Parallelize post processing of proto methods

Change-Id: I473e7d4a511f54a74c22d42e75887982de9096e0
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index fc97ba6..17c6a0a 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -733,7 +733,7 @@
           // TODO(b/112437944): Avoid iterating the entire application to post-process every
           //  dynamicMethod() method.
           appView.withGeneratedMessageLiteShrinker(
-              shrinker -> shrinker.postOptimizeDynamicMethods(converter, timing));
+              shrinker -> shrinker.postOptimizeDynamicMethods(converter, executorService, timing));
 
           // If proto shrinking is enabled, we need to post-process every
           // findLiteExtensionByNumber() method. This ensures that there are no references to dead
@@ -741,7 +741,9 @@
           // TODO(b/112437944): Avoid iterating the entire application to post-process every
           //  findLiteExtensionByNumber() method.
           appView.withGeneratedExtensionRegistryShrinker(
-              shrinker -> shrinker.postOptimizeGeneratedExtensionRegistry(converter, timing));
+              shrinker ->
+                  shrinker.postOptimizeGeneratedExtensionRegistry(
+                      converter, executorService, timing));
         }
       }
 
diff --git a/src/main/java/com/android/tools/r8/graph/AppView.java b/src/main/java/com/android/tools/r8/graph/AppView.java
index 3d12531..3a55054 100644
--- a/src/main/java/com/android/tools/r8/graph/AppView.java
+++ b/src/main/java/com/android/tools/r8/graph/AppView.java
@@ -27,7 +27,6 @@
 import com.google.common.base.Predicates;
 import java.util.IdentityHashMap;
 import java.util.Map;
-import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Predicate;
 
@@ -266,8 +265,8 @@
     return defaultValue;
   }
 
