Do not single caller inline overrides of kept methods

Bug: 130721661
Change-Id: I6c213d9543e68d488b510c25ee6655b631fcf832
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
index 162e0bb..92a127d 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
@@ -5,9 +5,12 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.conversion.CallGraph.Node;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.classhierarchy.MethodOverridesCollector;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.Sets;
 import java.util.Set;
 
@@ -55,13 +58,22 @@
     private final Set<DexMethod> doubleCallSite = Sets.newIdentityHashSet();
 
     CallGraphBasedCallSiteInformation(AppView<AppInfoWithLiveness> appView, CallGraph graph) {
+      ProgramMethodSet pinned =
+          MethodOverridesCollector.findAllMethodsAndOverridesThatMatches(
+              appView,
+              ImmediateProgramSubtypingInfo.create(appView),
+              appView.appInfo().classes(),
+              method ->
+                  appView.getKeepInfo(method).isPinned(appView.options())
+                      || appView.appInfo().isMethodTargetedByInvokeDynamic(method));
+
       for (Node node : graph.nodes) {
         ProgramMethod method = node.getProgramMethod();
         DexMethod reference = method.getReference();
 
         // For non-pinned methods and methods that override library methods we do not know the exact
         // number of call sites.
-        if (appView.appInfo().isPinned(reference)) {
+        if (pinned.contains(method)) {
           continue;
         }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorUnoptimizableMethods.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorUnoptimizableMethods.java
index 0e643f6..7ed397c 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorUnoptimizableMethods.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorUnoptimizableMethods.java
@@ -4,29 +4,18 @@
 
 package com.android.tools.r8.optimize.argumentpropagation;
 
-import static com.android.tools.r8.utils.MapUtils.ignoreKey;
-
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
-import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollectionByReference;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.UnknownMethodState;
-import com.android.tools.r8.optimize.argumentpropagation.utils.DepthFirstTopDownClassHierarchyTraversal;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.MethodSignatureEquivalence;
-import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
+import com.android.tools.r8.utils.classhierarchy.MethodOverridesCollector;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
-import com.google.common.base.Equivalence.Wrapper;
-import com.google.common.collect.Sets;
 import java.util.Collection;
-import java.util.IdentityHashMap;
-import java.util.Map;
-import java.util.Set;
-import java.util.function.Consumer;
 
 public class ArgumentPropagatorUnoptimizableMethods {
 
@@ -49,196 +38,39 @@
   //  class.
   public void disableArgumentPropagationForUnoptimizableMethods(
       Collection<DexProgramClass> stronglyConnectedComponent) {
-    ProgramMethodSet unoptimizableClassRootMethods = ProgramMethodSet.create();
-    ProgramMethodSet unoptimizableInterfaceRootMethods = ProgramMethodSet.create();
-    forEachUnoptimizableMethod(
-        stronglyConnectedComponent,
-        method -> {
-          if (method.getDefinition().belongsToVirtualPool()
-              && !method.getHolder().isFinal()
-              && !method.getAccessFlags().isFinal()) {
-            if (method.getHolder().isInterface()) {
-              unoptimizableInterfaceRootMethods.add(method);
-            } else {
-              unoptimizableClassRootMethods.add(method);
-            }
-          } else {
-            disableArgumentPropagationForMethod(method);
-          }
-        });
-
-    // Disable argument propagation for all overrides of the root methods. Since interface methods
-    // may be implemented by classes that are not a subtype of the interface that declares the
-    // interface method, we first mark the interface method overrides on such classes as ineligible
-    // for argument propagation.
-    if (!unoptimizableInterfaceRootMethods.isEmpty()) {
-      new UnoptimizableInterfaceMethodPropagator(
-              unoptimizableClassRootMethods, unoptimizableInterfaceRootMethods)
-          .run(stronglyConnectedComponent);
-    }
-
-    // At this point we can mark all overrides by a simple top-down traversal over the class
-    // hierarchy.
-    new UnoptimizableClassMethodPropagator(
-            unoptimizableClassRootMethods, unoptimizableInterfaceRootMethods)
-        .run(stronglyConnectedComponent);
+    ProgramMethodSet unoptimizableVirtualMethods =
+        MethodOverridesCollector.findAllMethodsAndOverridesThatMatches(
+            appView,
+            immediateSubtypingInfo,
+            stronglyConnectedComponent,
+            method -> {
+              if (isUnoptimizableMethod(method)) {
+                if (method.getDefinition().belongsToVirtualPool()
+                    && !method.getHolder().isFinal()
+                    && !method.getAccessFlags().isFinal()) {
+                  return true;
+                } else {
+                  disableArgumentPropagationForMethod(method);
+                }
+              }
+              return false;
+            });
+    unoptimizableVirtualMethods.forEach(this::disableArgumentPropagationForMethod);
   }
 
   private void disableArgumentPropagationForMethod(ProgramMethod method) {
     methodStates.set(method, UnknownMethodState.get());
   }
 
-  private void forEachUnoptimizableMethod(
-      Collection<DexProgramClass> stronglyConnectedComponent, Consumer<ProgramMethod> consumer) {
+  private boolean isUnoptimizableMethod(ProgramMethod method) {
+    assert !method.getDefinition().belongsToVirtualPool()
+            || !method.getDefinition().isLibraryMethodOverride().isUnknown()
+        : "Unexpected virtual method without library method override information: "
+            + method.toSourceString();
     AppInfoWithLiveness appInfo = appView.appInfo();
     InternalOptions options = appView.options();
-    for (DexProgramClass clazz : stronglyConnectedComponent) {
-      clazz.forEachProgramMethod(
-          method -> {
-            assert !method.getDefinition().belongsToVirtualPool()
-                    || !method.getDefinition().isLibraryMethodOverride().isUnknown()
-                : "Unexpected virtual method without library method override information: "
-                    + method.toSourceString();
-            if (method.getDefinition().isLibraryMethodOverride().isPossiblyTrue()
-                || appInfo.isMethodTargetedByInvokeDynamic(method)
-                || !appInfo
-                    .getKeepInfo()
-                    .getMethodInfo(method)
-                    .isArgumentPropagationAllowed(options)) {
-              consumer.accept(method);
-            }
-          });
-    }
-  }
-
-  private class UnoptimizableInterfaceMethodPropagator
-      extends DepthFirstTopDownClassHierarchyTraversal {
-
-    private final ProgramMethodSet unoptimizableClassRootMethods;
-    private final Map<DexProgramClass, Set<Wrapper<DexMethod>>> unoptimizableInterfaceMethods =
-        new IdentityHashMap<>();
-
-    UnoptimizableInterfaceMethodPropagator(
-        ProgramMethodSet unoptimizableClassRootMethods,
-        ProgramMethodSet unoptimizableInterfaceRootMethods) {
-      super(
-          ArgumentPropagatorUnoptimizableMethods.this.appView,
-          ArgumentPropagatorUnoptimizableMethods.this.immediateSubtypingInfo);
-      this.unoptimizableClassRootMethods = unoptimizableClassRootMethods;
-      unoptimizableInterfaceRootMethods.forEach(this::addUnoptimizableRootMethod);
-    }
-
-    private void addUnoptimizableRootMethod(ProgramMethod method) {
-      unoptimizableInterfaceMethods
-          .computeIfAbsent(method.getHolder(), ignoreKey(Sets::newIdentityHashSet))
-          .add(equivalence.wrap(method.getReference()));
-    }
-
-    @Override
-    public void visit(DexProgramClass clazz) {
-      Set<Wrapper<DexMethod>> unoptimizableInterfaceMethodsForClass =
-          unoptimizableInterfaceMethods.computeIfAbsent(clazz, ignoreKey(Sets::newIdentityHashSet));
-
-      // Add the unoptimizable interface methods from the parent interfaces.
-      immediateSubtypingInfo.forEachImmediateProgramSuperClass(
-          clazz,
-          superclass ->
-              unoptimizableInterfaceMethodsForClass.addAll(
-                  unoptimizableInterfaceMethods.get(superclass)));
-
-      // Propagate the unoptimizable interface methods of this interface to all immediate
-      // (non-interface) subclasses.
-      for (DexProgramClass implementer : immediateSubtypingInfo.getSubclasses(clazz)) {
-        if (implementer.isInterface()) {
-          continue;
-        }
-
-        for (Wrapper<DexMethod> unoptimizableInterfaceMethod :
-            unoptimizableInterfaceMethodsForClass) {
-          SingleResolutionResult resolutionResult =
-              appView
-                  .appInfo()
-                  .resolveMethodOnClass(unoptimizableInterfaceMethod.get(), implementer)
-                  .asSingleResolution();
-          if (resolutionResult == null || !resolutionResult.getResolvedHolder().isProgramClass()) {
-            continue;
-          }
-
-          ProgramMethod resolvedMethod = resolutionResult.getResolvedProgramMethod();
-          if (resolvedMethod.getHolder().isInterface()
-              || resolvedMethod.getHolder() == implementer) {
-            continue;
-          }
-
-          unoptimizableClassRootMethods.add(resolvedMethod);
-        }
-      }
-    }
-
-    @Override
-    public void prune(DexProgramClass clazz) {
-      unoptimizableInterfaceMethods.remove(clazz);
-    }
-
-    @Override
-    public boolean isRoot(DexProgramClass clazz) {
-      return clazz.isInterface() && super.isRoot(clazz);
-    }
-
-    @Override
-    public void forEachSubClass(DexProgramClass clazz, Consumer<DexProgramClass> consumer) {
-      for (DexProgramClass subclass : immediateSubtypingInfo.getSubclasses(clazz)) {
-        if (subclass.isInterface()) {
-          consumer.accept(subclass);
-        }
-      }
-    }
-  }
-
-  private class UnoptimizableClassMethodPropagator
-      extends DepthFirstTopDownClassHierarchyTraversal {
-
-    private final Map<DexProgramClass, DexMethodSignatureSet> unoptimizableMethods =
-        new IdentityHashMap<>();
-
-    UnoptimizableClassMethodPropagator(
-        ProgramMethodSet unoptimizableClassRootMethods,
-        ProgramMethodSet unoptimizableInterfaceRootMethods) {
-      super(
-          ArgumentPropagatorUnoptimizableMethods.this.appView,
-          ArgumentPropagatorUnoptimizableMethods.this.immediateSubtypingInfo);
-      unoptimizableClassRootMethods.forEach(this::addUnoptimizableRootMethod);
-      unoptimizableInterfaceRootMethods.forEach(this::addUnoptimizableRootMethod);
-    }
-
-    private void addUnoptimizableRootMethod(ProgramMethod method) {
-      unoptimizableMethods
-          .computeIfAbsent(method.getHolder(), ignoreKey(DexMethodSignatureSet::create))
-          .add(method);
-    }
-
-    @Override
-    public void visit(DexProgramClass clazz) {
-      DexMethodSignatureSet unoptimizableMethodsForClass =
-          unoptimizableMethods.computeIfAbsent(clazz, ignoreKey(DexMethodSignatureSet::create));
-
-      // Add the unoptimizable methods from the parent classes.
-      immediateSubtypingInfo.forEachImmediateProgramSuperClass(
-          clazz,
-          superclass -> unoptimizableMethodsForClass.addAll(unoptimizableMethods.get(superclass)));
-
-      // Disable argument propagation for the unoptimizable methods of this class.
-      clazz.forEachProgramVirtualMethod(
-          method -> {
-            if (unoptimizableMethodsForClass.contains(method)) {
-              disableArgumentPropagationForMethod(method);
-            }
-          });
-    }
-
-    @Override
-    public void prune(DexProgramClass clazz) {
-      unoptimizableMethods.remove(clazz);
-    }
+    return method.getDefinition().isLibraryMethodOverride().isPossiblyTrue()
+        || appInfo.isMethodTargetedByInvokeDynamic(method)
+        || !appInfo.getKeepInfo().getMethodInfo(method).isArgumentPropagationAllowed(options);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/classhierarchy/MethodOverridesCollector.java b/src/main/java/com/android/tools/r8/utils/classhierarchy/MethodOverridesCollector.java
new file mode 100644
index 0000000..bd63c2a
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/classhierarchy/MethodOverridesCollector.java
@@ -0,0 +1,203 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils.classhierarchy;
+
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexMethodSignature;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.optimize.argumentpropagation.utils.DepthFirstTopDownClassHierarchyTraversal;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import java.util.Collection;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
+
+/**
+ * Given a predicate, finds all methods that satisfies their methods including their overrides and
+ * siblings.
+ */
+public class MethodOverridesCollector {
+
+  public static ProgramMethodSet findAllMethodsAndOverridesThatMatches(
+      AppView<AppInfoWithLiveness> appView,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
+      Collection<DexProgramClass> stronglyConnectedComponent,
+      Predicate<ProgramMethod> predicate) {
+    ProgramMethodSet classRootMethods = ProgramMethodSet.create();
+    ProgramMethodSet interfaceRootMethods = ProgramMethodSet.create();
+
+    for (DexProgramClass clazz : stronglyConnectedComponent) {
+      clazz.forEachProgramMethod(
+          method -> {
+            if (predicate.test(method)) {
+              if (clazz.isInterface()) {
+                interfaceRootMethods.add(method);
+              } else {
+                classRootMethods.add(method);
+              }
+            }
+          });
+    }
+
+    // Since interface methods may be implemented by classes that are not a subtype of the interface
+    // that declares the interface method, we first add the interface method overrides on such
+    // classes to the classRootMethods set.
+    if (!interfaceRootMethods.isEmpty()) {
+      new InterfaceMethodToClassSiblingPropagator(
+              appView, immediateSubtypingInfo, classRootMethods, interfaceRootMethods)
+          .run(stronglyConnectedComponent);
+    }
+
+    // Mark all overrides by a simple top-down traversal over the class hierarchy.
+    TopDownClassHierarchyPropagator topDownClassHierarchyPropagator =
+        new TopDownClassHierarchyPropagator(
+            appView, immediateSubtypingInfo, classRootMethods, interfaceRootMethods);
+    topDownClassHierarchyPropagator.run(stronglyConnectedComponent);
+    return topDownClassHierarchyPropagator.getResult();
+  }
+
+  private static class InterfaceMethodToClassSiblingPropagator
+      extends DepthFirstTopDownClassHierarchyTraversal {
+
+    private final ProgramMethodSet classRootMethods;
+    private final Map<DexProgramClass, DexMethodSignatureSet> interfaceMethodsOfInterest =
+        new IdentityHashMap<>();
+
+    InterfaceMethodToClassSiblingPropagator(
+        AppView<AppInfoWithLiveness> appView,
+        ImmediateProgramSubtypingInfo immediateSubtypingInfo,
+        ProgramMethodSet classRootMethods,
+        ProgramMethodSet interfaceRootMethods) {
+      super(appView, immediateSubtypingInfo);
+      this.classRootMethods = classRootMethods;
+      for (ProgramMethod method : interfaceRootMethods) {
+        interfaceMethodsOfInterest
+            .computeIfAbsent(method.getHolder(), ignoreKey(DexMethodSignatureSet::create))
+            .add(method);
+      }
+    }
+
+    @Override
+    public void visit(DexProgramClass clazz) {
+      DexMethodSignatureSet interfaceMethodsOfInterestForClass =
+          interfaceMethodsOfInterest.computeIfAbsent(
+              clazz, ignoreKey(DexMethodSignatureSet::create));
+
+      // Add the interface methods from the parent interfaces that satisfies the given predicate.
+      immediateSubtypingInfo.forEachImmediateProgramSuperClass(
+          clazz,
+          superclass ->
+              interfaceMethodsOfInterestForClass.addAll(
+                  interfaceMethodsOfInterest.get(superclass)));
+
+      // Propagate the interface methods of interest from this interface to all immediate
+      // (non-interface) subclasses.
+      for (DexProgramClass implementer : immediateSubtypingInfo.getSubclasses(clazz)) {
+        if (implementer.isInterface()) {
+          continue;
+        }
+
+        for (DexMethodSignature interfaceMethod : interfaceMethodsOfInterestForClass) {
+          SingleResolutionResult resolutionResult =
+              appView
+                  .appInfo()
+                  .resolveMethodOnClass(interfaceMethod, implementer)
+                  .asSingleResolution();
+          if (resolutionResult == null || !resolutionResult.getResolvedHolder().isProgramClass()) {
+            continue;
+          }
+
+          ProgramMethod resolvedMethod = resolutionResult.getResolvedProgramMethod();
+          if (resolvedMethod.getHolder().isInterface()
+              || resolvedMethod.getHolder() == implementer) {
+            continue;
+          }
+
+          classRootMethods.add(resolvedMethod);
+        }
+      }
+    }
+
+    @Override
+    public void prune(DexProgramClass clazz) {
+      interfaceMethodsOfInterest.remove(clazz);
+    }
+
+    @Override
+    public boolean isRoot(DexProgramClass clazz) {
+      return clazz.isInterface() && super.isRoot(clazz);
+    }
+
+    @Override
+    public void forEachSubClass(DexProgramClass clazz, Consumer<DexProgramClass> consumer) {
+      for (DexProgramClass subclass : immediateSubtypingInfo.getSubclasses(clazz)) {
+        if (subclass.isInterface()) {
+          consumer.accept(subclass);
+        }
+      }
+    }
+  }
+
+  private static class TopDownClassHierarchyPropagator
+      extends DepthFirstTopDownClassHierarchyTraversal {
+
+    private final Map<DexProgramClass, DexMethodSignatureSet> methodsOfInterest =
+        new IdentityHashMap<>();
+
+    private final ProgramMethodSet result = ProgramMethodSet.create();
+
+    TopDownClassHierarchyPropagator(
+        AppView<AppInfoWithLiveness> appView,
+        ImmediateProgramSubtypingInfo immediateSubtypingInfo,
+        ProgramMethodSet classRootMethods,
+        ProgramMethodSet interfaceRootMethods) {
+      super(appView, immediateSubtypingInfo);
+      classRootMethods.forEach(this::addRootMethod);
+      interfaceRootMethods.forEach(this::addRootMethod);
+    }
+
+    private void addRootMethod(ProgramMethod method) {
+      methodsOfInterest
+          .computeIfAbsent(method.getHolder(), ignoreKey(DexMethodSignatureSet::create))
+          .add(method);
+    }
+
+    ProgramMethodSet getResult() {
+      return result;
+    }
+
+    @Override
+    public void visit(DexProgramClass clazz) {
+      DexMethodSignatureSet methodsOfInterestForClass =
+          methodsOfInterest.computeIfAbsent(clazz, ignoreKey(DexMethodSignatureSet::create));
+
+      // Add the methods of interest from the parent classes.
+      immediateSubtypingInfo.forEachImmediateProgramSuperClass(
+          clazz, superclass -> methodsOfInterestForClass.addAll(methodsOfInterest.get(superclass)));
+
+      // For each method on the current class that is classified as a method of interest, add the
+      // method to the result.
+      clazz.forEachProgramMethod(
+          method -> {
+            if (methodsOfInterestForClass.contains(method)) {
+              result.add(method);
+            }
+          });
+    }
+
+    @Override
+    public void prune(DexProgramClass clazz) {
+      methodsOfInterest.remove(clazz);
+    }
+  }
+}