Optimize processing of virtual calls when classes are instantiated

Bug: b/378464445
Change-Id: I174655f1bbd4b8aa963020957a5c9f6f7a394ec1
diff --git a/src/main/java/com/android/tools/r8/graph/LookupResult.java b/src/main/java/com/android/tools/r8/graph/LookupResult.java
index a23a5a6..a3f8879 100644
--- a/src/main/java/com/android/tools/r8/graph/LookupResult.java
+++ b/src/main/java/com/android/tools/r8/graph/LookupResult.java
@@ -43,6 +43,8 @@
   public abstract void forEachFailureDependency(
       Consumer<? super DexEncodedMethod> methodCausingFailureConsumer);
 
+  public abstract boolean hasFailureDependencies();
+
   public static LookupResultSuccess createResult(
       Map<DexMethod, LookupMethodTarget> methodTargets,
       List<LookupLambdaTarget> lambdaTargets,
@@ -114,6 +116,11 @@
       methodsCausingFailure.forEach(methodCausingFailureConsumer);
     }
 
+    @Override
+    public boolean hasFailureDependencies() {
+      return !methodsCausingFailure.isEmpty();
+    }
+
     public boolean contains(DexEncodedMethod method) {
       // Containment of a method in the lookup results only pertains to the method targets.
       return methodTargets.containsKey(method.getReference());
@@ -230,5 +237,10 @@
         Consumer<? super DexEncodedMethod> methodCausingFailureConsumer) {
       // TODO: record and emit failure dependencies.
     }
+
+    @Override
+    public boolean hasFailureDependencies() {
+      return false;
+    }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 5c79367..7b78440 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -75,6 +75,7 @@
 import com.android.tools.r8.graph.LookupLambdaTarget;
 import com.android.tools.r8.graph.LookupMethodTarget;
 import com.android.tools.r8.graph.LookupResult;
+import com.android.tools.r8.graph.LookupResult.LookupResultSuccess;
 import com.android.tools.r8.graph.LookupTarget;
 import com.android.tools.r8.graph.MethodResolutionResult;
 import com.android.tools.r8.graph.MethodResolutionResult.FailedResolutionResult;
@@ -156,10 +157,12 @@
 import com.android.tools.r8.shaking.rules.KeepAnnotationMatcher;
 import com.android.tools.r8.synthesis.SyntheticItems.SynthesizingContextOracle;
 import com.android.tools.r8.utils.Action;
+import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.Box;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.IteratorUtils;
+import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.OptionalBool;
 import com.android.tools.r8.utils.Pair;
 import com.android.tools.r8.utils.SetUtils;
@@ -188,7 +191,6 @@
 import java.util.List;
 import java.util.ListIterator;
 import java.util.Map;
-import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
@@ -2972,7 +2974,7 @@
   private void transitionItemsDueToNewlyInstantiatedClass(DexProgramClass clazz, Timing timing) {
     // For all methods of the class, if we have seen a call, mark the method live.
     // We only do this for virtual calls, as the other ones will be done directly.
-    timing.time("Transition methods", () -> transitionMethodsForInstantiatedClass(clazz));
+    timing.time("Transition methods", () -> transitionMethodsForInstantiatedClass(clazz, timing));
     // For all instance fields visible from the class, mark them live if we have seen a read.
     timing.time("Transition fields", () -> transitionFieldsForInstantiatedClass(clazz));
     // Add all dependent instance members to the workqueue.
@@ -3047,14 +3049,17 @@
 
   private void transitionMethodsForInstantiatedLambda(LambdaDescriptor lambda) {
     transitionMethodsForInstantiatedObject(
-        InstantiatedObject.of(lambda), appInfo.dexItemFactory().objectType, lambda.interfaces);
+        InstantiatedObject.of(lambda),
+        appInfo.dexItemFactory().objectType,
+        lambda.interfaces,
+        Timing.empty());
   }
 