-  public void withGeneratedExtensionRegistryShrinker(
-      Consumer<GeneratedExtensionRegistryShrinker> consumer) {
+  public <E extends Throwable> void withGeneratedExtensionRegistryShrinker(
+      ThrowingConsumer<GeneratedExtensionRegistryShrinker, E> consumer) throws E {
     if (protoShrinker != null && protoShrinker.generatedExtensionRegistryShrinker != null) {
       consumer.accept(protoShrinker.generatedExtensionRegistryShrinker);
     }
@@ -281,7 +280,8 @@
     return defaultValue;
   }
 
-  public void withGeneratedMessageLiteShrinker(Consumer<GeneratedMessageLiteShrinker> consumer) {
+  public <E extends Throwable> void withGeneratedMessageLiteShrinker(
+      ThrowingConsumer<GeneratedMessageLiteShrinker, E> consumer) throws E {
     if (protoShrinker != null && protoShrinker.generatedMessageLiteShrinker != null) {
       consumer.accept(protoShrinker.generatedMessageLiteShrinker);
     }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
index 08b29ce..b4e025c 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
@@ -4,6 +4,7 @@
 
 package com.android.tools.r8.ir.analysis.proto;
 
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
 import static com.google.common.base.Predicates.not;
 
 import com.android.tools.r8.graph.AppView;
@@ -32,6 +33,7 @@
 import com.android.tools.r8.shaking.TreePrunerConfiguration;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.FileUtils;
+import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
 import com.google.common.base.Predicates;
 import com.google.common.collect.Sets;
@@ -40,6 +42,8 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 import java.util.stream.Collectors;
@@ -156,26 +160,29 @@
     return removedExtensionFields.contains(field);
   }
 
-  public void postOptimizeGeneratedExtensionRegistry(IRConverter converter, Timing timing) {
+  public void postOptimizeGeneratedExtensionRegistry(
+      IRConverter converter, ExecutorService executorService, Timing timing)
+      throws ExecutionException {
     timing.begin("[Proto] Post optimize generated extension registry");
-    forEachFindLiteExtensionByNumberMethod(
+    ThreadUtils.processItems(
+        this::forEachFindLiteExtensionByNumberMethod,
         method ->
             converter.processMethod(
                 method,
                 OptimizationFeedbackIgnore.getInstance(),
-                OneTimeMethodProcessor.getInstance()));
-    timing.end(); // [Proto] Post optimize generated extension registry
+                OneTimeMethodProcessor.getInstance()),
+        executorService);
+    timing.end();
   }
 
   private void forEachFindLiteExtensionByNumberMethod(Consumer<DexEncodedMethod> consumer) {
-    for (DexProgramClass clazz : appView.appInfo().classes()) {
-      if (clazz.superType != references.extensionRegistryLiteType) {
-        continue;
-      }
-
-      for (DexEncodedMethod method : clazz.methods()) {
-        if (references.isFindLiteExtensionByNumberMethod(method.method)) {
-          consumer.accept(method);
+    for (DexType type : appView.appInfo().subtypes(references.extensionRegistryLiteType)) {
+      DexProgramClass clazz = asProgramClassOrNull(appView.definitionFor(type));
+      if (clazz != null) {
+        for (DexEncodedMethod method : clazz.methods()) {
+          if (references.isFindLiteExtensionByNumberMethod(method.method)) {
+            consumer.accept(method);
+          }
         }
       }
     }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
index b2d813a..40b8cb5 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
@@ -4,6 +4,7 @@
 
 package com.android.tools.r8.ir.analysis.proto;
 
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
 import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.getInfoValueFromMessageInfoConstructionInvoke;
 import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.getObjectsValueFromMessageInfoConstructionInvoke;
 import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.setObjectsValueForMessageInfoConstructionInvoke;
@@ -11,7 +12,10 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.proto.schema.ProtoMessageInfo;
 import com.android.tools.r8.ir.analysis.proto.schema.ProtoObject;
 import com.android.tools.r8.ir.analysis.type.Nullability;
@@ -32,8 +36,11 @@
 import com.android.tools.r8.ir.conversion.OneTimeMethodProcessor;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackIgnore;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
 import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
 
 public class GeneratedMessageLiteShrinker {
@@ -70,22 +77,33 @@
     }
   }
 
-  public void postOptimizeDynamicMethods(IRConverter converter, Timing timing) {
+  public void postOptimizeDynamicMethods(
+      IRConverter converter, ExecutorService executorService, Timing timing)
+      throws ExecutionException {
     timing.begin("[Proto] Post optimize dynamic methods");
-    forEachDynamicMethod(
+    ThreadUtils.processItems(
+        this::forEachDynamicMethod,
         method ->
             converter.processMethod(
                 method,
                 OptimizationFeedbackIgnore.getInstance(),
-                OneTimeMethodProcessor.getInstance()));
+                OneTimeMethodProcessor.getInstance()),
+        executorService);
     timing.end();
   }
 
   private void forEachDynamicMethod(Consumer<DexEncodedMethod> consumer) {
-    for (DexProgramClass clazz : appView.appInfo().classes()) {
-      DexEncodedMethod dynamicMethod = clazz.lookupVirtualMethod(references::isDynamicMethod);
-      if (dynamicMethod != null) {
-        consumer.accept(dynamicMethod);
+    DexItemFactory dexItemFactory = appView.dexItemFactory();
+    for (DexType type : appView.appInfo().subtypes(references.generatedMessageLiteType)) {
+      DexProgramClass clazz = asProgramClassOrNull(appView.definitionFor(type));
+      if (clazz != null) {
+        DexMethod dynamicMethod =
+            dexItemFactory.createMethod(
+                type, references.dynamicMethodProto, references.dynamicMethodName);
+        DexEncodedMethod encodedDynamicMethod = clazz.lookupVirtualMethod(dynamicMethod);
+        if (encodedDynamicMethod != null) {
+          consumer.accept(encodedDynamicMethod);
+        }
       }
     }
   }
diff --git a/src/main/java/com/android/tools/r8/utils/ForEachable.java b/src/main/java/com/android/tools/r8/utils/ForEachable.java
new file mode 100644
index 0000000..4c8ae53
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/ForEachable.java
@@ -0,0 +1,11 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.utils;
+
+import java.util.function.Consumer;
+
+public interface ForEachable<T> {
+
+  void forEach(Consumer<T> consumer);
+}
diff --git a/src/main/java/com/android/tools/r8/utils/ThreadUtils.java b/src/main/java/com/android/tools/r8/utils/ThreadUtils.java
index 84aa0d0..e6076ca 100644
--- a/src/main/java/com/android/tools/r8/utils/ThreadUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/ThreadUtils.java
@@ -21,32 +21,37 @@
   public static <T, R, E extends Exception> Collection<R> processItemsWithResults(
       Iterable<T> items, ThrowingFunction<T, R, E> consumer, ExecutorService executorService)
       throws ExecutionException {
+    return processItemsWithResults(items::forEach, consumer, executorService);
+  }
+
+  public static <T, R, E extends Exception> Collection<R> processItemsWithResults(
+      ForEachable<T> items, ThrowingFunction<T, R, E> consumer, ExecutorService executorService)
+      throws ExecutionException {
     List<Future<R>> futures = new ArrayList<>();
-    for (T item : items) {
-      futures.add(executorService.submit(() -> consumer.apply(item)));
-    }
+    items.forEach(item -> futures.add(executorService.submit(() -> consumer.apply(item))));
     return awaitFuturesWithResults(futures);
   }
 
   public static <T, E extends Exception> void processItems(
       Iterable<T> items, ThrowingConsumer<T, E> consumer, ExecutorService executorService)
       throws ExecutionException {
-    processItemsWithResults(
-        items,
-        arg -> {
-          consumer.accept(arg);
-          return null;
-        },
-        executorService);
+    processItems(items::forEach, consumer, executorService);
   }
 
   public static <T, U, E extends Exception> void processItems(
       Map<T, U> items, ThrowingBiConsumer<T, U, E> consumer, ExecutorService executorService)
       throws ExecutionException {
+    processItems(
+        items.entrySet(), arg -> consumer.accept(arg.getKey(), arg.getValue()), executorService);
+  }
+
+  public static <T, E extends Exception> void processItems(
+      ForEachable<T> items, ThrowingConsumer<T, E> consumer, ExecutorService executorService)
+      throws ExecutionException {
     processItemsWithResults(
-        items.entrySet(),
+        items,
         arg -> {
-          consumer.accept(arg.getKey(), arg.getValue());
+          consumer.accept(arg);
           return null;
         },
         executorService);