Propagate unknown arguments for methods called by lambdas

Bug: 186729231
Change-Id: Iff586818912f1bc02b7abfb0fbe58ccc19645440
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
index 5eb23bb..75735fa 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
@@ -94,6 +94,11 @@
     if (mode != Mode.COLLECT) {
       return;
     }
+
+    if (appView.appInfo().isMethodTargetedByInvokeDynamic(code.context().getReference())) {
+      abandonCallSitePropagationForMethodAndOverrides(code.context());
+    }
+
     ProgramMethod context = code.context();
     for (Instruction instruction : code.instructions()) {
       if (instruction.isInvokeMethod()) {
@@ -286,6 +291,31 @@
     }
   }
 
+  private void abandonCallSitePropagationForMethodAndOverrides(ProgramMethod method) {
+    Set<ProgramMethod> abandonSet = Sets.newIdentityHashSet();
+    if (method.getDefinition().isNonPrivateVirtualMethod()) {
+      SingleResolutionResult resolutionResult =
+          new SingleResolutionResult(
+              method.getHolder(), method.getHolder(), method.getDefinition());
+      resolutionResult
+          .lookupVirtualDispatchTargets(method.getHolder(), appView.appInfo())
+          .forEach(
+              methodTarget -> {
+                if (methodTarget.isProgramMethod()) {
+                  abandonSet.add(methodTarget.asProgramMethod());
+                }
+              },
+              lambdaTarget -> {
+                if (lambdaTarget.getImplementationMethod().isProgramMethod()) {
+                  abandonSet.add(lambdaTarget.getImplementationMethod().asProgramMethod());
+                }
+              });
+    } else {
+      abandonSet.add(method);
+    }
+    abandonCallSitePropagation(abandonSet::forEach);
+  }
+
   private CallSiteOptimizationInfo computeCallSiteOptimizationInfoFromArguments(
       InvokeMethod invoke, ProgramMethod context, Timing timing) {
     timing.begin("Compute argument info");
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index f78dacd..2785412 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -615,6 +615,10 @@
     return clazz == null || !clazz.isProgramClass();
   }
 
+  public boolean isMethodTargetedByInvokeDynamic(DexMethod method) {
+    return methodsTargetedByInvokeDynamic.contains(method);
+  }
+
   public void forEachReachableInterface(Consumer<DexClass> consumer) {
     forEachReachableInterface(consumer, ImmutableList.of());
   }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLambdaPropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLibraryLambdaPropagationTest.java
similarity index 70%
copy from src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLambdaPropagationTest.java
copy to src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLibraryLambdaPropagationTest.java
index 858d4cc..774267d 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLambdaPropagationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLibraryLambdaPropagationTest.java
@@ -9,57 +9,57 @@
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import java.util.function.Consumer;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
-public class CallSiteOptimizationLambdaPropagationTest extends TestBase {
+public class CallSiteOptimizationLibraryLambdaPropagationTest extends TestBase {
 
   private final TestParameters parameters;
 
   @Parameters(name = "{0}")
   public static TestParametersCollection data() {
-    return getTestParameters().withAllRuntimesAndApiLevels().build();
+    return getTestParameters()
+        .withCfRuntimes()
+        .withDexRuntimes()
+        .withApiLevelsStartingAtIncluding(AndroidApiLevel.N)
+        .build();
   }
 
-  public CallSiteOptimizationLambdaPropagationTest(TestParameters parameters) {
+  public CallSiteOptimizationLibraryLambdaPropagationTest(TestParameters parameters) {
     this.parameters = parameters;
   }
 
   @Test
   public void test() throws Exception {
     testForR8(parameters.getBackend())
-        .addInnerClasses(CallSiteOptimizationLambdaPropagationTest.class)
+        .addInnerClasses(CallSiteOptimizationLibraryLambdaPropagationTest.class)
         .addKeepMainRule(TestClass.class)
         .enableInliningAnnotations()
         .enableNeverClassInliningAnnotations()
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), TestClass.class)
-        // TODO(b/186729231): Should succeed with "A", "B".
-        .assertSuccessWithOutputLines("A", parameters.isCfRuntime() ? "A" : "B");
+        .assertSuccessWithOutputLines("A", "B");
   }
 
   static class TestClass {
 
     public static void main(String[] args) {
       add(new A());
-      Consumer consumer = TestClass::add;
+      Consumer<Object> consumer = TestClass::add;
       consumer.accept("B");
     }
 
-    // TODO(b/186729231): Incorrectly fail to propagate "B" as an argument to add().
     @NeverInline
     static void add(Object o) {
       System.out.println(o.toString());
     }
   }
 
-  interface Consumer {
-    void accept(Object o);
-  }
-
   @NeverClassInline
   static class A {
 
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLambdaPropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationProgramLambdaPropagationTest.java
similarity index 78%
rename from src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLambdaPropagationTest.java
rename to src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationProgramLambdaPropagationTest.java
index 858d4cc..86145c6 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLambdaPropagationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationProgramLambdaPropagationTest.java
@@ -15,7 +15,7 @@
 import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
-public class CallSiteOptimizationLambdaPropagationTest extends TestBase {
+public class CallSiteOptimizationProgramLambdaPropagationTest extends TestBase {
 
   private final TestParameters parameters;
 
@@ -24,21 +24,20 @@
     return getTestParameters().withAllRuntimesAndApiLevels().build();
   }
 
-  public CallSiteOptimizationLambdaPropagationTest(TestParameters parameters) {
+  public CallSiteOptimizationProgramLambdaPropagationTest(TestParameters parameters) {
     this.parameters = parameters;
   }
 
   @Test
   public void test() throws Exception {
     testForR8(parameters.getBackend())
-        .addInnerClasses(CallSiteOptimizationLambdaPropagationTest.class)
+        .addInnerClasses(CallSiteOptimizationProgramLambdaPropagationTest.class)
         .addKeepMainRule(TestClass.class)
         .enableInliningAnnotations()
         .enableNeverClassInliningAnnotations()
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), TestClass.class)
-        // TODO(b/186729231): Should succeed with "A", "B".
-        .assertSuccessWithOutputLines("A", parameters.isCfRuntime() ? "A" : "B");
+        .assertSuccessWithOutputLines("A", "B");
   }
 
   static class TestClass {
@@ -49,7 +48,6 @@
       consumer.accept("B");
     }
 
-    // TODO(b/186729231): Incorrectly fail to propagate "B" as an argument to add().
     @NeverInline
     static void add(Object o) {
       System.out.println(o.toString());