Parallelize flow constraint solver in argument propagator

Change-Id: I49a2e3d9262f027e229ecb3e59db301911c586ad
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
index 36719be..cabc28f 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
@@ -4,8 +4,6 @@
 
 package com.android.tools.r8.optimize.argumentpropagation;
 
-import static com.android.tools.r8.optimize.argumentpropagation.utils.StronglyConnectedProgramClasses.computeStronglyConnectedProgramClasses;
-
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
@@ -18,6 +16,7 @@
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollectionByReference;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.VirtualRootMethodsAnalysis;
 import com.android.tools.r8.optimize.argumentpropagation.reprocessingcriteria.ArgumentPropagatorReprocessingCriteriaCollection;
+import com.android.tools.r8.optimize.argumentpropagation.utils.ProgramClassesBidirectedGraph;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
@@ -74,7 +73,8 @@
     ImmediateProgramSubtypingInfo immediateSubtypingInfo =
         ImmediateProgramSubtypingInfo.create(appView);
     List<Set<DexProgramClass>> stronglyConnectedProgramClasses =
-        computeStronglyConnectedProgramClasses(appView, immediateSubtypingInfo);
+        new ProgramClassesBidirectedGraph(appView, immediateSubtypingInfo)
+            .computeStronglyConnectedComponents();
     ThreadUtils.processItems(
         stronglyConnectedProgramClasses,
         classes -> {
@@ -132,7 +132,8 @@
     ImmediateProgramSubtypingInfo immediateSubtypingInfo =
         ImmediateProgramSubtypingInfo.create(appView);
     List<Set<DexProgramClass>> stronglyConnectedProgramComponents =
-        computeStronglyConnectedProgramClasses(appView, immediateSubtypingInfo);
+        new ProgramClassesBidirectedGraph(appView, immediateSubtypingInfo)
+            .computeStronglyConnectedComponents();
     timing.end();
 
     // Set the optimization info on each method.
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InParameterFlowPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InParameterFlowPropagator.java
index b551b68..1f8637a 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InParameterFlowPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InParameterFlowPropagator.java
@@ -20,6 +20,7 @@
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollectionByReference;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.NonEmptyParameterState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ParameterState;
+import com.android.tools.r8.optimize.argumentpropagation.utils.BidirectedGraph;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.Action;
 import com.android.tools.r8.utils.ThreadUtils;
@@ -52,16 +53,22 @@
     // must be included in the argument information for p'.
     FlowGraph flowGraph = new FlowGraph(appView.appInfo().classes());
 
+    List<Set<ParameterNode>> stronglyConnectedComponents =
+        flowGraph.computeStronglyConnectedComponents();
+    ThreadUtils.processItems(stronglyConnectedComponents, this::process, executorService);
+
+    // The algorithm only changes the parameter states of each monomorphic method state. In case any
+    // of these method states have effectively become unknown, we replace them by the canonicalized
+    // unknown method state.
+    postProcessMethodStates(executorService);
+  }
+
+  private void process(Set<ParameterNode> stronglyConnectedComponent) {
     // Build a worklist containing all the parameter nodes.
-    Deque<ParameterNode> worklist = new ArrayDeque<>();
-    flowGraph.forEachNode(worklist::add);
+    Deque<ParameterNode> worklist = new ArrayDeque<>(stronglyConnectedComponent);
 
     // Repeatedly propagate argument information through edges in the flow graph until there are no
     // more changes.
-    // TODO(b/190154391): Consider parallelizing the flow propagation. There are a few scenarios
-    //  that need to be covered, such as (i) two threads could race to update the same parameter
-    //  state, (ii) a thread may try to propagate a parameter state to its successors while
-    //  another thread is trying to update the state of the parameter itself.
     // TODO(b/190154391): Consider a path p1 -> p2 -> p3 in the graph. If we process p2 first, then
     //  p3, and then p1, then the processing of p1 could cause p2 to change, which means that we
     //  need to reprocess p2 and then p3. If we always process leaves in the graph first, we would
@@ -83,11 +90,6 @@
             }
           });
     }