-  private void transitionMethodsForInstantiatedClass(DexProgramClass clazz) {
+  private void transitionMethodsForInstantiatedClass(DexProgramClass clazz, Timing timing) {
     assert !clazz.isAnnotation();
     assert !clazz.isInterface();
     transitionMethodsForInstantiatedObject(
-        InstantiatedObject.of(clazz), clazz.type, Collections.emptyList());
+        InstantiatedObject.of(clazz), clazz.type, Collections.emptyList(), timing);
   }
 
   /**
@@ -3065,7 +3070,7 @@
    * methods are considered reachable.
    */
   private void transitionMethodsForInstantiatedObject(
-      InstantiatedObject instantiation, DexType type, List<DexType> interfaces) {
+      InstantiatedObject instantiation, DexType type, List<DexType> interfaces, Timing timing) {
     WorkList<DexType> worklist = WorkList.newIdentityWorkList(type);
     worklist.addIfNotSeen(interfaces);
     while (worklist.hasNext()) {
@@ -3074,9 +3079,13 @@
       classResolutionResult.forEachClassResolutionResult(
           clazz -> {
             if (clazz.isProgramClass()) {
-              markProgramMethodOverridesAsLive(instantiation, clazz.asProgramClass());
+              timing.time(
+                  "Program",
+                  () -> markProgramMethodOverridesAsLive(instantiation, clazz.asProgramClass()));
             } else {
-              markLibraryAndClasspathMethodOverridesAsLive(instantiation, clazz);
+              timing.time(
+                  "Library",
+                  () -> markLibraryAndClasspathMethodOverridesAsLive(instantiation, clazz));
             }
             if (clazz.superType != null) {
               worklist.addIfNotSeen(clazz.superType);
@@ -3097,62 +3106,78 @@
         || appInfo.isSubtype(instantiation.asClass().getType(), currentClass.type);
     getReachableVirtualTargets(currentClass)
         .forEach(
-            (resolutionSearchKey, contexts) -> {
-              Map<DexProgramClass, List<ProgramMethod>> contextsByClass = new IdentityHashMap<>();
-              for (ProgramMethod context : contexts) {
-                contextsByClass
-                    .computeIfAbsent(context.getHolder(), ignoreKey(ArrayList::new))
-                    .add(context);
+            (resolutionSearchKey, contexts) ->
+                resolutionSearchKey
+                    .resolve(appInfo)
+                    .forEachMethodResolutionResult(
+                        resolutionResult -> {
+                          SingleResolutionResult<?> singleResolution =
+                              resolutionResult.asSingleResolution();
+                          if (singleResolution == null) {
+                            assert false : "Should not be null";
+                            return;
+                          }
+                          // Lookup virtual dispatch targets only uses the given context to judge
+                          // accessibility. Therefore, we simply find the first successful
+                          // resolution result and then disregard all other contextual lookups.
+                          LookupResultSuccess lookupResultSuccess = null;
+                          for (ProgramMethod context : contexts) {
+                            LookupResult lookupResult =
+                                singleResolution.lookupVirtualDispatchTargets(
+                                    context,
+                                    appView,
+                                    instantiation,
+                                    definition -> keepInfo.isPinned(definition, options, appInfo));
+                            if (lookupResult.isLookupResultSuccess()) {
+                              lookupResultSuccess = lookupResult.asLookupResultSuccess();
+                              break;
+                            } else {
+                              assert lookupResult.isLookupResultFailure();
+                              assert !lookupResult.hasFailureDependencies();
+                            }
+                          }
+                          if (lookupResultSuccess != null) {
+                            handleVirtualDispatchTargets(lookupResultSuccess, singleResolution);
+                            handleVirtualDispatchFailureDependencies(lookupResultSuccess, contexts);
+                          }
+                        }));
+  }
+
+  private void handleVirtualDispatchTargets(
+      LookupResult lookupResult, SingleResolutionResult<?> resolutionResult) {
+    lookupResult.forEach(
+        target ->
+            markVirtualDispatchTargetAsLive(
+                target,
+                programMethod ->
+                    graphReporter.reportReachableMethodAsLive(
+                        resolutionResult.getResolvedMethod().getReference(), programMethod)));
+  }
+
+  private void handleVirtualDispatchFailureDependencies(
+      LookupResultSuccess lookupResult, ProgramMethodSet contexts) {
+    if (!lookupResult.hasFailureDependencies()) {
+      return;
+    }
+    Map<DexProgramClass, List<ProgramMethod>> contextsByClass = new IdentityHashMap<>();
+    for (ProgramMethod context : contexts) {
+      contextsByClass.computeIfAbsent(context.getHolder(), ignoreKey(ArrayList::new)).add(context);
+    }
+    lookupResult.forEachFailureDependency(
+        method -> {
+          for (List<ProgramMethod> contextsWithSameHolder : contextsByClass.values()) {
+            ProgramMethod representativeContext = ListUtils.first(contextsWithSameHolder);
+            DexProgramClass clazz =
+                getProgramClassOrNull(method.getHolderType(), representativeContext);
+            if (clazz != null) {
+              failedMethodResolutionTargets.add(method.getReference());
+              for (ProgramMethod context : contextsWithSameHolder) {
+                markMethodAsTargeted(
+                    new ProgramMethod(clazz, method), KeepReason.invokedFrom(context));
               }
-              appInfo
-                  .resolveMethodLegacy(resolutionSearchKey.method, resolutionSearchKey.isInterface)
-                  .forEachMethodResolutionResult(
-                      resolutionResult -> {
-                        SingleResolutionResult<?> singleResolution =
-                            resolutionResult.asSingleResolution();
-                        if (singleResolution == null) {
-                          assert false : "Should not be null";
-                          return;
-                        }
-                        contextsByClass.forEach(
-                            (contextHolder, contextsInHolder) -> {
-                              LookupResult lookupResult =
-                                  singleResolution.lookupVirtualDispatchTargets(
-                                      contextHolder,
-                                      appView,
-                                      (type, subTypeConsumer, lambdaConsumer) -> {
-                                        assert appInfo.isSubtype(currentClass.type, type);
-                                        instantiation.apply(subTypeConsumer, lambdaConsumer);
-                                      },
-                                      definition ->
-                                          keepInfo.isPinned(definition, options, appInfo));
-                              lookupResult.forEach(
-                                  target ->
-                                      markVirtualDispatchTargetAsLive(
-                                          target,
-                                          programMethod ->
-                                              graphReporter.reportReachableMethodAsLive(
-                                                  singleResolution
-                                                      .getResolvedMethod()
-                                                      .getReference(),
-                                                  programMethod)));
-                              lookupResult.forEachFailureDependency(
-                                  method -> {
-                                    DexProgramClass clazz =
-                                        getProgramClassOrNull(
-                                            method.getHolderType(), contextHolder);
-                                    if (clazz != null) {
-                                      failedMethodResolutionTargets.add(method.getReference());
-                                      for (ProgramMethod context : contextsInHolder) {
-                                        markMethodAsTargeted(
-                                            new ProgramMethod(clazz, method),
-                                            KeepReason.invokedFrom(context));
-                                      }
-                                    }
-                                  });
-                            });
-                      });
-            });
+            }
+          }
+        });
   }
 
   @SuppressWarnings("ReferenceEquality")
@@ -3653,21 +3678,15 @@
               .computeIfAbsent(resolutionSearchKey, ignoreArgument(ProgramMethodSet::create))
               .add(context);
 
-          resolution
-              .lookupVirtualDispatchTargets(
+          LookupResult lookupResult =
+              resolution.lookupVirtualDispatchTargets(
                   context,
                   appView,
                   (type, subTypeConsumer, lambdaConsumer) ->
                       objectAllocationInfoCollection.forEachInstantiatedSubType(
                           type, subTypeConsumer, lambdaConsumer, appInfo),
-                  definition -> keepInfo.isPinned(definition, options, appInfo))
-              .forEach(
-                  target ->
-                      markVirtualDispatchTargetAsLive(
-                          target,
-                          programMethod ->
-                              graphReporter.reportReachableMethodAsLive(
-                                  resolvedMethod.getReference(), programMethod)));
+                  definition -> keepInfo.isPinned(definition, options, appInfo));
+          handleVirtualDispatchTargets(lookupResult, resolution);
         });
     return resolutionResults;
   }
@@ -5912,24 +5931,35 @@
     private final DexMethod method;
     private final boolean isInterface;
 
+    private MethodResolutionResult cachedResolution;
+
     private ResolutionSearchKey(DexMethod method, boolean isInterface) {
       this.method = method;
       this.isInterface = isInterface;
     }
 
+    public MethodResolutionResult resolve(AppInfoWithClassHierarchy appInfo) {
+      if (cachedResolution == null) {
+        cachedResolution = appInfo.resolveMethodLegacy(method, isInterface);
+      }
+      return cachedResolution;
+    }
+
     @Override
-    @SuppressWarnings({"EqualsGetClass", "ReferenceEquality"})
     public boolean equals(Object o) {
-      if (o == null || getClass() != o.getClass()) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof ResolutionSearchKey)) {
         return false;
       }
       ResolutionSearchKey that = (ResolutionSearchKey) o;
