Don't remove arguments for interface invokes dispatching to library

Bug: 202074964
Change-Id: Idbc706002d69f3330ee259fea83ef3b89ba9b5ff
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
index 645bac0..080e691 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.optimize.argumentpropagation;
 
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexMethodSignature;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -20,11 +21,15 @@
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
 import com.google.common.collect.Sets;
+import java.util.IdentityHashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.function.BiConsumer;
 
 /** Optimization that propagates information about arguments from call sites to method entries. */
 public class ArgumentPropagator {
@@ -144,14 +149,26 @@
     timing.end();
 
     // Set the optimization info on each method.
+    Map<Set<DexProgramClass>, DexMethodSignatureSet> interfaceDispatchOutsideProgram =
+        new IdentityHashMap<>();
     populateParameterOptimizationInfo(
-        immediateSubtypingInfo, stronglyConnectedProgramComponents, executorService, timing);
+        immediateSubtypingInfo,
+        stronglyConnectedProgramComponents,
+        (stronglyConnectedProgramComponent, signature) -> {
+          interfaceDispatchOutsideProgram
+              .computeIfAbsent(
+                  stronglyConnectedProgramComponent, (unused) -> DexMethodSignatureSet.create())
+              .add(signature);
+        },
+        executorService,
+        timing);
 
     // Using the computed optimization info, build a graph lens that describes the mapping from
     // methods with constant parameters to methods with the constant parameters removed.
     Set<DexProgramClass> affectedClasses = Sets.newConcurrentHashSet();
     ArgumentPropagatorGraphLens graphLens =
-        new ArgumentPropagatorProgramOptimizer(appView, immediateSubtypingInfo)
+        new ArgumentPropagatorProgramOptimizer(
+                appView, immediateSubtypingInfo, interfaceDispatchOutsideProgram)
             .run(stronglyConnectedProgramComponents, affectedClasses::add, executorService, timing);
 
     // Find all the code objects that need reprocessing.