-
-    // The algorithm only changes the parameter states of each monomorphic method state. In case any
-    // of these method states have effectively become unknown, we replace them by the canonicalized
-    // unknown method state.
-    postProcessMethodStates(executorService);
   }
 
   private void propagate(
@@ -133,7 +135,7 @@
     }
   }
 
-  private class FlowGraph {
+  public class FlowGraph extends BidirectedGraph<ParameterNode> {
 
     private final Map<DexMethod, Int2ReferenceMap<ParameterNode>> nodes = new IdentityHashMap<>();
 
@@ -141,7 +143,14 @@
       classes.forEach(this::add);
     }
 
-    void forEachNode(Consumer<? super ParameterNode> consumer) {
+    @Override
+    public void forEachNeighbor(ParameterNode node, Consumer<? super ParameterNode> consumer) {
+      node.getPredecessors().forEach(consumer);
+      node.getSuccessors().forEach(consumer);
+    }
+
+    @Override
+    public void forEachNode(Consumer<? super ParameterNode> consumer) {
       nodes.values().forEach(nodesForMethod -> nodesForMethod.values().forEach(consumer));
     }
 
@@ -278,6 +287,10 @@
       predecessors.clear();
     }
 
+    Set<ParameterNode> getPredecessors() {
+      return predecessors;
+    }
+
     ParameterState getState() {
       return methodState.getParameterState(parameterIndex);
     }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/BidirectedGraph.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/BidirectedGraph.java
new file mode 100644
index 0000000..8e76917
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/BidirectedGraph.java
@@ -0,0 +1,47 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.argumentpropagation.utils;
+
+import com.android.tools.r8.utils.WorkList;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Consumer;
+
+public abstract class BidirectedGraph<T> {
+
+  public abstract void forEachNeighbor(T node, Consumer<? super T> consumer);
+
+  public abstract void forEachNode(Consumer<? super T> consumer);
+
+  /**
+   * Computes the strongly connected components in the current bidirectional graph (i.e., each
+   * strongly connected component can be found using a breadth first search).
+   */
+  public List<Set<T>> computeStronglyConnectedComponents() {
+    Set<T> seen = new HashSet<>();
+    List<Set<T>> stronglyConnectedComponents = new ArrayList<>();
+    forEachNode(
+        node -> {
+          if (seen.contains(node)) {
+            return;
+          }
+          Set<T> stronglyConnectedComponent = internalComputeStronglyConnectedProgramClasses(node);
+          stronglyConnectedComponents.add(stronglyConnectedComponent);
+          seen.addAll(stronglyConnectedComponent);
+        });
+    return stronglyConnectedComponents;
+  }
+
+  private Set<T> internalComputeStronglyConnectedProgramClasses(T node) {
+    WorkList<T> worklist = WorkList.newEqualityWorkList(node);
+    while (worklist.hasNext()) {
+      T current = worklist.next();
+      forEachNeighbor(current, worklist::addIfNotSeen);
+    }
+    return worklist.getSeenSet();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/ProgramClassesBidirectedGraph.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/ProgramClassesBidirectedGraph.java
new file mode 100644
index 0000000..474c445
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/ProgramClassesBidirectedGraph.java
@@ -0,0 +1,34 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.argumentpropagation.utils;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import java.util.function.Consumer;
+
+public class ProgramClassesBidirectedGraph extends BidirectedGraph<DexProgramClass> {
+
+  private final AppView<AppInfoWithLiveness> appView;
+  private final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
+
+  public ProgramClassesBidirectedGraph(
+      AppView<AppInfoWithLiveness> appView, ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
+    this.appView = appView;
+    this.immediateSubtypingInfo = immediateSubtypingInfo;
+  }
+
+  @Override
+  public void forEachNeighbor(DexProgramClass node, Consumer<? super DexProgramClass> consumer) {
+    immediateSubtypingInfo.forEachImmediateProgramSuperClass(node, consumer);
+    immediateSubtypingInfo.getSubclasses(node).forEach(consumer);
+  }
+
+  @Override
+  public void forEachNode(Consumer<? super DexProgramClass> consumer) {
+    appView.appInfo().classes().forEach(consumer);
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/StronglyConnectedProgramClasses.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/StronglyConnectedProgramClasses.java
deleted file mode 100644
index 25d28cb..0000000
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/StronglyConnectedProgramClasses.java
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package com.android.tools.r8.optimize.argumentpropagation.utils;
-
-import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
-import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.utils.WorkList;
-import com.google.common.collect.Sets;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Set;
-
-public class StronglyConnectedProgramClasses {
-
-  /**
-   * Computes the strongly connected components in the program class hierarchy (where extends and
-   * implements edges are treated as bidirectional).
-   */
-  public static List<Set<DexProgramClass>> computeStronglyConnectedProgramClasses(
-      AppView<AppInfoWithLiveness> appView, ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
-    Set<DexProgramClass> seen = Sets.newIdentityHashSet();
-    List<Set<DexProgramClass>> stronglyConnectedComponents = new ArrayList<>();
-    for (DexProgramClass clazz : appView.appInfo().classes()) {
-      if (seen.contains(clazz)) {
-        continue;
-      }
-      Set<DexProgramClass> stronglyConnectedComponent =
-          internalComputeStronglyConnectedProgramClasses(clazz, immediateSubtypingInfo);
-      stronglyConnectedComponents.add(stronglyConnectedComponent);
-      seen.addAll(stronglyConnectedComponent);
-    }
-    return stronglyConnectedComponents;
-  }
-
-  private static Set<DexProgramClass> internalComputeStronglyConnectedProgramClasses(
-      DexProgramClass clazz, ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
-    WorkList<DexProgramClass> worklist = WorkList.newIdentityWorkList(clazz);
-    while (worklist.hasNext()) {
-      DexProgramClass current = worklist.next();
-      immediateSubtypingInfo.forEachImmediateProgramSuperClass(current, worklist::addIfNotSeen);
-      worklist.addIfNotSeen(immediateSubtypingInfo.getSubclasses(current));
-    }
-    return worklist.getSeenSet();
-  }
-}
diff --git a/src/main/java/com/android/tools/r8/utils/WorkList.java b/src/main/java/com/android/tools/r8/utils/WorkList.java
index 67e4148..3177689 100644
--- a/src/main/java/com/android/tools/r8/utils/WorkList.java
+++ b/src/main/java/com/android/tools/r8/utils/WorkList.java
@@ -17,11 +17,17 @@
   private final Set<T> seen;
 
   public static <T> WorkList<T> newEqualityWorkList() {
-    return new WorkList<T>(EqualityTest.HASH);
+    return new WorkList<T>(EqualityTest.EQUALS);
+  }
+
+  public static <T> WorkList<T> newEqualityWorkList(T item) {
+    WorkList<T> workList = new WorkList<>(EqualityTest.EQUALS);
+    workList.addIfNotSeen(item);
+    return workList;
   }
 
   public static <T> WorkList<T> newEqualityWorkList(Iterable<T> items) {
-    WorkList<T> workList = new WorkList<>(EqualityTest.HASH);
+    WorkList<T> workList = new WorkList<>(EqualityTest.EQUALS);
     workList.addIfNotSeen(items);
     return workList;
   }
@@ -53,7 +59,7 @@
   }
 
   private WorkList(EqualityTest equalityTest) {
-    this(equalityTest == EqualityTest.HASH ? new HashSet<>() : Sets.newIdentityHashSet());
+    this(equalityTest == EqualityTest.EQUALS ? new HashSet<>() : Sets.newIdentityHashSet());
   }
 
   private WorkList(Set<T> seen) {
@@ -120,7 +126,7 @@
   }
 
   public enum EqualityTest {
-    HASH,
+    EQUALS,
     IDENTITY
   }
 }