Build IR processing waves up-front to enable gc-ing of call graph

Bug: 131781627
Change-Id: Ibac5b889b393b760e0a4633ba114ad530f3c32f3
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java b/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
index e03f33c..a37e32a 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallGraph.java
@@ -6,43 +6,29 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.ir.conversion.CallGraphBuilder.CycleEliminator;
+import com.android.tools.r8.ir.conversion.CallSiteInformation.CallGraphBasedCallSiteInformation;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.utils.Action;
-import com.android.tools.r8.utils.IROrdering;
-import com.android.tools.r8.utils.ThreadUtils;
-import com.android.tools.r8.utils.ThrowingBiConsumer;
-import com.google.common.collect.Sets;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.LinkedHashSet;
-import java.util.List;
 import java.util.Set;
 import java.util.TreeSet;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
-import java.util.function.Predicate;
-import java.util.stream.Collectors;
 
 /**
  * Call graph representation.
- * <p>
- * Each node in the graph contain the methods called and the calling methods. For virtual and
+ *
+ * <p>Each node in the graph contain the methods called and the calling methods. For virtual and
  * interface calls all potential calls from subtypes are recorded.
- * <p>
- * Only methods in the program - not library methods - are represented.
- * <p>
- * The directional edges are represented as sets of nodes in each node (called methods and callees).
- * <p>
- * A call from method <code>a</code> to method <code>b</code> is only present once no matter how
+ *
+ * <p>Only methods in the program - not library methods - are represented.
+ *
+ * <p>The directional edges are represented as sets of nodes in each node (called methods and
+ * callees).
+ *
+ * <p>A call from method <code>a</code> to method <code>b</code> is only present once no matter how
  * many calls of <code>a</code> there are in <code>a</code>.
- * <p>
- * Recursive calls are not present.
+ *
+ * <p>Recursive calls are not present.
  */
-public class CallGraph extends CallSiteInformation {
+public class CallGraph {
 
   public static class Node implements Comparable<Node> {
 
@@ -82,10 +68,21 @@
       caller.callees.remove(this);
     }
 
+    public void cleanForRemoval() {
+      assert callees.isEmpty();
+      for (Node caller : callers) {
+        caller.callees.remove(this);
+      }
+    }
+
     public Node[] getCalleesWithDeterministicOrder() {
       return callees.toArray(Node.EMPTY_ARRAY);
     }
 
+    public int getNumberOfCallSites() {
+      return numberOfCallSites;
+    }
+
     public boolean hasCallee(Node method) {
       return callees.contains(method);
     }
@@ -135,46 +132,18 @@
     }
   }
 
-  private final Set<Node> nodes;
-  private final IROrdering shuffle;
+  final Set<Node> nodes;
 
