Refactor call graph construction to CallGraphBuilder
Change-Id: I2dec2b5cbc8f60c475e1e663d7a3034e11d1de45
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index 16cd797..692acae 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -25,7 +25,6 @@
import java.util.function.Predicate;
public abstract class DexClass extends DexDefinition {
- public static final DexClass[] EMPTY_ARRAY = {};
public interface FieldSetter {
void setField(int index, DexEncodedField field);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java b/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
index 526ce19..767f0b2 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
@@ -4,41 +4,22 @@
package com.android.tools.r8.ir.conversion;
-import com.android.tools.r8.errors.CompilationError;
import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexApplication;
-import com.android.tools.r8.graph.DexClass;
-import com.android.tools.r8.graph.DexEncodedField;
import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexField;
import com.android.tools.r8.graph.DexMethod;
-import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.graph.GraphLense;
-import com.android.tools.r8.graph.GraphLense.GraphLenseLookupResult;
-import com.android.tools.r8.graph.UseRegistry;
-import com.android.tools.r8.ir.code.Invoke.Type;
-import com.android.tools.r8.logging.Log;
+import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.Action;
import com.android.tools.r8.utils.IROrdering;
-import com.android.tools.r8.utils.InternalOptions;
import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.ThrowingBiConsumer;
-import com.android.tools.r8.utils.Timing;
import com.google.common.collect.Sets;
-import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
-import java.util.Deque;
-import java.util.Iterator;
-import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
-import java.util.LinkedList;
import java.util.List;
-import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
@@ -63,15 +44,12 @@
*/
public class CallGraph extends CallSiteInformation {
- private CallGraph(InternalOptions options) {
- this.shuffle = options.testing.irOrdering;
- }
+ public static class Node implements Comparable<Node> {
- public static class Node {
+ public static Node[] EMPTY_ARRAY = {};
public final DexEncodedMethod method;
- private int invokeCount = 0;
- private boolean isSelfRecursive = false;
+ private int numberOfCallSites = 0;
// Outgoing calls from this method.
private final Set<Node> callees = new LinkedHashSet<>();
@@ -87,17 +65,29 @@
return method.accessFlags.isBridge();
}
- public void addCallee(Node method) {
- callees.add(method);
- method.callers.add(this);
+ public void addCaller(Node caller) {
+ callers.add(caller);
+ caller.callees.add(this);
+ numberOfCallSites++;
+ }
+
+ public void removeCaller(Node caller) {
+ callers.remove(caller);
+ caller.callees.remove(this);
+ }
+
+ public Node[] getCalleesWithDeterministicOrder() {
+ Node[] sorted = callees.toArray(Node.EMPTY_ARRAY);
+ Arrays.sort(sorted);
+ return sorted;
}
public boolean hasCallee(Node method) {
return callees.contains(method);
}
- boolean isSelfRecursive() {
- return isSelfRecursive;
+ public boolean hasCaller(Node method) {
+ return callers.contains(method);
}
public boolean isLeaf() {
@@ -105,6 +95,11 @@
}
@Override
+ public int compareTo(Node other) {
+ return method.method.slowCompareTo(other.method.method);
+ }
+
+ @Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("MethodNode for: ");
@@ -117,10 +112,7 @@
if (isBridge()) {
builder.append(", bridge");
}
- if (isSelfRecursive()) {
- builder.append(", recursive");
- }
- builder.append(", invoke count ").append(invokeCount);
+ builder.append(", invoke count ").append(numberOfCallSites);
builder.append(").\n");
if (callees.size() > 0) {
builder.append("Callees:\n");
@@ -142,37 +134,30 @@
}
}
- private final Map<DexEncodedMethod, Node> nodes = new LinkedHashMap<>();
+ private final Set<Node> nodes;
private final IROrdering shuffle;
private final Set<DexMethod> singleCallSite = Sets.newIdentityHashSet();
private final Set<DexMethod> doubleCallSite = Sets.newIdentityHashSet();
- public static CallGraph build(
- DexApplication application,
- AppView<AppInfoWithLiveness> appView,
- InternalOptions options,
- Timing timing) {
- CallGraph graph = new CallGraph(options);
- DexClass[] classes = application.classes().toArray(DexClass.EMPTY_ARRAY);
- Arrays.sort(classes, (DexClass a, DexClass b) -> a.type.slowCompareTo(b.type));
- for (DexClass clazz : classes) {
- for (DexEncodedMethod method : clazz.allMethodsSorted()) {
- Node node = graph.ensureMethodNode(method);
- InvokeExtractor extractor = new InvokeExtractor(appView, node, graph);
- method.registerCodeReferences(extractor);
+ CallGraph(AppView<AppInfoWithLiveness> appView, Set<Node> nodes) {
+ this.nodes = nodes;
+ this.shuffle = appView.options().testing.irOrdering;
+
+ for (Node node : nodes) {
+ // For non-pinned methods we know the exact number of call sites.
+ if (!appView.appInfo().isPinned(node.method.method)) {
+ if (node.numberOfCallSites == 1) {
+ singleCallSite.add(node.method.method);
+ } else if (node.numberOfCallSites == 2) {
+ doubleCallSite.add(node.method.method);
+ }
}
}
- assert verifyAllMethodsWithCodeExists(application, graph);
+ }
- timing.begin("Cycle elimination");
- CycleEliminator cycleEliminator = new CycleEliminator(graph.nodes.values(), options);
- cycleEliminator.breakCycles();
- timing.end();
- assert cycleEliminator.breakCycles() == 0; // This time the cycles should be gone.
-
- graph.fillCallSiteSets(appView.appInfo());
- return graph;
+ public static CallGraphBuilder builder(AppView<AppInfoWithLiveness> appView) {
+ return new CallGraphBuilder(appView);
}
/**
@@ -191,30 +176,6 @@
return doubleCallSite.contains(method);
}
- private void fillCallSiteSets(AppInfoWithLiveness appInfo) {
- assert singleCallSite.isEmpty();
- for (Node value : nodes.values()) {
- // For non-pinned methods we know the exact number of call sites.
- if (!appInfo.isPinned(value.method.method)) {
- if (value.invokeCount == 1) {
- singleCallSite.add(value.method.method);
- } else if (value.invokeCount == 2) {
- doubleCallSite.add(value.method.method);
- }
- }
- }
- }
-
- private static boolean verifyAllMethodsWithCodeExists(
- DexApplication application, CallGraph graph) {
- for (DexProgramClass clazz : application.classes()) {
- for (DexEncodedMethod method : clazz.methods()) {
- assert !method.hasCode() || graph.nodes.get(method) != null;
- }
- }
- return true;
- }
-
/**
* Extract the next set of leaves (nodes with an call (outgoing) degree of 0) if any.
*
@@ -228,226 +189,18 @@
return Collections.emptySet();
}
// First identify all leaves before removing them from the graph.
- List<Node> leaves = nodes.values().stream().filter(Node::isLeaf).collect(Collectors.toList());
- leaves.forEach(leaf -> {
+ List<Node> leaves = nodes.stream().filter(Node::isLeaf).collect(Collectors.toList());
+ for (Node leaf : leaves) {
leaf.callers.forEach(caller -> caller.callees.remove(leaf));
- nodes.remove(leaf.method);
- });
+ nodes.remove(leaf);
+ }
Set<DexEncodedMethod> methods =
leaves.stream().map(x -> x.method).collect(Collectors.toCollection(LinkedHashSet::new));
return shuffle.order(methods);
}
- public static class CycleEliminator {
-
- public static final String CYCLIC_FORCE_INLINING_MESSAGE =
- "Unable to satisfy force inlining constraints due to cyclic force inlining";
-
- private static class CallEdge {
-
- private final Node caller;
- private final Node callee;
-
- public CallEdge(Node caller, Node callee) {
- this.caller = caller;
- this.callee = callee;
- }
- }
-
- private final Collection<Node> nodes;
- private final InternalOptions options;
-
- // DFS stack.
- private Deque<Node> stack = new ArrayDeque<>();
-
- // Set of nodes on the DFS stack.
- private Set<Node> stackSet = Sets.newIdentityHashSet();
-
- // Set of nodes that have been visited entirely.
- private Set<Node> marked = Sets.newIdentityHashSet();
-
- private int numberOfCycles = 0;
-
- public CycleEliminator(Collection<Node> nodes, InternalOptions options) {
- this.options = options;
-
- // Call to reorderNodes must happen after assigning options.
- this.nodes =
- options.testing.nondeterministicCycleElimination
- ? reorderNodes(new ArrayList<>(nodes))
- : nodes;
- }
-
- public int breakCycles() {
- // Break cycles in this call graph by removing edges causing cycles.
- for (Node node : nodes) {
- traverse(node);
- }
- int result = numberOfCycles;
- reset();
- return result;
- }
-
- private void reset() {
- assert stack.isEmpty();
- assert stackSet.isEmpty();
- marked.clear();
- numberOfCycles = 0;
- }
-
- private void traverse(Node node) {
- if (marked.contains(node)) {
- // Already visited all nodes that can be reached from this node.
- return;
- }
-
- push(node);
-
- // Sort the callees before calling traverse recursively. This will ensure cycles are broken
- // the same way across multiple invocations of the R8 compiler.
- Node[] callees = node.callees.toArray(new Node[]{});
- Arrays.sort(callees, (Node a, Node b) -> a.method.method.slowCompareTo(b.method.method));
- if (options.testing.nondeterministicCycleElimination) {
- reorderNodes(Arrays.asList(callees));
- }
-
- for (Node callee : callees) {
- if (stackSet.contains(callee)) {
- // Found a cycle that needs to be eliminated.
- numberOfCycles++;
-
- if (edgeRemovalIsSafe(node, callee)) {
- // Break the cycle by removing the edge node->callee.
- callee.callers.remove(node);
- node.callees.remove(callee);
-
- if (Log.ENABLED) {
- Log.info(
- CallGraph.class,
- "Removed call edge from method '%s' to '%s'",
- node.method.toSourceString(),
- callee.method.toSourceString());
- }
- } else {
- // The cycle has a method that is marked as force inline.
- LinkedList<Node> cycle = extractCycle(callee);
-
- if (Log.ENABLED) {
- Log.info(
- CallGraph.class, "Extracted cycle to find an edge that can safely be removed");
- }
-
- // Break the cycle by finding an edge that can be removed without breaking force
- // inlining. If that is not possible, this call fails with a compilation error.
- CallEdge edge = findCallEdgeForRemoval(cycle);
-
- // The edge will be null if this cycle has already been eliminated as a result of
- // another cycle elimination.
- if (edge != null) {
- assert edgeRemovalIsSafe(edge.caller, edge.callee);
-
- // Break the cycle by removing the edge caller->callee.
- edge.caller.callees.remove(edge.callee);
- edge.callee.callers.remove(edge.caller);
-
- if (Log.ENABLED) {
- Log.info(
- CallGraph.class,
- "Removed call edge from force inlined method '%s' to '%s' to ensure that "
- + "force inlining will succeed",
- node.method.toSourceString(),
- callee.method.toSourceString());
- }
- }
-
- // Recover the stack.
- recoverStack(cycle);
- }
- } else {
- traverse(callee);
- }
- }
- pop(node);
- marked.add(node);
- }
-
- private void push(Node node) {
- stack.push(node);
- boolean changed = stackSet.add(node);
- assert changed;
- }
-
- private void pop(Node node) {
- Node popped = stack.pop();
- assert popped == node;
- boolean changed = stackSet.remove(node);
- assert changed;
- }
-
- private LinkedList<Node> extractCycle(Node entry) {
- LinkedList<Node> cycle = new LinkedList<>();
- do {
- assert !stack.isEmpty();
- cycle.add(stack.pop());
- } while (cycle.getLast() != entry);
- return cycle;
- }
-
- private CallEdge findCallEdgeForRemoval(LinkedList<Node> extractedCycle) {
- Node callee = extractedCycle.getLast();
- for (Node caller : extractedCycle) {
- if (!caller.callees.contains(callee)) {
- // No need to break any edges since this cycle has already been broken previously.
- assert !callee.callers.contains(caller);
- return null;
- }
- if (edgeRemovalIsSafe(caller, callee)) {
- return new CallEdge(caller, callee);
- }
- callee = caller;
- }
- throw new CompilationError(CYCLIC_FORCE_INLINING_MESSAGE);
- }
-
- private static boolean edgeRemovalIsSafe(Node caller, Node callee) {
- // All call edges where the callee is a method that should be force inlined must be kept,
- // to guarantee that the IR converter will process the callee before the caller.
- return !callee.method.getOptimizationInfo().forceInline();
- }
-
- private void recoverStack(LinkedList<Node> extractedCycle) {
- Iterator<Node> descendingIt = extractedCycle.descendingIterator();
- while (descendingIt.hasNext()) {
- stack.push(descendingIt.next());
- }
- }
-
- private Collection<Node> reorderNodes(List<Node> nodes) {
- assert options.testing.nondeterministicCycleElimination;
- if (!InternalOptions.DETERMINISTIC_DEBUGGING) {
- Collections.shuffle(nodes);
- }
- return nodes;
- }
- }
-
- synchronized private Node ensureMethodNode(DexEncodedMethod method) {
- return nodes.computeIfAbsent(method, k -> new Node(method));
- }
-
- synchronized private void addCall(Node caller, Node callee) {
- assert caller != null;
- assert callee != null;
- if (caller != callee) {
- caller.addCallee(callee);
- } else {
- caller.isSelfRecursive = true;
- }
- callee.invokeCount++;
- }
-
public boolean isEmpty() {
- return nodes.size() == 0;
+ return nodes.isEmpty();
}
/**
@@ -477,171 +230,4 @@
waveDone.execute();
}
}
-
- public void dump() {
- nodes.forEach((m, n) -> System.out.println(n + "\n"));
- }
-
- private static class InvokeExtractor extends UseRegistry {
-
- private final AppInfoWithLiveness appInfo;
- private final GraphLense graphLense;
- private final Node caller;
- private final CallGraph graph;
-
- InvokeExtractor(AppView<AppInfoWithLiveness> appView, Node caller, CallGraph graph) {
- super(appView.dexItemFactory());
- this.appInfo = appView.appInfo();
- this.graphLense = appView.graphLense();
- this.caller = caller;
- this.graph = graph;
- }
-
- private void addClassInitializerTarget(DexClass clazz) {
- assert clazz != null;
- if (clazz.hasClassInitializer() && clazz.isProgramClass()) {
- addTarget(clazz.getClassInitializer());
- }
- }
-
- private void addClassInitializerTarget(DexType type) {
- assert type.isClassType();
- DexClass clazz = appInfo.definitionFor(type);
- if (clazz != null) {
- addClassInitializerTarget(clazz);
- }
- }
-
- private void addTarget(DexEncodedMethod target) {
- if (!target.accessFlags.isAbstract()) {
- Node callee = graph.ensureMethodNode(target);
- graph.addCall(caller, callee);
- }
- }
-
- private void addPossibleTarget(DexEncodedMethod possibleTarget) {
- DexClass possibleTargetClass =
- appInfo.definitionFor(possibleTarget.method.holder);
- if (possibleTargetClass != null && possibleTargetClass.isProgramClass()) {
- addTarget(possibleTarget);
- }
- }
-
- private void addPossibleTargets(
- DexEncodedMethod definition, Set<DexEncodedMethod> possibleTargets) {
- for (DexEncodedMethod possibleTarget : possibleTargets) {
- if (possibleTarget != definition) {
- addPossibleTarget(possibleTarget);
- }
- }
- }
-
- private void processInvoke(Type type, DexMethod method) {
- DexEncodedMethod source = caller.method;
- GraphLenseLookupResult result = graphLense.lookupMethod(method, source.method, type);
- method = result.getMethod();
- type = result.getType();
- DexEncodedMethod definition = appInfo.lookupSingleTarget(type, method, source.method.holder);
- if (definition != null) {
- assert !source.accessFlags.isBridge() || definition != caller.method;
- DexClass clazz = appInfo.definitionFor(definition.method.holder);
- assert clazz != null;
- if (clazz.isProgramClass()) {
- // For static invokes, the class could be initialized.
- if (type == Type.STATIC) {
- addClassInitializerTarget(clazz);
- }
-
- addTarget(definition);
- // For virtual and interface calls add all potential targets that could be called.
- if (type == Type.VIRTUAL || type == Type.INTERFACE) {
- Set<DexEncodedMethod> possibleTargets;
- if (clazz.isInterface()) {
- possibleTargets = appInfo.lookupInterfaceTargets(definition.method);
- } else {
- possibleTargets = appInfo.lookupVirtualTargets(definition.method);
- }
- addPossibleTargets(definition, possibleTargets);
- }
- }
- }
- }
-
- private void processFieldAccess(DexField field) {
- // Any field access implicitly calls the class initializer.
- if (field.holder.isClassType()) {
- DexEncodedField encodedField = appInfo.resolveField(field);
- if (encodedField != null && encodedField.isStatic()) {
- addClassInitializerTarget(field.holder);
- }
- }
- }
-
- @Override
- public boolean registerInvokeVirtual(DexMethod method) {
- processInvoke(Type.VIRTUAL, method);
- return false;
- }
-
- @Override
- public boolean registerInvokeDirect(DexMethod method) {
- processInvoke(Type.DIRECT, method);
- return false;
- }
-
- @Override
- public boolean registerInvokeStatic(DexMethod method) {
- processInvoke(Type.STATIC, method);
- return false;
- }
-
- @Override
- public boolean registerInvokeInterface(DexMethod method) {
- processInvoke(Type.INTERFACE, method);
- return false;
- }
-
- @Override
- public boolean registerInvokeSuper(DexMethod method) {
- processInvoke(Type.SUPER, method);
- return false;
- }
-
- @Override
- public boolean registerInstanceFieldWrite(DexField field) {
- processFieldAccess(field);
- return false;
- }
-
- @Override
- public boolean registerInstanceFieldRead(DexField field) {
- processFieldAccess(field);
- return false;
- }
-
- @Override
- public boolean registerNewInstance(DexType type) {
- if (type.isClassType()) {
- addClassInitializerTarget(type);
- }
- return false;
- }
-
- @Override
- public boolean registerStaticFieldRead(DexField field) {
- processFieldAccess(field);
- return false;
- }
-
- @Override
- public boolean registerStaticFieldWrite(DexField field) {
- processFieldAccess(field);
- return false;
- }
-
- @Override
- public boolean registerTypeReference(DexType type) {
- return false;
- }
- }
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilder.java
new file mode 100644
index 0000000..a19bf3c
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilder.java
@@ -0,0 +1,439 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.conversion;
+
+import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedField;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.GraphLense.GraphLenseLookupResult;
+import com.android.tools.r8.graph.UseRegistry;
+import com.android.tools.r8.ir.code.Invoke;
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.logging.Log;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.Timing;
+import com.google.common.collect.Sets;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.IdentityHashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class CallGraphBuilder {
+
+ private final AppView<AppInfoWithLiveness> appView;
+ private final Map<DexMethod, Node> nodes = new IdentityHashMap<>();
+
+ CallGraphBuilder(AppView<AppInfoWithLiveness> appView) {
+ this.appView = appView;
+ }
+
+ public CallGraph build(Timing timing) {
+ for (DexProgramClass clazz : appView.appInfo().classesWithDeterministicOrder()) {
+ processClass(clazz);
+ }
+
+ assert verifyAllMethodsWithCodeExists();
+
+ timing.begin("Cycle elimination");
+ CycleEliminator cycleEliminator = new CycleEliminator(nodes.values(), appView.options());
+ cycleEliminator.breakCycles();
+ timing.end();
+ assert cycleEliminator.breakCycles() == 0; // This time the cycles should be gone.
+
+ return new CallGraph(appView, Sets.newHashSet(nodes.values()));
+ }
+
+ private void processClass(DexProgramClass clazz) {
+ for (DexEncodedMethod method : clazz.allMethodsSorted()) {
+ processMethod(method);
+ }
+ }
+
+ private void processMethod(DexEncodedMethod method) {
+ if (method.hasCode()) {
+ method.registerCodeReferences(new InvokeExtractor(getOrCreateNode(method)));
+ }
+ }
+
+ private Node getOrCreateNode(DexEncodedMethod method) {
+ return nodes.computeIfAbsent(method.method, ignore -> new Node(method));
+ }
+
+ private boolean verifyAllMethodsWithCodeExists() {
+ for (DexProgramClass clazz : appView.appInfo().classes()) {
+ for (DexEncodedMethod method : clazz.methods()) {
+ assert !method.hasCode() || nodes.get(method.method) != null;
+ }
+ }
+ return true;
+ }
+
+ private class InvokeExtractor extends UseRegistry {
+
+ private final Node caller;
+
+ InvokeExtractor(Node caller) {
+ super(appView.dexItemFactory());
+ this.caller = caller;
+ }
+
+ private void addClassInitializerTarget(DexClass clazz) {
+ assert clazz != null;
+ if (clazz.hasClassInitializer() && clazz.isProgramClass()) {
+ addTarget(clazz.getClassInitializer());
+ }
+ }
+
+ private void addClassInitializerTarget(DexType type) {
+ assert type.isClassType();
+ DexClass clazz = appView.definitionFor(type);
+ if (clazz != null) {
+ addClassInitializerTarget(clazz);
+ }
+ }
+
+ private void addTarget(DexEncodedMethod callee) {
+ if (!callee.accessFlags.isAbstract()) {
+ getOrCreateNode(callee).addCaller(caller);
+ }
+ }
+
+ private void addPossibleTarget(DexEncodedMethod possibleTarget) {
+ DexClass possibleTargetClass = appView.definitionFor(possibleTarget.method.holder);
+ if (possibleTargetClass != null && possibleTargetClass.isProgramClass()) {
+ addTarget(possibleTarget);
+ }
+ }
+
+ private void addPossibleTargets(
+ DexEncodedMethod definition, Set<DexEncodedMethod> possibleTargets) {
+ for (DexEncodedMethod possibleTarget : possibleTargets) {
+ if (possibleTarget != definition) {
+ addPossibleTarget(possibleTarget);
+ }
+ }
+ }
+
+ private void processInvoke(Invoke.Type type, DexMethod method) {
+ DexEncodedMethod source = caller.method;
+ GraphLenseLookupResult result =
+ appView.graphLense().lookupMethod(method, source.method, type);
+ method = result.getMethod();
+ type = result.getType();
+ DexEncodedMethod definition =
+ appView.appInfo().lookupSingleTarget(type, method, source.method.holder);
+ if (definition != null) {
+ assert !source.accessFlags.isBridge() || definition != caller.method;
+ DexClass clazz = appView.definitionFor(definition.method.holder);
+ assert clazz != null;
+ if (clazz.isProgramClass()) {
+ // For static invokes, the class could be initialized.
+ if (type == Invoke.Type.STATIC) {
+ addClassInitializerTarget(clazz);
+ }
+
+ addTarget(definition);
+ // For virtual and interface calls add all potential targets that could be called.
+ if (type == Invoke.Type.VIRTUAL || type == Invoke.Type.INTERFACE) {
+ Set<DexEncodedMethod> possibleTargets;
+ if (clazz.isInterface()) {
+ possibleTargets = appView.appInfo().lookupInterfaceTargets(definition.method);
+ } else {
+ possibleTargets = appView.appInfo().lookupVirtualTargets(definition.method);
+ }
+ addPossibleTargets(definition, possibleTargets);
+ }
+ }
+ }
+ }
+
+ private void processFieldAccess(DexField field) {
+ // Any field access implicitly calls the class initializer.
+ if (field.holder.isClassType()) {
+ DexEncodedField encodedField = appView.appInfo().resolveField(field);
+ if (encodedField != null && encodedField.isStatic()) {
+ addClassInitializerTarget(field.holder);
+ }
+ }
+ }
+
+ @Override
+ public boolean registerInvokeVirtual(DexMethod method) {
+ processInvoke(Invoke.Type.VIRTUAL, method);
+ return false;
+ }
+
+ @Override
+ public boolean registerInvokeDirect(DexMethod method) {
+ processInvoke(Invoke.Type.DIRECT, method);
+ return false;
+ }
+
+ @Override
+ public boolean registerInvokeStatic(DexMethod method) {
+ processInvoke(Invoke.Type.STATIC, method);
+ return false;
+ }
+
+ @Override
+ public boolean registerInvokeInterface(DexMethod method) {
+ processInvoke(Invoke.Type.INTERFACE, method);
+ return false;
+ }
+
+ @Override
+ public boolean registerInvokeSuper(DexMethod method) {
+ processInvoke(Invoke.Type.SUPER, method);
+ return false;
+ }
+
+ @Override
+ public boolean registerInstanceFieldWrite(DexField field) {
+ processFieldAccess(field);
+ return false;
+ }
+
+ @Override
+ public boolean registerInstanceFieldRead(DexField field) {
+ processFieldAccess(field);
+ return false;
+ }
+
+ @Override
+ public boolean registerNewInstance(DexType type) {
+ if (type.isClassType()) {
+ addClassInitializerTarget(type);
+ }
+ return false;
+ }
+
+ @Override
+ public boolean registerStaticFieldRead(DexField field) {
+ processFieldAccess(field);
+ return false;
+ }
+
+ @Override
+ public boolean registerStaticFieldWrite(DexField field) {
+ processFieldAccess(field);
+ return false;
+ }
+
+ @Override
+ public boolean registerTypeReference(DexType type) {
+ return false;
+ }
+ }
+
+ public static class CycleEliminator {
+
+ public static final String CYCLIC_FORCE_INLINING_MESSAGE =
+ "Unable to satisfy force inlining constraints due to cyclic force inlining";
+
+ private static class CallEdge {
+
+ private final Node caller;
+ private final Node callee;
+
+ public CallEdge(Node caller, Node callee) {
+ this.caller = caller;
+ this.callee = callee;
+ }
+
+ public void remove() {
+ callee.removeCaller(caller);
+ }
+ }
+
+ private final Collection<Node> nodes;
+ private final InternalOptions options;
+
+ // DFS stack.
+ private Deque<Node> stack = new ArrayDeque<>();
+
+ // Set of nodes on the DFS stack.
+ private Set<Node> stackSet = Sets.newIdentityHashSet();
+
+ // Set of nodes that have been visited entirely.
+ private Set<Node> marked = Sets.newIdentityHashSet();
+
+ private int numberOfCycles = 0;
+
+ public CycleEliminator(Collection<Node> nodes, InternalOptions options) {
+ this.options = options;
+
+ // Call to reorderNodes must happen after assigning options.
+ this.nodes =
+ options.testing.nondeterministicCycleElimination
+ ? reorderNodes(new ArrayList<>(nodes))
+ : nodes;
+ }
+
+ public int breakCycles() {
+ // Break cycles in this call graph by removing edges causing cycles.
+ for (Node node : nodes) {
+ traverse(node);
+ }
+ int result = numberOfCycles;
+ reset();
+ return result;
+ }
+
+ private void reset() {
+ assert stack.isEmpty();
+ assert stackSet.isEmpty();
+ marked.clear();
+ numberOfCycles = 0;
+ }
+
+ private void traverse(Node node) {
+ if (marked.contains(node)) {
+ // Already visited all nodes that can be reached from this node.
+ return;
+ }
+
+ push(node);
+
+ // Sort the callees before calling traverse recursively. This will ensure cycles are broken
+ // the same way across multiple invocations of the R8 compiler.
+ Node[] callees = node.getCalleesWithDeterministicOrder();
+
+ if (options.testing.nondeterministicCycleElimination) {
+ reorderNodes(Arrays.asList(callees));
+ }
+
+ for (Node callee : callees) {
+ if (stackSet.contains(callee)) {
+ // Found a cycle that needs to be eliminated.
+ numberOfCycles++;
+
+ if (edgeRemovalIsSafe(node, callee)) {
+ // Break the cycle by removing the edge node->callee.
+ callee.removeCaller(node);
+
+ if (Log.ENABLED) {
+ Log.info(
+ CallGraph.class,
+ "Removed call edge from method '%s' to '%s'",
+ node.method.toSourceString(),
+ callee.method.toSourceString());
+ }
+ } else {
+ // The cycle has a method that is marked as force inline.
+ LinkedList<Node> cycle = extractCycle(callee);
+
+ if (Log.ENABLED) {
+ Log.info(
+ CallGraph.class, "Extracted cycle to find an edge that can safely be removed");
+ }
+
+ // Break the cycle by finding an edge that can be removed without breaking force
+ // inlining. If that is not possible, this call fails with a compilation error.
+ CallEdge edge = findCallEdgeForRemoval(cycle);
+
+ // The edge will be null if this cycle has already been eliminated as a result of
+ // another cycle elimination.
+ if (edge != null) {
+ assert edgeRemovalIsSafe(edge.caller, edge.callee);
+
+ // Break the cycle by removing the edge caller->callee.
+ edge.remove();
+
+ if (Log.ENABLED) {
+ Log.info(
+ CallGraph.class,
+ "Removed call edge from force inlined method '%s' to '%s' to ensure that "
+ + "force inlining will succeed",
+ node.method.toSourceString(),
+ callee.method.toSourceString());
+ }
+ }
+
+ // Recover the stack.
+ recoverStack(cycle);
+ }
+ } else {
+ traverse(callee);
+ }
+ }
+ pop(node);
+ marked.add(node);
+ }
+
+ private void push(Node node) {
+ stack.push(node);
+ boolean changed = stackSet.add(node);
+ assert changed;
+ }
+
+ private void pop(Node node) {
+ Node popped = stack.pop();
+ assert popped == node;
+ boolean changed = stackSet.remove(node);
+ assert changed;
+ }
+
+ private LinkedList<Node> extractCycle(Node entry) {
+ LinkedList<Node> cycle = new LinkedList<>();
+ do {
+ assert !stack.isEmpty();
+ cycle.add(stack.pop());
+ } while (cycle.getLast() != entry);
+ return cycle;
+ }
+
+ private CallEdge findCallEdgeForRemoval(LinkedList<Node> extractedCycle) {
+ Node callee = extractedCycle.getLast();
+ for (Node caller : extractedCycle) {
+ if (!caller.hasCallee(callee)) {
+ // No need to break any edges since this cycle has already been broken previously.
+ assert !callee.hasCaller(caller);
+ return null;
+ }
+ if (edgeRemovalIsSafe(caller, callee)) {
+ return new CallEdge(caller, callee);
+ }
+ callee = caller;
+ }
+ throw new CompilationError(CYCLIC_FORCE_INLINING_MESSAGE);
+ }
+
+ private static boolean edgeRemovalIsSafe(Node caller, Node callee) {
+ // All call edges where the callee is a method that should be force inlined must be kept,
+ // to guarantee that the IR converter will process the callee before the caller.
+ return !callee.method.getOptimizationInfo().forceInline();
+ }
+
+ private void recoverStack(LinkedList<Node> extractedCycle) {
+ Iterator<Node> descendingIt = extractedCycle.descendingIterator();
+ while (descendingIt.hasNext()) {
+ stack.push(descendingIt.next());
+ }
+ }
+
+ private Collection<Node> reorderNodes(List<Node> nodes) {
+ assert options.testing.nondeterministicCycleElimination;
+ if (!InternalOptions.DETERMINISTIC_DEBUGGING) {
+ Collections.shuffle(nodes);
+ }
+ return nodes;
+ }
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 4bd5bf5..e617152 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -510,7 +510,7 @@
OptimizationFeedbackDelayed feedback = delayedOptimizationFeedback;
{
timing.begin("Build call graph");
- CallGraph callGraph = CallGraph.build(application, appView.withLiveness(), options, timing);
+ CallGraph callGraph = CallGraph.builder(appView.withLiveness()).build(timing);
timing.end();
timing.begin("IR conversion phase 1");
BiConsumer<IRCode, DexEncodedMethod> outlineHandler =
diff --git a/src/test/java/com/android/tools/r8/internal/R8GMSCoreV10TreeShakeJarVerificationTest.java b/src/test/java/com/android/tools/r8/internal/R8GMSCoreV10TreeShakeJarVerificationTest.java
index 62c6bc8..4374d0e 100644
--- a/src/test/java/com/android/tools/r8/internal/R8GMSCoreV10TreeShakeJarVerificationTest.java
+++ b/src/test/java/com/android/tools/r8/internal/R8GMSCoreV10TreeShakeJarVerificationTest.java
@@ -7,9 +7,7 @@
import com.android.tools.r8.CompilationMode;
import com.android.tools.r8.utils.AndroidApp;
-import org.junit.Rule;
import org.junit.Test;
-import org.junit.rules.ExpectedException;
public class R8GMSCoreV10TreeShakeJarVerificationTest
extends R8GMSCoreTreeShakeJarVerificationTest {
diff --git a/src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java b/src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java
index 8156016..accc3e5 100644
--- a/src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java
@@ -4,17 +4,20 @@
package com.android.tools.r8.ir.callgraph;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import com.android.tools.r8.TestBase;
import com.android.tools.r8.errors.CompilationError;
import com.android.tools.r8.graph.DexEncodedMethod;
import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexMethod;
-import com.android.tools.r8.ir.conversion.CallGraph.CycleEliminator;
import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
import com.android.tools.r8.utils.InternalOptions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@@ -23,9 +26,7 @@
import java.util.List;
import java.util.Set;
import java.util.function.BooleanSupplier;
-import org.junit.Rule;
import org.junit.Test;
-import org.junit.rules.ExpectedException;
public class CycleEliminationTest extends TestBase {
@@ -44,8 +45,6 @@
private DexItemFactory dexItemFactory = new DexItemFactory();
- @Rule public final ExpectedException exception = ExpectedException.none();
-
@Test
public void testSimpleCycle() {
Node method = createNode("n1");
@@ -58,8 +57,8 @@
for (Collection<Node> nodes : orderings) {
// Create a cycle between the two nodes.
- method.addCallee(forceInlinedMethod);
- forceInlinedMethod.addCallee(method);
+ forceInlinedMethod.addCaller(method);
+ method.addCaller(forceInlinedMethod);
// Check that the cycle eliminator finds the cycle.
CycleEliminator cycleEliminator = new CycleEliminator(nodes, new InternalOptions());
@@ -80,17 +79,18 @@
Node forceInlinedMethod = createForceInlinedNode("n2");
// Create a cycle between the two nodes.
- method.addCallee(forceInlinedMethod);
- forceInlinedMethod.addCallee(method);
+ forceInlinedMethod.addCaller(method);
+ method.addCaller(forceInlinedMethod);
CycleEliminator cycleEliminator =
new CycleEliminator(ImmutableList.of(method, forceInlinedMethod), new InternalOptions());
- exception.expect(CompilationError.class);
- exception.expectMessage(CycleEliminator.CYCLIC_FORCE_INLINING_MESSAGE);
-
- // Should throw because force inlining will fail.
- cycleEliminator.breakCycles();
+ try {
+ cycleEliminator.breakCycles();
+ fail("Force inlining should fail");
+ } catch (CompilationError e) {
+ assertThat(e.toString(), containsString(CycleEliminator.CYCLIC_FORCE_INLINING_MESSAGE));
+ }
}
@Test
@@ -160,12 +160,12 @@
for (Configuration configuration : configurations) {
// Create a cycle between the three nodes.
- n1.addCallee(n2);
- n2.addCallee(n3);
- n3.addCallee(n1);
+ n2.addCaller(n1);
+ n3.addCaller(n2);
+ n1.addCaller(n3);
// Create a cycle in the graph between node n1 and n2.
- n2.addCallee(n1);
+ n1.addCaller(n2);
for (Node node : configuration.nodes) {
if (configuration.forceInline.contains(node)) {