Add utils to extract roots from call graph.

Bug: 127694949, 140768815
Change-Id: I7c0e0c38678e7609c52d08cf52e5683d6b25ea62
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java b/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
index 307215a..0b26147 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
@@ -6,7 +6,6 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
 import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator.CycleEliminationResult;
 import com.android.tools.r8.ir.conversion.CallSiteInformation.CallGraphBasedCallSiteInformation;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
@@ -76,7 +75,14 @@
       caller.callees.remove(this);
     }
 
-    public void cleanForRemoval() {
+    public void cleanCalleesForRemoval() {
+      assert callers.isEmpty();
+      for (Node callee : callees) {
+        callee.callers.remove(this);
+      }
+    }
+
+    public void cleanCallersForRemoval() {
       assert callees.isEmpty();
       for (Node caller : callers) {
         caller.callees.remove(this);
@@ -103,6 +109,10 @@
       return callers.contains(method);
     }
 
+    public boolean isRoot() {
+      return callers.isEmpty();
+    }
+
     public boolean isLeaf() {
       return callees.isEmpty();
     }
@@ -156,14 +166,6 @@
     return new CallGraphBuilder(appView);
   }
 
-  /**
-   * Extract the next set of leaves (nodes with an call (outgoing) degree of 0) if any.
-   *
-   * <p>All nodes in the graph are extracted if called repeatedly until null is returned. Please
-   * note that there are no cycles in this graph (see {@link CycleEliminator#breakCycles}).
-   *
-   * <p>
-   */
   static MethodProcessor createMethodProcessor(
       AppView<AppInfoWithLiveness> appView, ExecutorService executorService, Timing timing)
       throws ExecutionException {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
index 9a55240..ab68915 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
@@ -7,6 +7,7 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.Action;
 import com.android.tools.r8.utils.IROrdering;
@@ -70,7 +71,13 @@
     return waves;
   }
 
-  private static void extractLeaves(Set<Node> nodes, Consumer<Node> fn) {
+  /**
+   * Extract the next set of leaves (nodes with an outgoing call degree of 0) if any.
+   *
+   * <p>All nodes in the graph are extracted if called repeatedly until null is returned. Please
+   * note that there are no cycles in this graph (see {@link CycleEliminator#breakCycles}).
+   */
+  static void extractLeaves(Iterable<Node> nodes, Consumer<Node> fn) {
     Set<Node> removed = Sets.newIdentityHashSet();
     Iterator<Node> nodeIterator = nodes.iterator();
     while (nodeIterator.hasNext()) {
@@ -81,7 +88,7 @@
         removed.add(node);
       }
     }
-    removed.forEach(Node::cleanForRemoval);
+    removed.forEach(Node::cleanCallersForRemoval);
   }
 
   /**
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PostMethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/PostMethodProcessor.java
new file mode 100644
index 0000000..dab85a9
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PostMethodProcessor.java
@@ -0,0 +1,33 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.conversion;
+
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.google.common.collect.Sets;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.function.Consumer;
+
+public class PostMethodProcessor {
+
+  /**
+   * Extract the next set of roots (nodes with an incoming call degree of 0) if any.
+   *
+   * <p>All nodes in the graph are extracted if called repeatedly until null is returned.
+   */
+  static void extractRoots(Iterable<Node> nodes, Consumer<Node> fn) {
+    Set<Node> removed = Sets.newIdentityHashSet();
+    Iterator<Node> nodeIterator = nodes.iterator();
+    while (nodeIterator.hasNext()) {
+      Node node = nodeIterator.next();
+      if (node.isRoot()) {
+        fn.accept(node);
+        nodeIterator.remove();
+        removed.add(node);
+      }
+    }
+    removed.forEach(Node::cleanCalleesForRemoval);
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/conversion/CallGraphTestBase.java b/src/test/java/com/android/tools/r8/ir/conversion/CallGraphTestBase.java
new file mode 100644
index 0000000..b567014
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/conversion/CallGraphTestBase.java
@@ -0,0 +1,31 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.conversion;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.ParameterAnnotationsList;
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+
+class CallGraphTestBase extends TestBase {
+  private DexItemFactory dexItemFactory = new DexItemFactory();
+
+  Node createNode(String methodName) {
+    DexMethod signature =
+        dexItemFactory.createMethod(
+            dexItemFactory.objectType,
+            dexItemFactory.createProto(dexItemFactory.voidType),
+            methodName);
+    return new Node(
+        new DexEncodedMethod(signature, null, null, ParameterAnnotationsList.empty(), null));
+  }
+
+  Node createForceInlinedNode(String methodName) {
+    Node node = createNode(methodName);
+    node.method.getMutableOptimizationInfo().markForceInline();
+    return node;
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java b/src/test/java/com/android/tools/r8/ir/conversion/CycleEliminationTest.java
similarity index 87%
rename from src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java
rename to src/test/java/com/android/tools/r8/ir/conversion/CycleEliminationTest.java
index f767bf2..2a7c571 100644
--- a/src/test/java/com/android/tools/r8/ir/callgraph/CycleEliminationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/conversion/CycleEliminationTest.java
@@ -2,7 +2,7 @@
 // for details. All rights reserved. Use of this source code is governed by a
 // BSD-style license that can be found in the LICENSE file.
 
-package com.android.tools.r8.ir.callgraph;
+package com.android.tools.r8.ir.conversion;
 
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -11,12 +11,7 @@
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
-import com.android.tools.r8.TestBase;
 import com.android.tools.r8.errors.CompilationError;
-import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexItemFactory;
-import com.android.tools.r8.graph.DexMethod;
-import com.android.tools.r8.graph.ParameterAnnotationsList;
 import com.android.tools.r8.ir.conversion.CallGraph.Node;
 import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
 import com.android.tools.r8.utils.InternalOptions;
@@ -29,7 +24,7 @@
 import java.util.function.BooleanSupplier;
 import org.junit.Test;
 
-public class CycleEliminationTest extends TestBase {
+public class CycleEliminationTest extends CallGraphTestBase {
 
   private static class Configuration {
 
@@ -44,8 +39,6 @@
     }
   }
 
-  private DexItemFactory dexItemFactory = new DexItemFactory();
-
   @Test
   public void testSimpleCycle() {
     Node method = createNode("n1");
@@ -197,20 +190,4 @@
       }
     }
   }
-
-  private Node createNode(String methodName) {
-    DexMethod signature =
-        dexItemFactory.createMethod(
-            dexItemFactory.objectType,
-            dexItemFactory.createProto(dexItemFactory.voidType),
-            methodName);
-    return new Node(
-        new DexEncodedMethod(signature, null, null, ParameterAnnotationsList.empty(), null));
-  }
-
-  private Node createForceInlinedNode(String methodName) {
-    Node node = createNode(methodName);
-    node.method.getMutableOptimizationInfo().markForceInline();
-    return node;
-  }
 }
diff --git a/src/test/java/com/android/tools/r8/ir/conversion/NodeExtractionTest.java b/src/test/java/com/android/tools/r8/ir/conversion/NodeExtractionTest.java
new file mode 100644
index 0000000..9a48172
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/conversion/NodeExtractionTest.java
@@ -0,0 +1,224 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.conversion;
+
+import static org.hamcrest.CoreMatchers.hasItem;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
+import com.android.tools.r8.utils.InternalOptions;
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import org.junit.Test;
+
+public class NodeExtractionTest extends CallGraphTestBase {
+
+  private InternalOptions options = new InternalOptions();
+
+  // Note that building a test graph is intentionally repeated to avoid race conditions and/or
+  // non-deterministic test results due to cycle elimination.
+
+  @Test
+  public void testExtractLeaves_withoutCycle() {
+    Node n1, n2, n3, n4, n5, n6;
+    List<Node> nodes;
+
+    n1 = createNode("n1");
+    n2 = createNode("n2");
+    n3 = createNode("n3");
+    n4 = createNode("n4");
+    n5 = createNode("n5");
+    n6 = createNode("n6");
+
+    n2.addCallerConcurrently(n1);
+    n3.addCallerConcurrently(n2);
+    n4.addCallerConcurrently(n2);
+    n4.addCallerConcurrently(n5);
+    n6.addCallerConcurrently(n5);
+
+    nodes = new ArrayList<>();
+    nodes.add(n1);
+    nodes.add(n2);
+    nodes.add(n3);
+    nodes.add(n4);
+    nodes.add(n5);
+    nodes.add(n6);
+
+    Set<Node> wave = Sets.newIdentityHashSet();
+
+    MethodProcessor.extractLeaves(nodes, wave::add);
+    assertEquals(3, wave.size());
+    assertThat(wave, hasItem(n3));
+    assertThat(wave, hasItem(n4));
+    assertThat(wave, hasItem(n6));
+    wave.clear();
+
+    MethodProcessor.extractLeaves(nodes, wave::add);
+    assertEquals(2, wave.size());
+    assertThat(wave, hasItem(n2));
+    assertThat(wave, hasItem(n5));
+    wave.clear();
+
+    MethodProcessor.extractLeaves(nodes, wave::add);
+    assertEquals(1, wave.size());
+    assertThat(wave, hasItem(n1));
+    assertTrue(nodes.isEmpty());
+  }
+
+  @Test
+  public void testExtractLeaves_withCycle() {
+    Node n1, n2, n3, n4, n5, n6;
+    List<Node> nodes;
+
+    n1 = createNode("n1");
+    n2 = createNode("n2");
+    n3 = createNode("n3");
+    n4 = createNode("n4");
+    n5 = createNode("n5");
+    n6 = createNode("n6");
+
+    n2.addCallerConcurrently(n1);
+    n3.addCallerConcurrently(n2);
+    n4.addCallerConcurrently(n2);
+    n4.addCallerConcurrently(n5);
+    n6.addCallerConcurrently(n5);
+
+    nodes = new ArrayList<>();
+    nodes.add(n1);
+    nodes.add(n2);
+    nodes.add(n3);
+    nodes.add(n4);
+    nodes.add(n5);
+    nodes.add(n6);
+
+    n1.addCallerConcurrently(n3);
+    n3.method.getMutableOptimizationInfo().markForceInline();
+    CycleEliminator cycleEliminator = new CycleEliminator(nodes, options);
+    assertEquals(1, cycleEliminator.breakCycles().numberOfRemovedEdges());
+
+    Set<Node> wave = Sets.newIdentityHashSet();
+
+    MethodProcessor.extractLeaves(nodes, wave::add);
+    assertEquals(3, wave.size());
+    assertThat(wave, hasItem(n3));
+    assertThat(wave, hasItem(n4));
+    assertThat(wave, hasItem(n6));
+    wave.clear();
+
+    MethodProcessor.extractLeaves(nodes, wave::add);
+    assertEquals(2, wave.size());
+    assertThat(wave, hasItem(n2));
+    assertThat(wave, hasItem(n5));
+    wave.clear();
+
+    MethodProcessor.extractLeaves(nodes, wave::add);
+    assertEquals(1, wave.size());
+    assertThat(wave, hasItem(n1));
+    assertTrue(nodes.isEmpty());
+  }
+
+  @Test
+  public void testExtractRoots_withoutCycle() {
+    Node n1, n2, n3, n4, n5, n6;
+    List<Node> nodes;
+
+    n1 = createNode("n1");
+    n2 = createNode("n2");
+    n3 = createNode("n3");
+    n4 = createNode("n4");
+    n5 = createNode("n5");
+    n6 = createNode("n6");
+
+    n2.addCallerConcurrently(n1);
+    n3.addCallerConcurrently(n2);
+    n4.addCallerConcurrently(n2);
+    n4.addCallerConcurrently(n5);
+    n6.addCallerConcurrently(n5);
+
+    nodes = new ArrayList<>();
+    nodes.add(n1);
+    nodes.add(n2);
+    nodes.add(n3);
+    nodes.add(n4);
+    nodes.add(n5);
+    nodes.add(n6);
+
+    Set<Node> wave = Sets.newIdentityHashSet();
+
+    PostMethodProcessor.extractRoots(nodes, wave::add);
+    assertEquals(2, wave.size());
+    assertThat(wave, hasItem(n1));
+    assertThat(wave, hasItem(n5));
+    wave.clear();
+
+    PostMethodProcessor.extractRoots(nodes, wave::add);
+    assertEquals(2, wave.size());
+    assertThat(wave, hasItem(n2));
+    assertThat(wave, hasItem(n6));
+    wave.clear();
+
+    PostMethodProcessor.extractRoots(nodes, wave::add);
+    assertEquals(2, wave.size());
+    assertThat(wave, hasItem(n3));
+    assertThat(wave, hasItem(n4));
+    assertTrue(nodes.isEmpty());
+  }
+
+  @Test
+  public void testExtractRoots_withCycle() {
+    Node n1, n2, n3, n4, n5, n6;
+    List<Node> nodes;
+
+    n1 = createNode("n1");
+    n2 = createNode("n2");
+    n3 = createNode("n3");
+    n4 = createNode("n4");
+    n5 = createNode("n5");
+    n6 = createNode("n6");
+
+    n2.addCallerConcurrently(n1);
+    n3.addCallerConcurrently(n2);
+    n4.addCallerConcurrently(n2);
+    n4.addCallerConcurrently(n5);
+    n6.addCallerConcurrently(n5);
+
+    nodes = new ArrayList<>();
+    nodes.add(n1);
+    nodes.add(n2);
+    nodes.add(n3);
+    nodes.add(n4);
+    nodes.add(n5);
+    nodes.add(n6);
+
+    n1.addCallerConcurrently(n3);
+    n3.method.getMutableOptimizationInfo().markForceInline();
+    CycleEliminator cycleEliminator = new CycleEliminator(nodes, options);
+    assertEquals(1, cycleEliminator.breakCycles().numberOfRemovedEdges());
+
+    Set<Node> wave = Sets.newIdentityHashSet();
+
+    PostMethodProcessor.extractRoots(nodes, wave::add);
+    assertEquals(2, wave.size());
+    assertThat(wave, hasItem(n1));
+    assertThat(wave, hasItem(n5));
+    wave.clear();
+
+    PostMethodProcessor.extractRoots(nodes, wave::add);
+    assertEquals(2, wave.size());
+    assertThat(wave, hasItem(n2));
+    assertThat(wave, hasItem(n6));
+    wave.clear();
+
+    PostMethodProcessor.extractRoots(nodes, wave::add);
+    assertEquals(2, wave.size());
+    assertThat(wave, hasItem(n3));
+    assertThat(wave, hasItem(n4));
+    assertTrue(nodes.isEmpty());
+  }
+}