-  private final Set<DexMethod> singleCallSite = Sets.newIdentityHashSet();
-  private final Set<DexMethod> doubleCallSite = Sets.newIdentityHashSet();
-
-  CallGraph(AppView<AppInfoWithLiveness> appView, Set<Node> nodes) {
+  CallGraph(Set<Node> nodes) {
     this.nodes = nodes;
-    this.shuffle = appView.options().testing.irOrdering;
-
-    for (Node node : nodes) {
-      // For non-pinned methods we know the exact number of call sites.
-      if (!appView.appInfo().isPinned(node.method.method)) {
-        if (node.numberOfCallSites == 1) {
-          singleCallSite.add(node.method.method);
-        } else if (node.numberOfCallSites == 2) {
-          doubleCallSite.add(node.method.method);
-        }
-      }
-    }
   }
 
   public static CallGraphBuilder builder(AppView<AppInfoWithLiveness> appView) {
     return new CallGraphBuilder(appView);
   }
 
-  /**
-   * Check if the <code>method</code> is guaranteed to only have a single call site.
-   * <p>
-   * For pinned methods (methods kept through Proguard keep rules) this will always answer
-   * <code>false</code>.
-   */
-  @Override
-  public boolean hasSingleCallSite(DexMethod method) {
-    return singleCallSite.contains(method);
-  }
-
-  @Override
-  public boolean hasDoubleCallSite(DexMethod method) {
-    return doubleCallSite.contains(method);
+  CallSiteInformation createCallSiteInformation(AppView<AppInfoWithLiveness> appView) {
+    return new CallGraphBasedCallSiteInformation(appView, this);
   }
 
   /**
@@ -185,50 +154,7 @@
    *
    * <p>
    */
-  private Collection<DexEncodedMethod> extractLeaves() {
-    if (isEmpty()) {
-      return Collections.emptySet();
-    }
-    // First identify all leaves before removing them from the graph.
-    List<Node> leaves = nodes.stream().filter(Node::isLeaf).collect(Collectors.toList());
-    for (Node leaf : leaves) {
-      leaf.callers.forEach(caller -> caller.callees.remove(leaf));
-      nodes.remove(leaf);
-    }
-    Set<DexEncodedMethod> methods =
-        leaves.stream().map(x -> x.method).collect(Collectors.toCollection(LinkedHashSet::new));
-    return shuffle.order(methods);
-  }
-
-  public boolean isEmpty() {
-    return nodes.isEmpty();
-  }
-
-  /**
-   * Applies the given method to all leaf nodes of the graph.
-   *
-   * <p>As second parameter, a predicate that can be used to decide whether another method is
-   * processed at the same time is passed. This can be used to avoid races in concurrent processing.
-   */
-  public <E extends Exception> void forEachMethod(
-      ThrowingBiConsumer<DexEncodedMethod, Predicate<DexEncodedMethod>, E> consumer,
-      Action waveStart,
-      Action waveDone,
-      ExecutorService executorService)
-      throws ExecutionException {
-    while (!isEmpty()) {
-      Collection<DexEncodedMethod> methods = extractLeaves();
-      assert methods.size() > 0;
-      List<Future<?>> futures = new ArrayList<>();
-      waveStart.execute();
-      for (DexEncodedMethod method : methods) {
-        futures.add(executorService.submit(() -> {
-          consumer.accept(method, methods::contains);
-          return null; // we want a Callable not a Runnable to be able to throw
-        }));
-      }
-      ThreadUtils.awaitFutures(futures);
-      waveDone.execute();
-    }
+  MethodProcessingOrder createMethodProcessingOrder(AppView<?> appView) {
+    return new MethodProcessingOrder(appView, this);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilder.java
index 28a6b84..699cd86 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilder.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilder.java
@@ -77,7 +77,7 @@
     timing.end();
     assert cycleEliminator.breakCycles() == 0; // This time the cycles should be gone.
 
-    return new CallGraph(appView, nodesWithDeterministicOrder);
+    return new CallGraph(nodesWithDeterministicOrder);
   }
 
   private void processClass(DexProgramClass clazz) {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
index 60e4326..36b5a83 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
@@ -3,7 +3,13 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.ir.conversion;
 
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.google.common.collect.Sets;
+import java.util.Set;
 
 public abstract class CallSiteInformation {
 
@@ -35,4 +41,53 @@
       return false;
     }
   }
+
+  static class CallGraphBasedCallSiteInformation extends CallSiteInformation {
+
+    private final Set<DexMethod> singleCallSite = Sets.newIdentityHashSet();
+    private final Set<DexMethod> doubleCallSite = Sets.newIdentityHashSet();
+
+    CallGraphBasedCallSiteInformation(AppView<AppInfoWithLiveness> appView, CallGraph graph) {
+      for (Node node : graph.nodes) {
+        DexEncodedMethod encodedMethod = node.method;
+        DexMethod method = encodedMethod.method;
+
+        // For non-pinned methods and methods that override library methods we do not know the exact
+        // number of call sites.
+        if (appView.appInfo().isPinned(method)
+            || encodedMethod.isLibraryMethodOverride().isTrue()) {
+          continue;
+        }
+
+        int numberOfCallSites = node.getNumberOfCallSites();
+        if (numberOfCallSites == 1) {
+          singleCallSite.add(method);
+        } else if (numberOfCallSites == 2) {
+          doubleCallSite.add(method);
+        }
+      }
+    }
+
+    /**
+     * Checks if the given method only has a single call site.
+     *
+     * <p>For pinned methods (methods kept through Proguard keep rules) and methods that override a
+     * library method this always returns false.
+     */
+    @Override
+    public boolean hasSingleCallSite(DexMethod method) {
+      return singleCallSite.contains(method);
+    }
+
+    /**
+     * Checks if the given method only has two call sites.
+     *
+     * <p>For pinned methods (methods kept through Proguard keep rules) and methods that override a
+     * library method this always returns false.
+     */
+    @Override
+    public boolean hasDoubleCallSite(DexMethod method) {
+      return doubleCallSite.contains(method);
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index d104823..c6c1ba8 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -502,6 +502,8 @@
 
   public DexApplication optimize(DexApplication application, ExecutorService executorService)
       throws ExecutionException {
+    AppView<AppInfoWithLiveness> appViewWithLiveness = appView.withLiveness();
+
     if (options.isShrinking()) {
       assert !removeLambdaDeserializationMethods();
     } else {
@@ -527,13 +529,17 @@
       timing.begin("Build call graph");
       CallGraph callGraph =
           CallGraph.builder(appView.withLiveness()).build(executorService, timing);
+      CallSiteInformation callSiteInformation =
+          callGraph.createCallSiteInformation(appViewWithLiveness);
+      MethodProcessingOrder methodProcessingOrder = callGraph.createMethodProcessingOrder(appView);
       timing.end();
       timing.begin("IR conversion phase 1");
       BiConsumer<IRCode, DexEncodedMethod> outlineHandler =
           outliner == null ? Outliner::noProcessing : outliner.identifyCandidateMethods();
-      callGraph.forEachMethod(
+      methodProcessingOrder.forEachMethod(
           (method, isProcessedConcurrently) ->
-              processMethod(method, feedback, isProcessedConcurrently, callGraph, outlineHandler),
+              processMethod(
+                  method, feedback, isProcessedConcurrently, callSiteInformation, outlineHandler),
           this::waveStart,
           this::waveDone,
           executorService);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessingOrder.java b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessingOrder.java
new file mode 100644
index 0000000..6d84aba
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessingOrder.java
@@ -0,0 +1,92 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.conversion;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.utils.Action;
+import com.android.tools.r8.utils.IROrdering;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.ThrowingBiConsumer;
+import com.google.common.collect.Sets;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.function.Predicate;
+
+public class MethodProcessingOrder {
+
+  private final Deque<Collection<DexEncodedMethod>> waves;
+
+  MethodProcessingOrder(AppView<?> appView, CallGraph callGraph) {
+    this.waves = createWaves(appView, callGraph);
+  }
+
+  public static Deque<Collection<DexEncodedMethod>> createWaves(
+      AppView<?> appView, CallGraph callGraph) {
+    IROrdering shuffle = appView.options().testing.irOrdering;
+    Deque<Collection<DexEncodedMethod>> waves = new ArrayDeque<>();
+
+    Set<Node> nodes = callGraph.nodes;
+    while (!nodes.isEmpty()) {
+      waves.addLast(shuffle.order(extractLeaves(nodes)));
+    }
+    return waves;
+  }
+
+  private static Set<DexEncodedMethod> extractLeaves(Set<Node> nodes) {
+    Set<DexEncodedMethod> leaves = Sets.newIdentityHashSet();
+    Set<Node> removed = Sets.newIdentityHashSet();
+    Iterator<Node> nodeIterator = nodes.iterator();
+    while (nodeIterator.hasNext()) {
+      Node node = nodeIterator.next();
+      if (node.isLeaf()) {
+        leaves.add(node.method);
+        nodeIterator.remove();
+        removed.add(node);
+      }
+    }
+    removed.forEach(Node::cleanForRemoval);
+    return leaves;
+  }
+
+  /**
+   * Applies the given method to all leaf nodes of the graph.
+   *
+   * <p>As second parameter, a predicate that can be used to decide whether another method is
+   * processed at the same time is passed. This can be used to avoid races in concurrent processing.
+   */
+  public <E extends Exception> void forEachMethod(
+      ThrowingBiConsumer<DexEncodedMethod, Predicate<DexEncodedMethod>, E> consumer,
+      Action waveStart,
+      Action waveDone,
+      ExecutorService executorService)
+      throws ExecutionException {
+    while (!waves.isEmpty()) {
+      Collection<DexEncodedMethod> wave = waves.removeFirst();
+      assert wave.size() > 0;
+      List<Future<?>> futures = new ArrayList<>();
+      waveStart.execute();
+      for (DexEncodedMethod method : wave) {
+        futures.add(
+            executorService.submit(
+                () -> {
+                  consumer.accept(method, wave::contains);
+                  return null; // we want a Callable not a Runnable to be able to throw
+                }));
+      }
+      ThreadUtils.awaitFutures(futures);
+      waveDone.execute();
+    }
+  }
+}