Reenable single caller inlining for virtual methods

Change-Id: I1b1563521223eb5cb6e33a7186ff6790a6b19623
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
index 1f95169..7ee35d8 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
@@ -774,6 +774,10 @@
       return resolvedMethod.getReference();
     }
 
+    if (isMonomorphicVirtualMethod(resolvedMethod)) {
+      return resolvedMethod.getReference();
+    }
+
     if (invoke.isInvokeInterface()) {
       assert !isMonomorphicVirtualMethod(resolvedMethod);
       return getVirtualRootMethod(resolvedMethod);
@@ -781,10 +785,6 @@
 
     assert invoke.isInvokeSuper() || invoke.isInvokeVirtual();
 
-    if (isMonomorphicVirtualMethod(resolvedMethod)) {
-      return resolvedMethod.getReference();
-    }
-
     DexMethod rootMethod = getVirtualRootMethod(resolvedMethod);
     assert rootMethod != null;
     assert !isMonomorphicVirtualMethod(resolvedMethod)
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/VirtualRootMethodsAnalysisBase.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/VirtualRootMethodsAnalysisBase.java
index b5f1374..6130e36 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/VirtualRootMethodsAnalysisBase.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/VirtualRootMethodsAnalysisBase.java
@@ -3,9 +3,8 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.optimize.argumentpropagation.codescanner;
 
-import static com.android.tools.r8.utils.MapUtils.ignoreKey;
-
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClassAndMethod;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
@@ -14,8 +13,13 @@
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.collections.DexMethodSignatureMap;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.IdentityHashMap;
+import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.function.Consumer;
 
 /**
@@ -27,58 +31,89 @@
   protected static class VirtualRootMethod {
 
     private final VirtualRootMethod parent;
-    private final ProgramMethod root;
-    private final ProgramMethodSet overrides = ProgramMethodSet.create();
+    private final ProgramMethod method;
+
+    private Set<VirtualRootMethod> overrides = Collections.emptySet();
+    private List<VirtualRootMethod> siblings = Collections.emptyList();
+    private boolean mayDispatchOutsideProgram = false;
 
     VirtualRootMethod(ProgramMethod root) {
       this(root, null);
     }
 
-    VirtualRootMethod(ProgramMethod root, VirtualRootMethod parent) {
-      assert root != null;
+    VirtualRootMethod(ProgramMethod method, VirtualRootMethod parent) {
+      assert method != null;
       this.parent = parent;
-      this.root = root;
+      this.method = method;
     }
 
-    void addOverride(ProgramMethod override) {
-      assert override.getDefinition() != root.getDefinition();
-      assert override.getMethodSignature().equals(root.getMethodSignature());
+    void addOverride(VirtualRootMethod override) {
+      assert !override.getMethod().isStructurallyEqualTo(method);
+      assert override.getMethod().getReference().match(method.getReference());
+      if (overrides.isEmpty()) {
+        overrides = Sets.newIdentityHashSet();
+      }
       overrides.add(override);
       if (hasParent()) {
         getParent().addOverride(override);
       }
+      for (VirtualRootMethod sibling : siblings) {
+        sibling.addOverride(override);
+      }
+    }
+
+    void addSibling(VirtualRootMethod sibling) {
+      if (siblings.isEmpty()) {
+        siblings = new ArrayList<>(1);
+      }
+      siblings.add(sibling);
+    }
+
+    void setMayDispatchOutsideProgram() {
+      mayDispatchOutsideProgram = true;
     }
 
     boolean hasParent() {
       return parent != null;
     }
 
+    boolean hasSiblings() {
+      return !siblings.isEmpty();
+    }
+
     VirtualRootMethod getParent() {
       return parent;
     }
 
-    ProgramMethod getRoot() {
-      return root;
+    VirtualRootMethod getRoot() {
+      return hasParent() ? getParent().getRoot() : this;
     }
 
-    ProgramMethod getSingleNonAbstractMethod() {
-      ProgramMethod singleNonAbstractMethod = root.getAccessFlags().isAbstract() ? null : root;
-      for (ProgramMethod override : overrides) {
-        if (!override.getAccessFlags().isAbstract()) {
-          if (singleNonAbstractMethod != null) {
+    ProgramMethod getMethod() {
+      return method;
+    }
+
+    VirtualRootMethod getSingleDispatchTarget() {
+      assert !hasParent();
+      if (isMayDispatchOutsideProgramSet()) {
+        return null;
+      }
+      VirtualRootMethod singleDispatchTarget = isAbstract() ? null : this;
+      for (VirtualRootMethod override : overrides) {
+        if (!override.isAbstract()) {
+          if (singleDispatchTarget != null) {
             // Not a single non-abstract method.
             return null;
           }
-          singleNonAbstractMethod = override;
+          singleDispatchTarget = override;
         }
       }
-      assert singleNonAbstractMethod == null
-          || !singleNonAbstractMethod.getAccessFlags().isAbstract();
-      return singleNonAbstractMethod;
+      assert singleDispatchTarget == null || !singleDispatchTarget.isAbstract();
+      return singleDispatchTarget;
     }
 
-    void forEach(Consumer<ProgramMethod> consumer) {
-      consumer.accept(root);
+    void forEach(Consumer<VirtualRootMethod> consumer) {
+      consumer.accept(this);
       overrides.forEach(consumer);
     }
 
@@ -86,10 +121,12 @@
       return !overrides.isEmpty();
     }
 
-    boolean isInterfaceMethodWithSiblings() {
-      // TODO(b/190154391): Conservatively returns true for all interface methods, but should only
-      //  return true for those with siblings.
-      return root.getHolder().isInterface();
+    boolean isAbstract() {
+      return method.getAccessFlags().isAbstract();
+    }
+
+    boolean isMayDispatchOutsideProgramSet() {
+      return mayDispatchOutsideProgram;
     }
   }
 
@@ -122,18 +159,48 @@
           DexMethodSignatureMap<VirtualRootMethod> virtualRootMethodsForSuperclass =
               virtualRootMethodsPerClass.get(superclass);
           virtualRootMethodsForSuperclass.forEach(
-              (signature, info) ->
-                  virtualRootMethodsForClass.computeIfAbsent(
-                      signature, ignoreKey(() -> new VirtualRootMethod(info.getRoot(), info))));
+              (signature, info) -> {
+                virtualRootMethodsForClass.compute(
+                    signature,
+                    (ignore, existing) -> {
+                      if (existing == null || existing == info) {
+                        return info;
+                      } else {
+                        // We iterate the immediate supertypes in-order using
+                        // forEachImmediateProgramSuperClass. Therefore, the current method is
+                        // guaranteed to be an interface method when existing != null.
+                        assert info.getMethod().getHolder().isInterface();
+                        if (!existing.getMethod().getHolder().isInterface()) {
+                          existing.addSibling(info);
+                          info.addOverride(existing);
+                        }
+                        return existing;
+                      }
+                    });
+                if (!clazz.isInterface() && superclass.isInterface()) {
+                  DexClassAndMethod resolvedMethod =
+                      appView.appInfo().resolveMethodOnClass(clazz, signature).getResolutionPair();
+                  if (resolvedMethod != null
+                      && !resolvedMethod.isProgramMethod()
+                      && !resolvedMethod.getAccessFlags().isAbstract()) {
+                    info.setMayDispatchOutsideProgram();
+                  }
+                }
+              });
         });
     clazz.forEachProgramVirtualMethod(
-        method -> {
-          if (virtualRootMethodsForClass.containsKey(method)) {
-            virtualRootMethodsForClass.get(method).getParent().addOverride(method);
-          } else {
-            virtualRootMethodsForClass.put(method, new VirtualRootMethod(method));
-          }
-        });
+        method ->
+            virtualRootMethodsForClass.compute(
+                method,
+                (ignore, parent) -> {
+                  if (parent == null) {
+                    return new VirtualRootMethod(method);
+                  } else {
+                    VirtualRootMethod override = new VirtualRootMethod(method, parent);
+                    parent.addOverride(override);
+                    return override;
+                  }
+                }));
     return virtualRootMethodsForClass;
   }
 
@@ -147,37 +214,48 @@
           VirtualRootMethod virtualRootMethod =
               virtualRootMethodsForClass.remove(rootCandidate.getMethodSignature());
           acceptVirtualMethod(rootCandidate, virtualRootMethod);
-          if (!rootCandidate.isStructurallyEqualTo(virtualRootMethod.getRoot())) {
+          if (virtualRootMethod.hasParent()
+              || !rootCandidate.isStructurallyEqualTo(virtualRootMethod.getMethod())) {
             return;
           }
-          boolean isMonomorphicVirtualMethod =
-              !clazz.isInterface() && !virtualRootMethod.hasOverrides();
-          if (isMonomorphicVirtualMethod) {
+          if (!virtualRootMethod.hasOverrides()
+              && !virtualRootMethod.hasSiblings()
+              && !virtualRootMethod.isMayDispatchOutsideProgramSet()) {
             monomorphicVirtualRootMethods.add(rootCandidate);
           } else {
-            ProgramMethod singleNonAbstractMethod = virtualRootMethod.getSingleNonAbstractMethod();
-            if (singleNonAbstractMethod != null
-                && !virtualRootMethod.isInterfaceMethodWithSiblings()) {
+            VirtualRootMethod singleDispatchTarget = virtualRootMethod.getSingleDispatchTarget();
+            if (singleDispatchTarget != null) {
               virtualRootMethod.forEach(
                   method -> {
                     // Interface methods can have siblings and can therefore not be mapped to their
                     // unique non-abstract implementation, unless the interface method does not have
                     // any siblings.
-                    virtualRootMethods.put(
-                        method.getReference(), singleNonAbstractMethod.getReference());
+                    setRootMethod(method, virtualRootMethod, singleDispatchTarget);
                   });
-              if (!singleNonAbstractMethod.getHolder().isInterface()) {
-                monomorphicVirtualNonRootMethods.add(singleNonAbstractMethod);
+              if (!singleDispatchTarget.getMethod().getHolder().isInterface()) {
+                monomorphicVirtualNonRootMethods.add(singleDispatchTarget.getMethod());
               }
             } else {
               virtualRootMethod.forEach(
-                  method ->
-                      virtualRootMethods.put(method.getReference(), rootCandidate.getReference()));
+                  method -> setRootMethod(method, virtualRootMethod, virtualRootMethod));
             }
           }
         });
   }
 
+  private void setRootMethod(
+      VirtualRootMethod method, VirtualRootMethod currentRoot, VirtualRootMethod root) {
+    // Since the same method can have multiple roots due to interface methods, we only allow
+    // controlling the virtual root of methods that are rooted at the current root. Otherwise, we
+    // would be setting the virtual root of the same method multiple times, which could lead to
+    // non-determinism in the result (i.e., the `virtualRootMethods` map).
+    if (method.getRoot() == currentRoot) {
+      DexMethod rootReference = root.getMethod().getReference();
+      DexMethod previous = virtualRootMethods.put(method.getMethod().getReference(), rootReference);
+      assert previous == null || previous.isIdenticalTo(rootReference);
+    }
+  }
+
   protected void acceptVirtualMethod(ProgramMethod method, VirtualRootMethod virtualRootMethod) {
     // Intentionally empty.
   }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
index 675bff8..5fcd8db 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
@@ -108,10 +108,12 @@
             return;
           }
 
-          // TODO(b/190154391): We should always have an unknown or polymorphic state, but it would
-          //  be better to use a monomorphic state when the interface method is a default method
-          //  with no overrides (CF backend only). In this case, there is no need to add methodState
-          //  to interfaceState.
+          // If the method state is monomorphic, then this is an interface method with no overrides.
+          // In this case, there is no need to add methodState to interfaceState.
+          if (methodState.isMonomorphic()) {
+            return;
+          }
+
           assert methodState.isUnknown() || methodState.asConcrete().isPolymorphic();
           interfaceState.addMethodState(appView, method, methodState);
         });
diff --git a/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInliner.java b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInliner.java
index a6e1da7..70b50c9 100644
--- a/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInliner.java
+++ b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInliner.java
@@ -68,8 +68,8 @@
   }
 
   public void run(ExecutorService executorService) throws ExecutionException {
-    // TODO(b/335584013): Re-enable monomorphic method analysis.
-    ProgramMethodSet monomorphicVirtualMethods = ProgramMethodSet.empty();
+    ProgramMethodSet monomorphicVirtualMethods =
+        computeMonomorphicVirtualRootMethods(executorService);
     ProgramMethodMap<ProgramMethod> singleCallerMethods =
         new SingleCallerScanner(appView, monomorphicVirtualMethods)
             .getSingleCallerMethods(executorService);
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureMap.java b/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureMap.java
index 377be48..e4872c7 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureMap.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureMap.java
@@ -8,6 +8,7 @@
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexMethodSignature;
+import com.android.tools.r8.graph.ProgramMethod;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -143,6 +144,12 @@
     return backing.compute(key, remappingFunction);
   }
 
+  public T compute(
+      ProgramMethod method,
+      BiFunction<? super DexMethodSignature, ? super T, ? extends T> remappingFunction) {
+    return compute(method.getMethodSignature(), remappingFunction);
+  }
+
   @Override
   public T merge(
       DexMethodSignature key,
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/inliner/IllegalSingleCallerInliningOfNonAbstractSiblingTest.java b/src/test/java/com/android/tools/r8/ir/optimize/inliner/IllegalSingleCallerInliningOfNonAbstractSiblingTest.java
new file mode 100644
index 0000000..117b090
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/inliner/IllegalSingleCallerInliningOfNonAbstractSiblingTest.java
@@ -0,0 +1,78 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.optimize.inliner;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.ir.optimize.Inliner.Reason;
+import com.google.common.collect.ImmutableSet;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class IllegalSingleCallerInliningOfNonAbstractSiblingTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options -> options.testing.validInliningReasons = ImmutableSet.of(Reason.SINGLE_CALLER))
+        .enableNoHorizontalClassMergingAnnotations()
+        .enableNoVerticalClassMergingAnnotations()
+        .enableNeverClassInliningAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("A", "A");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      I i = System.currentTimeMillis() > 0 ? new B() : new C();
+      i.m();
+
+      new A().m();
+    }
+  }
+
+  interface I {
+
+    default void m() {
+      System.out.println("C");
+    }
+  }
+
+  @NeverClassInline
+  @NoHorizontalClassMerging
+  @NoVerticalClassMerging
+  static class A {
+
+    public void m() {
+      System.out.println("A");
+    }
+  }
+
+  static class B extends A implements I {}
+
+  @NoHorizontalClassMerging
+  static class C implements I {}
+}