-      return method == that.method && isInterface == that.isInterface;
+      return method.isIdenticalTo(that.method) && isInterface == that.isInterface;
     }
 
     @Override
     public int hashCode() {
-      return Objects.hash(method, isInterface);
+      return (method.hashCode() << 1) | BooleanUtils.intValue(isInterface);
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/InstantiatedObject.java b/src/main/java/com/android/tools/r8/shaking/InstantiatedObject.java
index 9c92725..e705f70 100644
--- a/src/main/java/com/android/tools/r8/shaking/InstantiatedObject.java
+++ b/src/main/java/com/android/tools/r8/shaking/InstantiatedObject.java
@@ -4,10 +4,12 @@
 package com.android.tools.r8.shaking;
 
 import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.InstantiatedSubTypeInfo;
 import com.android.tools.r8.ir.desugar.LambdaDescriptor;
 import java.util.function.Consumer;
 
-public abstract class InstantiatedObject {
+public abstract class InstantiatedObject implements InstantiatedSubTypeInfo {
 
   public static InstantiatedObject of(DexProgramClass clazz) {
     return new InstantiatedClass(clazz);
@@ -52,6 +54,14 @@
     }
 
     @Override
+    public void forEachInstantiatedSubType(
+        DexType type,
+        Consumer<DexProgramClass> subTypeConsumer,
+        Consumer<LambdaDescriptor> lambdaConsumer) {
+      subTypeConsumer.accept(clazz);
+    }
+
+    @Override
     public boolean isClass() {
       return true;
     }
@@ -70,6 +80,14 @@
     }
 
     @Override
+    public void forEachInstantiatedSubType(
+        DexType type,
+        Consumer<DexProgramClass> subTypeConsumer,
+        Consumer<LambdaDescriptor> lambdaConsumer) {
+      lambdaConsumer.accept(lambdaDescriptor);
+    }
+
+    @Override
     public boolean isLambda() {
       return true;
     }