@@ -174,6 +191,7 @@
   private void populateParameterOptimizationInfo(
       ImmediateProgramSubtypingInfo immediateSubtypingInfo,
       List<Set<DexProgramClass>> stronglyConnectedProgramComponents,
+      BiConsumer<Set<DexProgramClass>, DexMethodSignature> interfaceDispatchOutsideProgram,
       ExecutorService executorService,
       Timing timing)
       throws ExecutionException {
@@ -189,7 +207,8 @@
             immediateSubtypingInfo,
             codeScannerResult,
             reprocessingCriteriaCollection,
-            stronglyConnectedProgramComponents)
+            stronglyConnectedProgramComponents,
+            interfaceDispatchOutsideProgram)
         .populateOptimizationInfo(executorService, timing);
     reprocessingCriteriaCollection = null;
     timing.end();
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
index 7118043..5470ed5 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
@@ -7,6 +7,7 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexMethodSignature;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
@@ -42,6 +43,7 @@
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.function.BiConsumer;
 import java.util.stream.IntStream;
 
 /**
@@ -58,17 +60,22 @@
   private final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
   private final List<Set<DexProgramClass>> stronglyConnectedProgramComponents;
 
+  private final BiConsumer<Set<DexProgramClass>, DexMethodSignature>
+      interfaceDispatchOutsideProgram;
+
   ArgumentPropagatorOptimizationInfoPopulator(
       AppView<AppInfoWithLiveness> appView,
       ImmediateProgramSubtypingInfo immediateSubtypingInfo,
       MethodStateCollectionByReference methodStates,
       ArgumentPropagatorReprocessingCriteriaCollection reprocessingCriteriaCollection,
-      List<Set<DexProgramClass>> stronglyConnectedProgramComponents) {
+      List<Set<DexProgramClass>> stronglyConnectedProgramComponents,
+      BiConsumer<Set<DexProgramClass>, DexMethodSignature> interfaceDispatchOutsideProgram) {
     this.appView = appView;
     this.immediateSubtypingInfo = immediateSubtypingInfo;
     this.methodStates = methodStates;
     this.reprocessingCriteriaCollection = reprocessingCriteriaCollection;
     this.stronglyConnectedProgramComponents = stronglyConnectedProgramComponents;
+    this.interfaceDispatchOutsideProgram = interfaceDispatchOutsideProgram;
   }
 
   /**
@@ -114,7 +121,12 @@
     //
     // To handle this we first propagate any argument information stored for I.m() to A.m() by doing
     // a top-down traversal over the interfaces in the strongly connected component.
-    new InterfaceMethodArgumentPropagator(appView, immediateSubtypingInfo, methodStates)
+    new InterfaceMethodArgumentPropagator(
+            appView,
+            immediateSubtypingInfo,
+            methodStates,
+            signature ->
+                interfaceDispatchOutsideProgram.accept(stronglyConnectedComponent, signature))
         .run(stronglyConnectedComponent);
 
     // Now all the argument information for a given method is guaranteed to be stored on a supertype
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorProgramOptimizer.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorProgramOptimizer.java
index 57611a3..51120e1 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorProgramOptimizer.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorProgramOptimizer.java
@@ -100,14 +100,18 @@
 
   private final AppView<AppInfoWithLiveness> appView;
   private final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
+  private final Map<Set<DexProgramClass>, DexMethodSignatureSet> interfaceDispatchOutsideProgram;
 
   private final Map<DexClass, DexMethodSignatureSet> libraryVirtualMethods =
       new ConcurrentHashMap<>();
 
   public ArgumentPropagatorProgramOptimizer(
-      AppView<AppInfoWithLiveness> appView, ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
+      AppView<AppInfoWithLiveness> appView,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
+      Map<Set<DexProgramClass>, DexMethodSignatureSet> interfaceDispatchOutsideProgram) {
     this.appView = appView;
     this.immediateSubtypingInfo = immediateSubtypingInfo;
+    this.interfaceDispatchOutsideProgram = interfaceDispatchOutsideProgram;
   }
 
   public ArgumentPropagatorGraphLens run(
@@ -121,7 +125,12 @@
         ThreadUtils.processItemsWithResults(
             stronglyConnectedProgramComponents,
             classes ->
-                new StronglyConnectedComponentOptimizer().optimize(classes, affectedClassConsumer),
+                new StronglyConnectedComponentOptimizer()
+                    .optimize(
+                        classes,
+                        interfaceDispatchOutsideProgram.getOrDefault(
+                            classes, DexMethodSignatureSet.empty()),
+                        affectedClassConsumer),
             executorService);
     timing.end();
 
@@ -204,8 +213,9 @@
     //  similarly to the way we deal with call chains in argument propagation. If a field is only
     //  assigned the parameter of a given method, we would add the flow constraint "parameter p ->
     //  field f".
-    public ArgumentPropagatorGraphLens.Builder optimize(
+    private ArgumentPropagatorGraphLens.Builder optimize(
         Set<DexProgramClass> stronglyConnectedProgramClasses,
+        DexMethodSignatureSet interfaceDispatchOutsideProgram,
         Consumer<DexProgramClass> affectedClassConsumer) {
       // First reserve pinned method signatures.
       reservePinnedMethodSignatures(stronglyConnectedProgramClasses);
@@ -213,7 +223,8 @@
       // To ensure that we preserve the overriding relationships between methods, we only remove a
       // constant or unused parameter from a virtual method when it can be removed from all other
       // virtual methods in the component with the same method signature.
-      computePrototypeChangesForVirtualMethods(stronglyConnectedProgramClasses);
+      computePrototypeChangesForVirtualMethods(
+          stronglyConnectedProgramClasses, interfaceDispatchOutsideProgram);
 
       // Build a graph lens while visiting the classes in the component.
       // TODO(b/190154391): Consider visiting the interfaces first, and then processing the
@@ -225,7 +236,7 @@
       stronglyConnectedProgramClassesWithDeterministicOrder.sort(
           Comparator.comparing(DexClass::getType));
       for (DexProgramClass clazz : stronglyConnectedProgramClassesWithDeterministicOrder) {
-        if (visitClass(clazz, partialGraphLensBuilder)) {
+        if (visitClass(clazz, interfaceDispatchOutsideProgram, partialGraphLensBuilder)) {
           affectedClassConsumer.accept(clazz);
         }
       }
@@ -276,7 +287,8 @@
     }
 
     private void computePrototypeChangesForVirtualMethods(
-        Set<DexProgramClass> stronglyConnectedProgramClasses) {
+        Set<DexProgramClass> stronglyConnectedProgramClasses,
+        DexMethodSignatureSet interfaceDispatchOutsideProgram) {
       // Group the virtual methods in the component by their signatures.
       Map<DexMethodSignature, ProgramMethodSet> virtualMethodsBySignature =
           computeVirtualMethodsBySignature(stronglyConnectedProgramClasses);
@@ -284,7 +296,9 @@
           (signature, methods) -> {
             // Check that there are no keep rules that prohibit prototype changes from any of the
             // methods.
-            if (Iterables.any(methods, method -> !isPrototypeChangesAllowed(method))) {
+            if (Iterables.any(
+                methods,
+                method -> !isPrototypeChangesAllowed(method, interfaceDispatchOutsideProgram))) {
               return;
             }
 
@@ -340,11 +354,13 @@
       return virtualMethodsBySignature;
     }
 
-    private boolean isPrototypeChangesAllowed(ProgramMethod method) {
+    private boolean isPrototypeChangesAllowed(
+        ProgramMethod method, DexMethodSignatureSet interfaceDispatchOutsideProgram) {
       return appView.getKeepInfo(method).isParameterRemovalAllowed(options)
           && !method.getDefinition().isLibraryMethodOverride().isPossiblyTrue()
           && !appView.appInfo().isBootstrapMethod(method)
-          && !appView.appInfo().isMethodTargetedByInvokeDynamic(method);
+          && !appView.appInfo().isMethodTargetedByInvokeDynamic(method)
+          && !interfaceDispatchOutsideProgram.contains(method);
     }
 
     private SingleValue getReturnValueForVirtualMethods(
@@ -416,7 +432,9 @@
 
     // Returns true if the class was changed as a result of argument propagation.
     private boolean visitClass(
-        DexProgramClass clazz, ArgumentPropagatorGraphLens.Builder partialGraphLensBuilder) {
+        DexProgramClass clazz,
+        DexMethodSignatureSet interfaceDispatchOutsideProgram,
+        ArgumentPropagatorGraphLens.Builder partialGraphLensBuilder) {
       BooleanBox affected = new BooleanBox();
       DexMethodSignatureSet instanceInitializerSignatures = DexMethodSignatureSet.create();
       clazz.forEachProgramInstanceInitializer(instanceInitializerSignatures::add);
@@ -424,7 +442,8 @@
           method -> {
             RewrittenPrototypeDescription prototypeChanges =
                 method.getDefinition().belongsToDirectPool()
-                    ? computePrototypeChangesForDirectMethod(method, instanceInitializerSignatures)
+                    ? computePrototypeChangesForDirectMethod(
+                        method, interfaceDispatchOutsideProgram, instanceInitializerSignatures)
                     : computePrototypeChangesForVirtualMethod(method);
             DexMethod newMethodSignature = getNewMethodSignature(method, prototypeChanges);
             if (newMethodSignature != method.getReference()) {
@@ -518,9 +537,11 @@
     }
 
     private RewrittenPrototypeDescription computePrototypeChangesForDirectMethod(
-        ProgramMethod method, DexMethodSignatureSet instanceInitializerSignatures) {
+        ProgramMethod method,
+        DexMethodSignatureSet interfaceDispatchOutsideProgram,
+        DexMethodSignatureSet instanceInitializerSignatures) {
       assert method.getDefinition().belongsToDirectPool();
-      if (!isPrototypeChangesAllowed(method)) {
+      if (!isPrototypeChangesAllowed(method, interfaceDispatchOutsideProgram)) {
         return RewrittenPrototypeDescription.none();
       }
       // TODO(b/199864962): Allow parameter removal from check-not-null classified methods.
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
index e1e5f63..0b617b1 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexMethodSignature;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.MethodResolutionResult;
@@ -43,12 +44,15 @@
   // methods) on the seen but not finished interfaces.
   final Map<DexProgramClass, MethodStateCollectionBySignature> methodStatesToPropagate =
       new IdentityHashMap<>();
+  final Consumer<DexMethodSignature> interfaceDispatchOutsideProgram;
 
   public InterfaceMethodArgumentPropagator(
       AppView<AppInfoWithLiveness> appView,
       ImmediateProgramSubtypingInfo immediateSubtypingInfo,
-      MethodStateCollectionByReference methodStates) {
+      MethodStateCollectionByReference methodStates,
+      Consumer<DexMethodSignature> interfaceDispatchOutsideProgram) {
     super(appView, immediateSubtypingInfo, methodStates);
+    this.interfaceDispatchOutsideProgram = interfaceDispatchOutsideProgram;
   }
 
   @Override
@@ -135,6 +139,12 @@
                     return;
                   }
 
+                  assert resolutionResult.isSingleResolution();
+                  if (!resolutionResult.getResolutionPair().isProgramMethod()) {
+                    interfaceDispatchOutsideProgram.accept(interfaceMethod);
+                    return;
+                  }
+
                   ProgramMethod resolvedMethod = resolutionResult.getResolvedProgramMethod();
                   if (resolvedMethod == null
                       || resolvedMethod.getHolder().isInterface()
diff --git a/src/test/java/com/android/tools/r8/optimize/argumentpropagation/UpwardsInterfacePropagationToLibraryOrClasspathMethodTest.java b/src/test/java/com/android/tools/r8/optimize/argumentpropagation/UpwardsInterfacePropagationToLibraryOrClasspathMethodTest.java
index f6c0782..aa57ae6 100644
--- a/src/test/java/com/android/tools/r8/optimize/argumentpropagation/UpwardsInterfacePropagationToLibraryOrClasspathMethodTest.java
+++ b/src/test/java/com/android/tools/r8/optimize/argumentpropagation/UpwardsInterfacePropagationToLibraryOrClasspathMethodTest.java
@@ -4,8 +4,13 @@
 
 package com.android.tools.r8.optimize.argumentpropagation;
 
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoHorizontalClassMerging;
 import com.android.tools.r8.NoVerticalClassMerging;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
@@ -48,10 +53,17 @@
   }
 
   private static final String EXPECTED_OUTPUT =
-      StringUtils.lines("LibraryClass::libraryMethod(false)");
+      StringUtils.lines("LibraryClass::libraryMethod(false)", "ProgramClass2::libraryMethod(true)");
   private static final List<Class<?>> LIBRARY_CLASSES = ImmutableList.of(LibraryClass.class);
   private static final List<Class<?>> PROGRAM_CLASSES =
-      ImmutableList.of(ProgramClass.class, Delegate.class, Delegater.class, TestClass.class);
+      ImmutableList.of(
+          ProgramClass.class,
+          Delegate.class,
+          Delegater.class,
+          AnotherProgramClass.class,
+          AnotherDelegate.class,
+          AnotherDelegator.class,
+          TestClass.class);
 
   @Test
   public void testRuntime() throws Exception {
@@ -78,6 +90,7 @@
         .addKeepMainRule(TestClass.class)
         .setMinApi(parameters.getApiLevel())
         .enableNoVerticalClassMergingAnnotations()
+        .enableNoHorizontalClassMergingAnnotations()
         .enableNeverClassInliningAnnotations()
         .enableInliningAnnotations()
         .addHorizontallyMergedClassesInspector(
@@ -85,8 +98,18 @@
         .compile()
         .addRunClasspathClasses(LibraryClass.class)
         .run(parameters.getRuntime(), TestClass.class)
-        // TODO(b/202074964): This should not fail.
-        .assertFailureWithErrorThatThrows(AbstractMethodError.class);
+        .inspect(
+            inspector -> {
+              assertThat(
+                  inspector.clazz(Delegate.class).method("void", "libraryMethod", "boolean"),
+                  isPresent());
+              // Check that boolean argument to libraryMethod was removed for AnotherProgramClass.
+              inspector
+                  .clazz(AnotherProgramClass.class)
+                  .forAllMethods(
+                      method -> assertEquals(method.getFinalSignature().toDescriptor(), "()V"));
+            })
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
   public static class LibraryClass {
@@ -115,6 +138,7 @@
   }
 
   @NoVerticalClassMerging
+  @NoHorizontalClassMerging
   @NeverClassInline
   public static class Delegater {
     Delegate delegate;
@@ -128,10 +152,51 @@
     }
   }
 
+  @NoVerticalClassMerging
+  @NoHorizontalClassMerging
+  public interface AnotherDelegate {
+    void libraryMethod(boolean visible);
+  }
+
+  @NeverClassInline
+  @NoHorizontalClassMerging
+  public static class AnotherProgramClass implements AnotherDelegate {
+    AnotherDelegator delegater;
+
+    public AnotherProgramClass() {
+      delegater = new AnotherDelegator(this);
+    }
+
+    @NeverInline
+    public void libraryMethod(boolean visible) {
+      System.out.println("ProgramClass2::libraryMethod(" + visible + ")");
+    }
+
+    @NeverInline
+    public void m() {
+      delegater.m();
+    }
+  }
+
+  @NoVerticalClassMerging
+  @NeverClassInline
+  public static class AnotherDelegator {
+    AnotherDelegate delegate;
+
+    AnotherDelegator(AnotherDelegate delegate) {
+      this.delegate = delegate;
+    }
+
+    public void m() {
+      delegate.libraryMethod(true);
+    }
+  }
+
   public static class TestClass {
 
     public static void main(String[] args) {
       new ProgramClass().m();
+      new AnotherProgramClass().m();
     }
   }
 }