Optimize processing of virtual calls when classes are instantiated
Bug: b/378464445
Change-Id: I174655f1bbd4b8aa963020957a5c9f6f7a394ec1
diff --git a/src/main/java/com/android/tools/r8/graph/LookupResult.java b/src/main/java/com/android/tools/r8/graph/LookupResult.java
index a23a5a6..a3f8879 100644
--- a/src/main/java/com/android/tools/r8/graph/LookupResult.java
+++ b/src/main/java/com/android/tools/r8/graph/LookupResult.java
@@ -43,6 +43,8 @@
public abstract void forEachFailureDependency(
Consumer<? super DexEncodedMethod> methodCausingFailureConsumer);
+ public abstract boolean hasFailureDependencies();
+
public static LookupResultSuccess createResult(
Map<DexMethod, LookupMethodTarget> methodTargets,
List<LookupLambdaTarget> lambdaTargets,
@@ -114,6 +116,11 @@
methodsCausingFailure.forEach(methodCausingFailureConsumer);
}
+ @Override
+ public boolean hasFailureDependencies() {
+ return !methodsCausingFailure.isEmpty();
+ }
+
public boolean contains(DexEncodedMethod method) {
// Containment of a method in the lookup results only pertains to the method targets.
return methodTargets.containsKey(method.getReference());
@@ -230,5 +237,10 @@
Consumer<? super DexEncodedMethod> methodCausingFailureConsumer) {
// TODO: record and emit failure dependencies.
}
+
+ @Override
+ public boolean hasFailureDependencies() {
+ return false;
+ }
}
}
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 5c79367..7b78440 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -75,6 +75,7 @@
import com.android.tools.r8.graph.LookupLambdaTarget;
import com.android.tools.r8.graph.LookupMethodTarget;
import com.android.tools.r8.graph.LookupResult;
+import com.android.tools.r8.graph.LookupResult.LookupResultSuccess;
import com.android.tools.r8.graph.LookupTarget;
import com.android.tools.r8.graph.MethodResolutionResult;
import com.android.tools.r8.graph.MethodResolutionResult.FailedResolutionResult;
@@ -156,10 +157,12 @@
import com.android.tools.r8.shaking.rules.KeepAnnotationMatcher;
import com.android.tools.r8.synthesis.SyntheticItems.SynthesizingContextOracle;
import com.android.tools.r8.utils.Action;
+import com.android.tools.r8.utils.BooleanUtils;
import com.android.tools.r8.utils.Box;
import com.android.tools.r8.utils.DescriptorUtils;
import com.android.tools.r8.utils.InternalOptions;
import com.android.tools.r8.utils.IteratorUtils;
+import com.android.tools.r8.utils.ListUtils;
import com.android.tools.r8.utils.OptionalBool;
import com.android.tools.r8.utils.Pair;
import com.android.tools.r8.utils.SetUtils;
@@ -188,7 +191,6 @@
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
-import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
@@ -2972,7 +2974,7 @@
private void transitionItemsDueToNewlyInstantiatedClass(DexProgramClass clazz, Timing timing) {
// For all methods of the class, if we have seen a call, mark the method live.
// We only do this for virtual calls, as the other ones will be done directly.
- timing.time("Transition methods", () -> transitionMethodsForInstantiatedClass(clazz));
+ timing.time("Transition methods", () -> transitionMethodsForInstantiatedClass(clazz, timing));
// For all instance fields visible from the class, mark them live if we have seen a read.
timing.time("Transition fields", () -> transitionFieldsForInstantiatedClass(clazz));
// Add all dependent instance members to the workqueue.
@@ -3047,14 +3049,17 @@
private void transitionMethodsForInstantiatedLambda(LambdaDescriptor lambda) {
transitionMethodsForInstantiatedObject(
- InstantiatedObject.of(lambda), appInfo.dexItemFactory().objectType, lambda.interfaces);
+ InstantiatedObject.of(lambda),
+ appInfo.dexItemFactory().objectType,
+ lambda.interfaces,
+ Timing.empty());
}
- private void transitionMethodsForInstantiatedClass(DexProgramClass clazz) {
+ private void transitionMethodsForInstantiatedClass(DexProgramClass clazz, Timing timing) {
assert !clazz.isAnnotation();
assert !clazz.isInterface();
transitionMethodsForInstantiatedObject(
- InstantiatedObject.of(clazz), clazz.type, Collections.emptyList());
+ InstantiatedObject.of(clazz), clazz.type, Collections.emptyList(), timing);
}
/**
@@ -3065,7 +3070,7 @@
* methods are considered reachable.
*/
private void transitionMethodsForInstantiatedObject(
- InstantiatedObject instantiation, DexType type, List<DexType> interfaces) {
+ InstantiatedObject instantiation, DexType type, List<DexType> interfaces, Timing timing) {
WorkList<DexType> worklist = WorkList.newIdentityWorkList(type);
worklist.addIfNotSeen(interfaces);
while (worklist.hasNext()) {
@@ -3074,9 +3079,13 @@
classResolutionResult.forEachClassResolutionResult(
clazz -> {
if (clazz.isProgramClass()) {
- markProgramMethodOverridesAsLive(instantiation, clazz.asProgramClass());
+ timing.time(
+ "Program",
+ () -> markProgramMethodOverridesAsLive(instantiation, clazz.asProgramClass()));
} else {
- markLibraryAndClasspathMethodOverridesAsLive(instantiation, clazz);
+ timing.time(
+ "Library",
+ () -> markLibraryAndClasspathMethodOverridesAsLive(instantiation, clazz));
}
if (clazz.superType != null) {
worklist.addIfNotSeen(clazz.superType);
@@ -3097,62 +3106,78 @@
|| appInfo.isSubtype(instantiation.asClass().getType(), currentClass.type);
getReachableVirtualTargets(currentClass)
.forEach(
- (resolutionSearchKey, contexts) -> {
- Map<DexProgramClass, List<ProgramMethod>> contextsByClass = new IdentityHashMap<>();
- for (ProgramMethod context : contexts) {
- contextsByClass
- .computeIfAbsent(context.getHolder(), ignoreKey(ArrayList::new))
- .add(context);
+ (resolutionSearchKey, contexts) ->
+ resolutionSearchKey
+ .resolve(appInfo)
+ .forEachMethodResolutionResult(
+ resolutionResult -> {
+ SingleResolutionResult<?> singleResolution =
+ resolutionResult.asSingleResolution();
+ if (singleResolution == null) {
+ assert false : "Should not be null";
+ return;
+ }
+ // Lookup virtual dispatch targets only uses the given context to judge
+ // accessibility. Therefore, we simply find the first successful
+ // resolution result and then disregard all other contextual lookups.
+ LookupResultSuccess lookupResultSuccess = null;
+ for (ProgramMethod context : contexts) {
+ LookupResult lookupResult =
+ singleResolution.lookupVirtualDispatchTargets(
+ context,
+ appView,
+ instantiation,
+ definition -> keepInfo.isPinned(definition, options, appInfo));
+ if (lookupResult.isLookupResultSuccess()) {
+ lookupResultSuccess = lookupResult.asLookupResultSuccess();
+ break;
+ } else {
+ assert lookupResult.isLookupResultFailure();
+ assert !lookupResult.hasFailureDependencies();
+ }
+ }
+ if (lookupResultSuccess != null) {
+ handleVirtualDispatchTargets(lookupResultSuccess, singleResolution);
+ handleVirtualDispatchFailureDependencies(lookupResultSuccess, contexts);
+ }
+ }));
+ }
+
+ private void handleVirtualDispatchTargets(
+ LookupResult lookupResult, SingleResolutionResult<?> resolutionResult) {
+ lookupResult.forEach(
+ target ->
+ markVirtualDispatchTargetAsLive(
+ target,
+ programMethod ->
+ graphReporter.reportReachableMethodAsLive(
+ resolutionResult.getResolvedMethod().getReference(), programMethod)));
+ }
+
+ private void handleVirtualDispatchFailureDependencies(
+ LookupResultSuccess lookupResult, ProgramMethodSet contexts) {
+ if (!lookupResult.hasFailureDependencies()) {
+ return;
+ }
+ Map<DexProgramClass, List<ProgramMethod>> contextsByClass = new IdentityHashMap<>();
+ for (ProgramMethod context : contexts) {
+ contextsByClass.computeIfAbsent(context.getHolder(), ignoreKey(ArrayList::new)).add(context);
+ }
+ lookupResult.forEachFailureDependency(
+ method -> {
+ for (List<ProgramMethod> contextsWithSameHolder : contextsByClass.values()) {
+ ProgramMethod representativeContext = ListUtils.first(contextsWithSameHolder);
+ DexProgramClass clazz =
+ getProgramClassOrNull(method.getHolderType(), representativeContext);
+ if (clazz != null) {
+ failedMethodResolutionTargets.add(method.getReference());
+ for (ProgramMethod context : contextsWithSameHolder) {
+ markMethodAsTargeted(
+ new ProgramMethod(clazz, method), KeepReason.invokedFrom(context));
}
- appInfo
- .resolveMethodLegacy(resolutionSearchKey.method, resolutionSearchKey.isInterface)
- .forEachMethodResolutionResult(
- resolutionResult -> {
- SingleResolutionResult<?> singleResolution =
- resolutionResult.asSingleResolution();
- if (singleResolution == null) {
- assert false : "Should not be null";
- return;
- }
- contextsByClass.forEach(
- (contextHolder, contextsInHolder) -> {
- LookupResult lookupResult =
- singleResolution.lookupVirtualDispatchTargets(
- contextHolder,
- appView,
- (type, subTypeConsumer, lambdaConsumer) -> {
- assert appInfo.isSubtype(currentClass.type, type);
- instantiation.apply(subTypeConsumer, lambdaConsumer);
- },
- definition ->
- keepInfo.isPinned(definition, options, appInfo));
- lookupResult.forEach(
- target ->
- markVirtualDispatchTargetAsLive(
- target,
- programMethod ->
- graphReporter.reportReachableMethodAsLive(
- singleResolution
- .getResolvedMethod()
- .getReference(),
- programMethod)));
- lookupResult.forEachFailureDependency(
- method -> {
- DexProgramClass clazz =
- getProgramClassOrNull(
- method.getHolderType(), contextHolder);
- if (clazz != null) {
- failedMethodResolutionTargets.add(method.getReference());
- for (ProgramMethod context : contextsInHolder) {
- markMethodAsTargeted(
- new ProgramMethod(clazz, method),
- KeepReason.invokedFrom(context));
- }
- }
- });
- });
- });
- });
+ }
+ }
+ });
}
@SuppressWarnings("ReferenceEquality")
@@ -3653,21 +3678,15 @@
.computeIfAbsent(resolutionSearchKey, ignoreArgument(ProgramMethodSet::create))
.add(context);
- resolution
- .lookupVirtualDispatchTargets(
+ LookupResult lookupResult =
+ resolution.lookupVirtualDispatchTargets(
context,
appView,
(type, subTypeConsumer, lambdaConsumer) ->
objectAllocationInfoCollection.forEachInstantiatedSubType(
type, subTypeConsumer, lambdaConsumer, appInfo),
- definition -> keepInfo.isPinned(definition, options, appInfo))
- .forEach(
- target ->
- markVirtualDispatchTargetAsLive(
- target,
- programMethod ->
- graphReporter.reportReachableMethodAsLive(
- resolvedMethod.getReference(), programMethod)));
+ definition -> keepInfo.isPinned(definition, options, appInfo));
+ handleVirtualDispatchTargets(lookupResult, resolution);
});
return resolutionResults;
}
@@ -5912,24 +5931,35 @@
private final DexMethod method;
private final boolean isInterface;
+ private MethodResolutionResult cachedResolution;
+
private ResolutionSearchKey(DexMethod method, boolean isInterface) {
this.method = method;
this.isInterface = isInterface;
}
+ public MethodResolutionResult resolve(AppInfoWithClassHierarchy appInfo) {
+ if (cachedResolution == null) {
+ cachedResolution = appInfo.resolveMethodLegacy(method, isInterface);
+ }
+ return cachedResolution;
+ }
+
@Override
- @SuppressWarnings({"EqualsGetClass", "ReferenceEquality"})
public boolean equals(Object o) {
- if (o == null || getClass() != o.getClass()) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof ResolutionSearchKey)) {
return false;
}
ResolutionSearchKey that = (ResolutionSearchKey) o;
- return method == that.method && isInterface == that.isInterface;
+ return method.isIdenticalTo(that.method) && isInterface == that.isInterface;
}
@Override
public int hashCode() {
- return Objects.hash(method, isInterface);
+ return (method.hashCode() << 1) | BooleanUtils.intValue(isInterface);
}
}
}
diff --git a/src/main/java/com/android/tools/r8/shaking/InstantiatedObject.java b/src/main/java/com/android/tools/r8/shaking/InstantiatedObject.java
index 9c92725..e705f70 100644
--- a/src/main/java/com/android/tools/r8/shaking/InstantiatedObject.java
+++ b/src/main/java/com/android/tools/r8/shaking/InstantiatedObject.java
@@ -4,10 +4,12 @@
package com.android.tools.r8.shaking;
import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.InstantiatedSubTypeInfo;
import com.android.tools.r8.ir.desugar.LambdaDescriptor;
import java.util.function.Consumer;
-public abstract class InstantiatedObject {
+public abstract class InstantiatedObject implements InstantiatedSubTypeInfo {
public static InstantiatedObject of(DexProgramClass clazz) {
return new InstantiatedClass(clazz);
@@ -52,6 +54,14 @@
}
@Override
+ public void forEachInstantiatedSubType(
+ DexType type,
+ Consumer<DexProgramClass> subTypeConsumer,
+ Consumer<LambdaDescriptor> lambdaConsumer) {
+ subTypeConsumer.accept(clazz);
+ }
+
+ @Override
public boolean isClass() {
return true;
}
@@ -70,6 +80,14 @@
}
@Override
+ public void forEachInstantiatedSubType(
+ DexType type,
+ Consumer<DexProgramClass> subTypeConsumer,
+ Consumer<LambdaDescriptor> lambdaConsumer) {
+ lambdaConsumer.accept(lambdaDescriptor);
+ }
+
+ @Override
public boolean isLambda() {
return true;
}