Add utils to extract roots from call graph.
Bug: 127694949, 140768815
Change-Id: I7c0e0c38678e7609c52d08cf52e5683d6b25ea62
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java b/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
index 307215a..0b26147 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
@@ -6,7 +6,6 @@
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator.CycleEliminationResult;
import com.android.tools.r8.ir.conversion.CallSiteInformation.CallGraphBasedCallSiteInformation;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
@@ -76,7 +75,14 @@
caller.callees.remove(this);
}
- public void cleanForRemoval() {
+ public void cleanCalleesForRemoval() {
+ assert callers.isEmpty();
+ for (Node callee : callees) {
+ callee.callers.remove(this);
+ }
+ }
+
+ public void cleanCallersForRemoval() {
assert callees.isEmpty();
for (Node caller : callers) {
caller.callees.remove(this);
@@ -103,6 +109,10 @@
return callers.contains(method);
}
+ public boolean isRoot() {
+ return callers.isEmpty();
+ }
+
public boolean isLeaf() {
return callees.isEmpty();
}
@@ -156,14 +166,6 @@
return new CallGraphBuilder(appView);
}
- /**
- * Extract the next set of leaves (nodes with an call (outgoing) degree of 0) if any.
- *
- * <p>All nodes in the graph are extracted if called repeatedly until null is returned. Please
- * note that there are no cycles in this graph (see {@link CycleEliminator#breakCycles}).
- *
- * <p>
- */
static MethodProcessor createMethodProcessor(
AppView<AppInfoWithLiveness> appView, ExecutorService executorService, Timing timing)
throws ExecutionException {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
index 9a55240..ab68915 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
@@ -7,6 +7,7 @@
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexEncodedMethod;
import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.Action;
import com.android.tools.r8.utils.IROrdering;
@@ -70,7 +71,13 @@
return waves;
}
- private static void extractLeaves(Set<Node> nodes, Consumer<Node> fn) {
+ /**
+ * Extract the next set of leaves (nodes with an outgoing call degree of 0) if any.
+ *
+ * <p>All nodes in the graph are extracted if called repeatedly until null is returned. Please
+ * note that there are no cycles in this graph (see {@link CycleEliminator#breakCycles}).
+ */
+ static void extractLeaves(Iterable<Node> nodes, Consumer<Node> fn) {
Set<Node> removed = Sets.newIdentityHashSet();
Iterator<Node> nodeIterator = nodes.iterator();
while (nodeIterator.hasNext()) {
@@ -81,7 +88,7 @@
removed.add(node);
}
}
- removed.forEach(Node::cleanForRemoval);
+ removed.forEach(Node::cleanCallersForRemoval);
}
/**
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PostMethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/PostMethodProcessor.java
new file mode 100644
index 0000000..dab85a9
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PostMethodProcessor.java
@@ -0,0 +1,33 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.conversion;
+
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.google.common.collect.Sets;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.function.Consumer;
+
+public class PostMethodProcessor {
+
+ /**
+ * Extract the next set of roots (nodes with an incoming call degree of 0) if any.
+ *
+ * <p>All nodes in the graph are extracted if called repeatedly until null is returned.
+ */
+ static void extractRoots(Iterable<Node> nodes, Consumer<Node> fn) {
+ Set<Node> removed = Sets.newIdentityHashSet();
+ Iterator<Node> nodeIterator = nodes.iterator();
+ while (nodeIterator.hasNext()) {
+ Node node = nodeIterator.next();
+ if (node.isRoot()) {
+ fn.accept(node);
+ nodeIterator.remove();
+ removed.add(node);
+ }
+ }
+ removed.forEach(Node::cleanCalleesForRemoval);
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/conversion/CallGraphTestBase.java b/src/test/java/com/android/tools/r8/ir/conversion/CallGraphTestBase.java
new file mode 100644
index 0000000..b567014
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/conversion/CallGraphTestBase.java
@@ -0,0 +1,31 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.conversion;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.ParameterAnnotationsList;
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+
+class CallGraphTestBase extends TestBase {
+ private DexItemFactory dexItemFactory = new DexItemFactory();
+
+ Node createNode(String methodName) {
+ DexMethod signature =
+ dexItemFactory.createMethod(
+ dexItemFactory.objectType,
+ dexItemFactory.createProto(dexItemFactory.voidType),
+ methodName);
+ return new Node(
+ new DexEncodedMethod(signature, null, null, ParameterAnnotationsList.empty(), null));
+ }
+
+ Node createForceInlinedNode(String methodName) {
+ Node node = createNode(methodName);
+ node.method.getMutableOptimizationInfo().markForceInline();
+ return node;
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java b/src/test/java/com/android/tools/r8/ir/conversion/CycleEliminationTest.java
similarity index 87%
rename from src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java
rename to src/test/java/com/android/tools/r8/ir/conversion/CycleEliminationTest.java
index f767bf2..2a7c571 100644
--- a/src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/conversion/CycleEliminationTest.java
@@ -2,7 +2,7 @@
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
-package com.android.tools.r8.ir.callgraph;
+package com.android.tools.r8.ir.conversion;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.MatcherAssert.assertThat;
@@ -11,12 +11,7 @@
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
-import com.android.tools.r8.TestBase;
import com.android.tools.r8.errors.CompilationError;
-import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexItemFactory;
-import com.android.tools.r8.graph.DexMethod;
-import com.android.tools.r8.graph.ParameterAnnotationsList;
import com.android.tools.r8.ir.conversion.CallGraph.Node;
import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
import com.android.tools.r8.utils.InternalOptions;
@@ -29,7 +24,7 @@
import java.util.function.BooleanSupplier;
import org.junit.Test;
-public class CycleEliminationTest extends TestBase {
+public class CycleEliminationTest extends CallGraphTestBase {
private static class Configuration {
@@ -44,8 +39,6 @@
}
}
- private DexItemFactory dexItemFactory = new DexItemFactory();
-
@Test
public void testSimpleCycle() {
Node method = createNode("n1");
@@ -197,20 +190,4 @@
}
}
}
-
- private Node createNode(String methodName) {
- DexMethod signature =
- dexItemFactory.createMethod(
- dexItemFactory.objectType,
- dexItemFactory.createProto(dexItemFactory.voidType),
- methodName);
- return new Node(
- new DexEncodedMethod(signature, null, null, ParameterAnnotationsList.empty(), null));
- }
-
- private Node createForceInlinedNode(String methodName) {
- Node node = createNode(methodName);
- node.method.getMutableOptimizationInfo().markForceInline();
- return node;
- }
}
diff --git a/src/test/java/com/android/tools/r8/ir/conversion/NodeExtractionTest.java b/src/test/java/com/android/tools/r8/ir/conversion/NodeExtractionTest.java
new file mode 100644
index 0000000..9a48172
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/conversion/NodeExtractionTest.java
@@ -0,0 +1,224 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.conversion;
+
+import static org.hamcrest.CoreMatchers.hasItem;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
+import com.android.tools.r8.utils.InternalOptions;
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import org.junit.Test;
+
+public class NodeExtractionTest extends CallGraphTestBase {
+
+ private InternalOptions options = new InternalOptions();
+
+ // Note that building a test graph is intentionally repeated to avoid race conditions and/or
+ // non-deterministic test results due to cycle elimination.
+
+ @Test
+ public void testExtractLeaves_withoutCycle() {
+ Node n1, n2, n3, n4, n5, n6;
+ List<Node> nodes;
+
+ n1 = createNode("n1");
+ n2 = createNode("n2");
+ n3 = createNode("n3");
+ n4 = createNode("n4");
+ n5 = createNode("n5");
+ n6 = createNode("n6");
+
+ n2.addCallerConcurrently(n1);
+ n3.addCallerConcurrently(n2);
+ n4.addCallerConcurrently(n2);
+ n4.addCallerConcurrently(n5);
+ n6.addCallerConcurrently(n5);
+
+ nodes = new ArrayList<>();
+ nodes.add(n1);
+ nodes.add(n2);
+ nodes.add(n3);
+ nodes.add(n4);
+ nodes.add(n5);
+ nodes.add(n6);
+
+ Set<Node> wave = Sets.newIdentityHashSet();
+
+ MethodProcessor.extractLeaves(nodes, wave::add);
+ assertEquals(3, wave.size());
+ assertThat(wave, hasItem(n3));
+ assertThat(wave, hasItem(n4));
+ assertThat(wave, hasItem(n6));
+ wave.clear();
+
+ MethodProcessor.extractLeaves(nodes, wave::add);
+ assertEquals(2, wave.size());
+ assertThat(wave, hasItem(n2));
+ assertThat(wave, hasItem(n5));
+ wave.clear();
+
+ MethodProcessor.extractLeaves(nodes, wave::add);
+ assertEquals(1, wave.size());
+ assertThat(wave, hasItem(n1));
+ assertTrue(nodes.isEmpty());
+ }
+
+ @Test
+ public void testExtractLeaves_withCycle() {
+ Node n1, n2, n3, n4, n5, n6;
+ List<Node> nodes;
+
+ n1 = createNode("n1");
+ n2 = createNode("n2");
+ n3 = createNode("n3");
+ n4 = createNode("n4");
+ n5 = createNode("n5");
+ n6 = createNode("n6");
+
+ n2.addCallerConcurrently(n1);
+ n3.addCallerConcurrently(n2);
+ n4.addCallerConcurrently(n2);
+ n4.addCallerConcurrently(n5);
+ n6.addCallerConcurrently(n5);
+
+ nodes = new ArrayList<>();
+ nodes.add(n1);
+ nodes.add(n2);
+ nodes.add(n3);
+ nodes.add(n4);
+ nodes.add(n5);
+ nodes.add(n6);
+
+ n1.addCallerConcurrently(n3);
+ n3.method.getMutableOptimizationInfo().markForceInline();
+ CycleEliminator cycleEliminator = new CycleEliminator(nodes, options);
+ assertEquals(1, cycleEliminator.breakCycles().numberOfRemovedEdges());
+
+ Set<Node> wave = Sets.newIdentityHashSet();
+
+ MethodProcessor.extractLeaves(nodes, wave::add);
+ assertEquals(3, wave.size());
+ assertThat(wave, hasItem(n3));
+ assertThat(wave, hasItem(n4));
+ assertThat(wave, hasItem(n6));
+ wave.clear();
+
+ MethodProcessor.extractLeaves(nodes, wave::add);
+ assertEquals(2, wave.size());
+ assertThat(wave, hasItem(n2));
+ assertThat(wave, hasItem(n5));
+ wave.clear();
+
+ MethodProcessor.extractLeaves(nodes, wave::add);
+ assertEquals(1, wave.size());
+ assertThat(wave, hasItem(n1));
+ assertTrue(nodes.isEmpty());
+ }
+
+ @Test
+ public void testExtractRoots_withoutCycle() {
+ Node n1, n2, n3, n4, n5, n6;
+ List<Node> nodes;
+
+ n1 = createNode("n1");
+ n2 = createNode("n2");
+ n3 = createNode("n3");
+ n4 = createNode("n4");
+ n5 = createNode("n5");
+ n6 = createNode("n6");
+
+ n2.addCallerConcurrently(n1);
+ n3.addCallerConcurrently(n2);
+ n4.addCallerConcurrently(n2);
+ n4.addCallerConcurrently(n5);
+ n6.addCallerConcurrently(n5);
+
+ nodes = new ArrayList<>();
+ nodes.add(n1);
+ nodes.add(n2);
+ nodes.add(n3);
+ nodes.add(n4);
+ nodes.add(n5);
+ nodes.add(n6);
+
+ Set<Node> wave = Sets.newIdentityHashSet();
+
+ PostMethodProcessor.extractRoots(nodes, wave::add);
+ assertEquals(2, wave.size());
+ assertThat(wave, hasItem(n1));
+ assertThat(wave, hasItem(n5));
+ wave.clear();
+
+ PostMethodProcessor.extractRoots(nodes, wave::add);
+ assertEquals(2, wave.size());
+ assertThat(wave, hasItem(n2));
+ assertThat(wave, hasItem(n6));
+ wave.clear();
+
+ PostMethodProcessor.extractRoots(nodes, wave::add);
+ assertEquals(2, wave.size());
+ assertThat(wave, hasItem(n3));
+ assertThat(wave, hasItem(n4));
+ assertTrue(nodes.isEmpty());
+ }
+
+ @Test
+ public void testExtractRoots_withCycle() {
+ Node n1, n2, n3, n4, n5, n6;
+ List<Node> nodes;
+
+ n1 = createNode("n1");
+ n2 = createNode("n2");
+ n3 = createNode("n3");
+ n4 = createNode("n4");
+ n5 = createNode("n5");
+ n6 = createNode("n6");
+
+ n2.addCallerConcurrently(n1);
+ n3.addCallerConcurrently(n2);
+ n4.addCallerConcurrently(n2);
+ n4.addCallerConcurrently(n5);
+ n6.addCallerConcurrently(n5);
+
+ nodes = new ArrayList<>();
+ nodes.add(n1);
+ nodes.add(n2);
+ nodes.add(n3);
+ nodes.add(n4);
+ nodes.add(n5);
+ nodes.add(n6);
+
+ n1.addCallerConcurrently(n3);
+ n3.method.getMutableOptimizationInfo().markForceInline();
+ CycleEliminator cycleEliminator = new CycleEliminator(nodes, options);
+ assertEquals(1, cycleEliminator.breakCycles().numberOfRemovedEdges());
+
+ Set<Node> wave = Sets.newIdentityHashSet();
+
+ PostMethodProcessor.extractRoots(nodes, wave::add);
+ assertEquals(2, wave.size());
+ assertThat(wave, hasItem(n1));
+ assertThat(wave, hasItem(n5));
+ wave.clear();
+
+ PostMethodProcessor.extractRoots(nodes, wave::add);
+ assertEquals(2, wave.size());
+ assertThat(wave, hasItem(n2));
+ assertThat(wave, hasItem(n6));
+ wave.clear();
+
+ PostMethodProcessor.extractRoots(nodes, wave::add);
+ assertEquals(2, wave.size());
+ assertThat(wave, hasItem(n3));
+ assertThat(wave, hasItem(n4));
+ assertTrue(nodes.isEmpty());
+ }
+}