Decouple method pools from ClassAndMemberPublicizer

Change-Id: I99fba323217e226bee96d52b8683341cbd78c0d4
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/MethodPoolCollection.java b/src/main/java/com/android/tools/r8/ir/optimize/MethodPoolCollection.java
new file mode 100644
index 0000000..9af3967
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/optimize/MethodPoolCollection.java
@@ -0,0 +1,147 @@
+// Copyright (c) 2018, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize;
+
+import com.android.tools.r8.graph.DexApplication;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.utils.MethodSignatureEquivalence;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.Timing;
+import com.google.common.base.Equivalence;
+import com.google.common.base.Equivalence.Wrapper;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+// Per-class collection of method signatures.
+//
+// Example use case: to determine if a certain method can be publicized or not.
+public class MethodPoolCollection {
+
+  private static final Equivalence<DexMethod> equivalence = MethodSignatureEquivalence.get();
+
+  private final DexApplication application;
+  private final Map<DexClass, MethodPool> methodPools = new ConcurrentHashMap<>();
+
+  public MethodPoolCollection(DexApplication application) {
+    this.application = application;
+  }
+
+  public void buildAll(ExecutorService executorService, Timing timing) throws ExecutionException {
+    timing.begin("Building method pool collection");
+    try {
+      List<Future<?>> futures = new ArrayList<>();
+      submitAll(application.classes(), futures, executorService);
+      ThreadUtils.awaitFutures(futures);
+    } finally {
+      timing.end();
+    }
+  }
+
+  public MethodPool get(DexClass clazz) {
+    assert methodPools.containsKey(clazz);
+    return methodPools.get(clazz);
+  }
+
+  public boolean markIfNotSeen(DexClass clazz, DexMethod method) {
+    MethodPool methodPool = get(clazz);
+    Wrapper<DexMethod> key = equivalence.wrap(method);
+    if (methodPool.hasSeen(key)) {
+      return true;
+    }
+    methodPool.seen(key);
+    return false;
+  }
+
+  private void submitAll(
+      Iterable<DexProgramClass> classes, List<Future<?>> futures, ExecutorService executorService) {
+    for (DexProgramClass clazz : classes) {
+      futures.add(executorService.submit(computeMethodPoolPerClass(clazz)));
+    }
+  }
+
+  private Runnable computeMethodPoolPerClass(DexClass clazz) {
+    return () -> {
+      MethodPool methodPool = methodPools.computeIfAbsent(clazz, k -> new MethodPool());
+      clazz.forEachMethod(
+          encodedMethod -> {
+            // We will add private instance methods when we promote them.
+            if (!encodedMethod.isPrivateMethod() || encodedMethod.isStaticMethod()) {
+              methodPool.seen(equivalence.wrap(encodedMethod.method));
+            }
+          });
+      if (clazz.superType != null) {
+        DexClass superClazz = application.definitionFor(clazz.superType);
+        if (superClazz != null) {
+          MethodPool superPool = methodPools.computeIfAbsent(superClazz, k -> new MethodPool());
+          superPool.linkSubtype(methodPool);
+          methodPool.linkSupertype(superPool);
+        }
+      }
+      if (clazz.isInterface()) {
+        clazz.type.forAllImplementsSubtypes(
+            implementer -> {
+              DexClass subClazz = application.definitionFor(implementer);
+              if (subClazz != null) {
+                MethodPool childPool = methodPools.computeIfAbsent(subClazz, k -> new MethodPool());
+                childPool.linkInterface(methodPool);
+              }
+            });
+      }
+    };
+  }
+
+  public static class MethodPool {
+    private MethodPool superType;
+    private final Set<MethodPool> interfaces = new HashSet<>();
+    private final Set<MethodPool> subTypes = new HashSet<>();
+    private final Set<Wrapper<DexMethod>> methodPool = new HashSet<>();
+
+    private MethodPool() {}
+
+    synchronized void linkSupertype(MethodPool superType) {
+      assert this.superType == null;
+      this.superType = superType;
+    }
+
+    synchronized void linkSubtype(MethodPool subType) {
+      boolean added = subTypes.add(subType);
+      assert added;
+    }
+
+    synchronized void linkInterface(MethodPool itf) {
+      boolean added = interfaces.add(itf);
+      assert added;
+    }
+
+    public synchronized void seen(Wrapper<DexMethod> method) {
+      boolean added = methodPool.add(method);
+      assert added;
+    }
+
+    public boolean hasSeen(Wrapper<DexMethod> method) {
+      return hasSeenUpwardRecursive(method) || hasSeenDownwardRecursive(method);
+    }
+
+    private boolean hasSeenUpwardRecursive(Wrapper<DexMethod> method) {
+      return methodPool.contains(method)
+          || (superType != null && superType.hasSeenUpwardRecursive(method))
+          || interfaces.stream().anyMatch(itf -> itf.hasSeenUpwardRecursive(method));
+    }
+
+    private boolean hasSeenDownwardRecursive(Wrapper<DexMethod> method) {
+      return methodPool.contains(method)
+          || subTypes.stream().anyMatch(subType -> subType.hasSeenDownwardRecursive(method));
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java b/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
index 73ebf73..b53e1bc 100644
--- a/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
+++ b/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
@@ -11,23 +11,16 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.GraphLense;
 import com.android.tools.r8.graph.MethodAccessFlags;
+import com.android.tools.r8.ir.optimize.MethodPoolCollection;
 import com.android.tools.r8.optimize.PublicizerLense.PublicizedLenseBuilder;
 import com.android.tools.r8.shaking.RootSetBuilder.RootSet;
 import com.android.tools.r8.utils.MethodSignatureEquivalence;
-import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
 import com.google.common.base.Equivalence;
-import com.google.common.base.Equivalence.Wrapper;
-import java.util.ArrayList;
-import java.util.HashSet;
 import java.util.LinkedHashSet;
-import java.util.List;
-import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
 
 public final class ClassAndMemberPublicizer {
   private final DexApplication application;
@@ -36,11 +29,12 @@
   private final PublicizedLenseBuilder lenseBuilder;
 
   private final Equivalence<DexMethod> equivalence = MethodSignatureEquivalence.get();
-  private final Map<DexClass, MethodPool> methodPools = new ConcurrentHashMap<>();
+  private final MethodPoolCollection methodPoolCollection;
 
   private ClassAndMemberPublicizer(DexApplication application, AppView appView, RootSet rootSet) {
     this.application = application;
     this.appView = appView;
+    this.methodPoolCollection = new MethodPoolCollection(application);
     this.rootSet = rootSet;
     lenseBuilder = PublicizerLense.createBuilder();
   }
@@ -64,15 +58,7 @@
   private GraphLense run(ExecutorService executorService, Timing timing)
       throws ExecutionException {
     // Phase 1: Collect methods to check if private instance methods don't have conflicts.
-    timing.begin("Phase 1: collectMethods");
-    try {
-      List<Future<?>> futures = new ArrayList<>();
-      application.classes().forEach(clazz ->
-          futures.add(executorService.submit(computeMethodPoolPerClass(clazz))));
-      ThreadUtils.awaitFutures(futures);
-    } finally {
-      timing.end();
-    }
+    methodPoolCollection.buildAll(executorService, timing);
 
     // Phase 2: Visit classes and promote class/member to public if possible.
     timing.begin("Phase 2: promoteToPublic");
@@ -83,35 +69,6 @@
     return lenseBuilder.build(appView);
   }
 
-  private Runnable computeMethodPoolPerClass(DexClass clazz) {
-    return () -> {
-      MethodPool methodPool = methodPools.computeIfAbsent(clazz, k -> new MethodPool());
-      clazz.forEachMethod(encodedMethod -> {
-        // We will add private instance methods when we promote them.
-        if (!encodedMethod.isPrivateMethod() || encodedMethod.isStaticMethod()) {
-          methodPool.seen(equivalence.wrap(encodedMethod.method));
-        }
-      });
-      if (clazz.superType != null) {
-        DexClass superClazz = application.definitionFor(clazz.superType);
-        if (superClazz != null) {
-          MethodPool superPool = methodPools.computeIfAbsent(superClazz, k -> new MethodPool());
-          superPool.linkSubtype(methodPool);
-          methodPool.linkSupertype(superPool);
-        }
-      }
-      if (clazz.isInterface()) {
-        clazz.type.forAllImplementsSubtypes(implementer -> {
-          DexClass subClazz = application.definitionFor(implementer);
-          if (subClazz != null) {
-            MethodPool childPool = methodPools.computeIfAbsent(subClazz, k -> new MethodPool());
-            childPool.linkInterface(methodPool);
-          }
-        });
-      }
-    };
-  }
-
   private void publicizeType(DexType type) {
     DexClass clazz = application.definitionFor(type);
     if (clazz != null && clazz.isProgramClass()) {
@@ -166,9 +123,8 @@
         return false;
       }
 
-      MethodPool methodPool = methodPools.get(holder);
-      Wrapper<DexMethod> key = equivalence.wrap(encodedMethod.method);
-      if (methodPool.hasSeen(key)) {
+      boolean wasSeen = methodPoolCollection.markIfNotSeen(holder, encodedMethod.method);
+      if (wasSeen) {
         // We can't do anything further because even renaming is not allowed due to the keep rule.
         if (rootSet.noObfuscation.contains(encodedMethod)) {
           return false;
@@ -176,7 +132,6 @@
         // TODO(b/111118390): Renaming will enable more private instance methods to be publicized.
         return false;
       }
-      methodPool.seen(key);
       lenseBuilder.add(encodedMethod.method);
       accessFlags.unsetPrivate();
       accessFlags.setFinal();
@@ -196,51 +151,4 @@
     accessFlags.setPublic();
     return false;
   }
-
-  // Per-class collection of method signatures, which will be used to determine if a certain method
-  // can be publicized or not.
-  static class MethodPool {
-    private MethodPool superType;
-    private final Set<MethodPool> interfaces = new HashSet<>();
-    private final Set<MethodPool> subTypes = new HashSet<>();
-    private final Set<Wrapper<DexMethod>> methodPool = new HashSet<>();
-
-    MethodPool() {
-    }
-
-    synchronized void linkSupertype(MethodPool superType) {
-      assert this.superType == null;
-      this.superType = superType;
-    }
-
-    synchronized void linkSubtype(MethodPool subType) {
-      boolean added = subTypes.add(subType);
-      assert added;
-    }
-
-    synchronized void linkInterface(MethodPool itf) {
-      boolean added = interfaces.add(itf);
-      assert added;
-    }
-
-    synchronized void seen(Wrapper<DexMethod> method) {
-      boolean added = methodPool.add(method);
-      assert added;
-    }
-
-    boolean hasSeen(Wrapper<DexMethod> method) {
-      return hasSeenUpwardRecursive(method) || hasSeenDownwardRecursive(method);
-    }
-
-    private boolean hasSeenUpwardRecursive(Wrapper<DexMethod> method) {
-      return methodPool.contains(method)
-          || (superType != null && superType.hasSeenUpwardRecursive(method))
-          || interfaces.stream().anyMatch(itf -> itf.hasSeenUpwardRecursive(method));
-    }
-
-    private boolean hasSeenDownwardRecursive(Wrapper<DexMethod> method) {
-      return methodPool.contains(method)
-          || subTypes.stream().anyMatch(subType -> subType.hasSeenDownwardRecursive(method));
-    }
-  }
 }