Keep singler caller inlinees after inlining if reprocessing

Bug: b/285021603
Change-Id: I4f48c26c39f7396a0ad615b87e524e6bcbdf938f
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CallSiteInformation.java b/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CallSiteInformation.java
index c2e13af..234eeee 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CallSiteInformation.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CallSiteInformation.java
@@ -13,6 +13,8 @@
 import com.android.tools.r8.utils.classhierarchy.MethodOverridesCollector;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.Sets;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.Set;
 
 public abstract class CallSiteInformation {
@@ -23,6 +25,14 @@
    * <p>For pinned methods (methods kept through Proguard keep rules) this will always answer <code>
    * false</code>.
    */
+  public abstract boolean hasSingleCallSite(ProgramMethod context, ProgramMethod method);
+
+  /**
+   * Checks if the given method only has a single call without considering context.
+   *
+   * <p>For pinned methods (methods kept through Proguard keep rules) and methods that override a
+   * library method this always returns false.
+   */
   public abstract boolean hasSingleCallSite(ProgramMethod method);
 
   public abstract boolean isMultiCallerInlineCandidate(ProgramMethod method);
@@ -38,6 +48,11 @@
     private static final EmptyCallSiteInformation EMPTY_INFO = new EmptyCallSiteInformation();
 
     @Override
+    public boolean hasSingleCallSite(ProgramMethod context, ProgramMethod method) {
+      return false;
+    }
+
+    @Override
     public boolean hasSingleCallSite(ProgramMethod method) {
       return false;
     }
@@ -55,7 +70,9 @@
 
   static class CallGraphBasedCallSiteInformation extends CallSiteInformation {
 
-    private final Set<DexMethod> singleCallerMethods = Sets.newIdentityHashSet();
+    // Single callers track their calling context to ensure that the predicate is stable after
+    // inlining of the caller.
+    private final Map<DexMethod, DexMethod> singleCallerMethods = new HashMap<>();
     private final Set<DexMethod> multiCallerInlineCandidates = Sets.newIdentityHashSet();
 
     CallGraphBasedCallSiteInformation(AppView<AppInfoWithLiveness> appView, CallGraph graph) {
@@ -94,7 +111,10 @@
 
         int numberOfCallSites = node.getNumberOfCallSites();
         if (numberOfCallSites == 1) {
-          singleCallerMethods.add(reference);
+          Node caller = node.getCallersWithDeterministicOrder().iterator().next();
+          DexMethod existing =
+              singleCallerMethods.put(reference, caller.getMethod().getReference());
+          assert existing == null;
         } else if (numberOfCallSites > 1) {
           multiCallerInlineCandidates.add(reference);
         }
@@ -102,14 +122,25 @@
     }
 
     /**
-     * Checks if the given method only has a single call site.
+     * Checks if the given method only has a single call site with the given context.
+     *
+     * <p>For pinned methods (methods kept through Proguard keep rules) and methods that override a
+     * library method this always returns false.
+     */
+    @Override
+    public boolean hasSingleCallSite(ProgramMethod context, ProgramMethod method) {
+      return singleCallerMethods.get(method.getReference()) == context.getReference();
+    }
+
+    /**
+     * Checks if the given method only has a single call without considering context.
      *
      * <p>For pinned methods (methods kept through Proguard keep rules) and methods that override a
      * library method this always returns false.
      */
     @Override
     public boolean hasSingleCallSite(ProgramMethod method) {
-      return singleCallerMethods.contains(method.getReference());
+      return singleCallerMethods.containsKey(method.getReference());
     }
 
     /**
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/MultiCallerInliner.java b/src/main/java/com/android/tools/r8/ir/optimize/MultiCallerInliner.java
index 54c8f61..4b5f0c9 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/MultiCallerInliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/MultiCallerInliner.java
@@ -147,12 +147,11 @@
     // We track up to n call sites, where n is the size of multiCallerInliningInstructionLimits.
     if (callers.size() > multiCallerInliningInstructionLimits.length) {
       stopTrackingCallSitesForMethodIfDefinitelyIneligibleForMultiCallerInlining(
-          method, singleTarget, methodProcessor, callers);
+          singleTarget, methodProcessor, callers);
     }
   }
 
   private void stopTrackingCallSitesForMethodIfDefinitelyIneligibleForMultiCallerInlining(
-      ProgramMethod method,
       ProgramMethod singleTarget,
       MethodProcessor methodProcessor,
       ProgramMethodMultiset callers) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/inliner/DefaultInliningReasonStrategy.java b/src/main/java/com/android/tools/r8/ir/optimize/inliner/DefaultInliningReasonStrategy.java
index 537150e..1d3fb28 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/inliner/DefaultInliningReasonStrategy.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/inliner/DefaultInliningReasonStrategy.java
@@ -53,7 +53,7 @@
       // program.
       return Reason.SIMPLE;
     }
-    if (isSingleCallerInliningTarget(target)) {
+    if (isSingleCallerInliningTarget(context, target)) {
       return Reason.SINGLE_CALLER;
     }
     if (isMultiCallerInlineCandidate(invoke, target, oracle, methodProcessor)) {
@@ -64,8 +64,8 @@
     return Reason.SIMPLE;
   }
 
-  private boolean isSingleCallerInliningTarget(ProgramMethod method) {
-    if (!callSiteInformation.hasSingleCallSite(method)) {
+  private boolean isSingleCallerInliningTarget(ProgramMethod context, ProgramMethod method) {
+    if (!callSiteInformation.hasSingleCallSite(context, method)) {
       return false;
     }
     if (appView.appInfo().isNeverInlineDueToSingleCallerMethod(method)) {
diff --git a/src/test/java/com/android/tools/r8/startup/SingleCallerBridgeStartupTest.java b/src/test/java/com/android/tools/r8/startup/SingleCallerBridgeStartupTest.java
index 4a27e0f..1cb68de 100644
--- a/src/test/java/com/android/tools/r8/startup/SingleCallerBridgeStartupTest.java
+++ b/src/test/java/com/android/tools/r8/startup/SingleCallerBridgeStartupTest.java
@@ -4,6 +4,9 @@
 
 package com.android.tools.r8.startup;
 
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
@@ -12,6 +15,7 @@
 import com.android.tools.r8.startup.profile.ExternalStartupItem;
 import com.android.tools.r8.startup.profile.ExternalStartupMethod;
 import com.android.tools.r8.startup.utils.StartupTestingUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.google.common.collect.ImmutableList;
 import java.util.Collection;
 import org.junit.Test;
@@ -55,9 +59,16 @@
                     (appView, inlinee, inliningDepth) ->
                         inlinee.getMethodReference().equals(barMethod))
         .setMinApi(parameters)
+        .compile()
+        .inspect(
+            inspector -> {
+              // Assert that foo is not inlined.
+              ClassSubject A = inspector.clazz(A.class);
+              assertThat(A, isPresent());
+              assertThat(A.uniqueMethodWithOriginalName("foo"), isPresent());
+            })
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/285021603): We should not fail here.
-        .assertFailureWithErrorThatThrows(NoSuchMethodError.class);
+        .assertSuccessWithOutputLines("A::foo", "A::foo");
   }
 
   static class Main {