Parallelize post processing of proto methods
Change-Id: I473e7d4a511f54a74c22d42e75887982de9096e0
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index fc97ba6..17c6a0a 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -733,7 +733,7 @@
// TODO(b/112437944): Avoid iterating the entire application to post-process every
// dynamicMethod() method.
appView.withGeneratedMessageLiteShrinker(
- shrinker -> shrinker.postOptimizeDynamicMethods(converter, timing));
+ shrinker -> shrinker.postOptimizeDynamicMethods(converter, executorService, timing));
// If proto shrinking is enabled, we need to post-process every
// findLiteExtensionByNumber() method. This ensures that there are no references to dead
@@ -741,7 +741,9 @@
// TODO(b/112437944): Avoid iterating the entire application to post-process every
// findLiteExtensionByNumber() method.
appView.withGeneratedExtensionRegistryShrinker(
- shrinker -> shrinker.postOptimizeGeneratedExtensionRegistry(converter, timing));
+ shrinker ->
+ shrinker.postOptimizeGeneratedExtensionRegistry(
+ converter, executorService, timing));
}
}
diff --git a/src/main/java/com/android/tools/r8/graph/AppView.java b/src/main/java/com/android/tools/r8/graph/AppView.java
index 3d12531..3a55054 100644
--- a/src/main/java/com/android/tools/r8/graph/AppView.java
+++ b/src/main/java/com/android/tools/r8/graph/AppView.java
@@ -27,7 +27,6 @@
import com.google.common.base.Predicates;
import java.util.IdentityHashMap;
import java.util.Map;
-import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
@@ -266,8 +265,8 @@
return defaultValue;
}
- public void withGeneratedExtensionRegistryShrinker(
- Consumer<GeneratedExtensionRegistryShrinker> consumer) {
+ public <E extends Throwable> void withGeneratedExtensionRegistryShrinker(
+ ThrowingConsumer<GeneratedExtensionRegistryShrinker, E> consumer) throws E {
if (protoShrinker != null && protoShrinker.generatedExtensionRegistryShrinker != null) {
consumer.accept(protoShrinker.generatedExtensionRegistryShrinker);
}
@@ -281,7 +280,8 @@
return defaultValue;
}
- public void withGeneratedMessageLiteShrinker(Consumer<GeneratedMessageLiteShrinker> consumer) {
+ public <E extends Throwable> void withGeneratedMessageLiteShrinker(
+ ThrowingConsumer<GeneratedMessageLiteShrinker, E> consumer) throws E {
if (protoShrinker != null && protoShrinker.generatedMessageLiteShrinker != null) {
consumer.accept(protoShrinker.generatedMessageLiteShrinker);
}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
index 08b29ce..b4e025c 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
@@ -4,6 +4,7 @@
package com.android.tools.r8.ir.analysis.proto;
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
import static com.google.common.base.Predicates.not;
import com.android.tools.r8.graph.AppView;
@@ -32,6 +33,7 @@
import com.android.tools.r8.shaking.TreePrunerConfiguration;
import com.android.tools.r8.utils.DescriptorUtils;
import com.android.tools.r8.utils.FileUtils;
+import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.Timing;
import com.google.common.base.Predicates;
import com.google.common.collect.Sets;
@@ -40,6 +42,8 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@@ -156,26 +160,29 @@
return removedExtensionFields.contains(field);
}
- public void postOptimizeGeneratedExtensionRegistry(IRConverter converter, Timing timing) {
+ public void postOptimizeGeneratedExtensionRegistry(
+ IRConverter converter, ExecutorService executorService, Timing timing)
+ throws ExecutionException {
timing.begin("[Proto] Post optimize generated extension registry");
- forEachFindLiteExtensionByNumberMethod(
+ ThreadUtils.processItems(
+ this::forEachFindLiteExtensionByNumberMethod,
method ->
converter.processMethod(
method,
OptimizationFeedbackIgnore.getInstance(),
- OneTimeMethodProcessor.getInstance()));
- timing.end(); // [Proto] Post optimize generated extension registry
+ OneTimeMethodProcessor.getInstance()),
+ executorService);
+ timing.end();
}
private void forEachFindLiteExtensionByNumberMethod(Consumer<DexEncodedMethod> consumer) {
- for (DexProgramClass clazz : appView.appInfo().classes()) {
- if (clazz.superType != references.extensionRegistryLiteType) {
- continue;
- }
-
- for (DexEncodedMethod method : clazz.methods()) {
- if (references.isFindLiteExtensionByNumberMethod(method.method)) {
- consumer.accept(method);
+ for (DexType type : appView.appInfo().subtypes(references.extensionRegistryLiteType)) {
+ DexProgramClass clazz = asProgramClassOrNull(appView.definitionFor(type));
+ if (clazz != null) {
+ for (DexEncodedMethod method : clazz.methods()) {
+ if (references.isFindLiteExtensionByNumberMethod(method.method)) {
+ consumer.accept(method);
+ }
}
}
}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
index b2d813a..40b8cb5 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
@@ -4,6 +4,7 @@
package com.android.tools.r8.ir.analysis.proto;
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.getInfoValueFromMessageInfoConstructionInvoke;
import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.getObjectsValueFromMessageInfoConstructionInvoke;
import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.setObjectsValueForMessageInfoConstructionInvoke;
@@ -11,7 +12,10 @@
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexClass;
import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.ir.analysis.proto.schema.ProtoMessageInfo;
import com.android.tools.r8.ir.analysis.proto.schema.ProtoObject;
import com.android.tools.r8.ir.analysis.type.Nullability;
@@ -32,8 +36,11 @@
import com.android.tools.r8.ir.conversion.OneTimeMethodProcessor;
import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackIgnore;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.Timing;
import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
public class GeneratedMessageLiteShrinker {
@@ -70,22 +77,33 @@
}
}
- public void postOptimizeDynamicMethods(IRConverter converter, Timing timing) {
+ public void postOptimizeDynamicMethods(
+ IRConverter converter, ExecutorService executorService, Timing timing)
+ throws ExecutionException {
timing.begin("[Proto] Post optimize dynamic methods");
- forEachDynamicMethod(
+ ThreadUtils.processItems(
+ this::forEachDynamicMethod,
method ->
converter.processMethod(
method,
OptimizationFeedbackIgnore.getInstance(),
- OneTimeMethodProcessor.getInstance()));
+ OneTimeMethodProcessor.getInstance()),
+ executorService);
timing.end();
}
private void forEachDynamicMethod(Consumer<DexEncodedMethod> consumer) {
- for (DexProgramClass clazz : appView.appInfo().classes()) {
- DexEncodedMethod dynamicMethod = clazz.lookupVirtualMethod(references::isDynamicMethod);
- if (dynamicMethod != null) {
- consumer.accept(dynamicMethod);
+ DexItemFactory dexItemFactory = appView.dexItemFactory();
+ for (DexType type : appView.appInfo().subtypes(references.generatedMessageLiteType)) {
+ DexProgramClass clazz = asProgramClassOrNull(appView.definitionFor(type));
+ if (clazz != null) {
+ DexMethod dynamicMethod =
+ dexItemFactory.createMethod(
+ type, references.dynamicMethodProto, references.dynamicMethodName);
+ DexEncodedMethod encodedDynamicMethod = clazz.lookupVirtualMethod(dynamicMethod);
+ if (encodedDynamicMethod != null) {
+ consumer.accept(encodedDynamicMethod);
+ }
}
}
}
diff --git a/src/main/java/com/android/tools/r8/utils/ForEachable.java b/src/main/java/com/android/tools/r8/utils/ForEachable.java
new file mode 100644
index 0000000..4c8ae53
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/ForEachable.java
@@ -0,0 +1,11 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.utils;
+
+import java.util.function.Consumer;
+
+public interface ForEachable<T> {
+
+ void forEach(Consumer<T> consumer);
+}
diff --git a/src/main/java/com/android/tools/r8/utils/ThreadUtils.java b/src/main/java/com/android/tools/r8/utils/ThreadUtils.java
index 84aa0d0..e6076ca 100644
--- a/src/main/java/com/android/tools/r8/utils/ThreadUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/ThreadUtils.java
@@ -21,32 +21,37 @@
public static <T, R, E extends Exception> Collection<R> processItemsWithResults(
Iterable<T> items, ThrowingFunction<T, R, E> consumer, ExecutorService executorService)
throws ExecutionException {
+ return processItemsWithResults(items::forEach, consumer, executorService);
+ }
+
+ public static <T, R, E extends Exception> Collection<R> processItemsWithResults(
+ ForEachable<T> items, ThrowingFunction<T, R, E> consumer, ExecutorService executorService)
+ throws ExecutionException {
List<Future<R>> futures = new ArrayList<>();
- for (T item : items) {
- futures.add(executorService.submit(() -> consumer.apply(item)));
- }
+ items.forEach(item -> futures.add(executorService.submit(() -> consumer.apply(item))));
return awaitFuturesWithResults(futures);
}
public static <T, E extends Exception> void processItems(
Iterable<T> items, ThrowingConsumer<T, E> consumer, ExecutorService executorService)
throws ExecutionException {
- processItemsWithResults(
- items,
- arg -> {
- consumer.accept(arg);
- return null;
- },
- executorService);
+ processItems(items::forEach, consumer, executorService);
}
public static <T, U, E extends Exception> void processItems(
Map<T, U> items, ThrowingBiConsumer<T, U, E> consumer, ExecutorService executorService)
throws ExecutionException {
+ processItems(
+ items.entrySet(), arg -> consumer.accept(arg.getKey(), arg.getValue()), executorService);
+ }
+
+ public static <T, E extends Exception> void processItems(
+ ForEachable<T> items, ThrowingConsumer<T, E> consumer, ExecutorService executorService)
+ throws ExecutionException {
processItemsWithResults(
- items.entrySet(),
+ items,
arg -> {
- consumer.accept(arg.getKey(), arg.getValue());
+ consumer.accept(arg);
return null;
},
executorService);