Fix CallSiteOptimization for pinned method overrides

Bug: 189264383
Change-Id: I78e9106a288aa0549e6bea68be4aebeec392255d
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 8ae52ef..3ef120b 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -722,6 +722,9 @@
     //   2) Second inlining pass for dealing with double inline callers.
     printPhase("Post optimization pass");
     if (appView.callSiteOptimizationInfoPropagator() != null) {
+      appView
+          .callSiteOptimizationInfoPropagator()
+          .abandonCallSitePropagationForPinnedMethodsAndOverrides(executorService);
       postMethodProcessorBuilder.put(appView.callSiteOptimizationInfoPropagator());
     }
     if (inliner != null) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
index 2fdd282..3403e7b 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
@@ -38,6 +38,7 @@
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.InternalOptions.CallSiteOptimizationOptions;
 import com.android.tools.r8.utils.LazyBox;
+import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.Sets;
@@ -45,6 +46,9 @@
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
 
 public class CallSiteOptimizationInfoPropagator implements PostOptimization {
 
@@ -291,6 +295,25 @@
     }
   }
 
+  public void abandonCallSitePropagationForPinnedMethodsAndOverrides(
+      ExecutorService executorService) throws ExecutionException {
+    ThreadUtils.processItems(
+        this::forEachPinnedNonPrivateVirtualMethod,
+        this::abandonCallSitePropagationForMethodAndOverrides,
+        executorService);
+  }
+
+  private void forEachPinnedNonPrivateVirtualMethod(Consumer<ProgramMethod> consumer) {
+    for (DexProgramClass clazz : appView.appInfo().classes()) {
+      for (ProgramMethod virtualProgramMethod : clazz.virtualProgramMethods()) {
+        if (virtualProgramMethod.getDefinition().isNonPrivateVirtualMethod()
+            && appView.getKeepInfo().isPinned(virtualProgramMethod.getReference(), appView)) {
+          consumer.accept(virtualProgramMethod);
+        }
+      }
+    }
+  }
+
   private void abandonCallSitePropagationForMethodAndOverrides(ProgramMethod method) {
     Set<ProgramMethod> abandonSet = Sets.newIdentityHashSet();
     if (method.getDefinition().isNonPrivateVirtualMethod()) {
diff --git a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java
index d454fe7..aec1045 100644
--- a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java
+++ b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java
@@ -10,7 +10,6 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
-import org.junit.Assume;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -37,7 +36,6 @@
 
   @Test
   public void testStreamD8() throws Exception {
-    Assume.assumeFalse("TODO(b/189264383): fix", shrinkDesugaredLibrary);
     KeepRuleConsumer keepRuleConsumer = createKeepRuleConsumer(parameters);
     testForD8()
         .addInnerClasses(SimpleStreamTest.class)
@@ -55,7 +53,6 @@
 
   @Test
   public void testStreamR8() throws Exception {
-    Assume.assumeFalse("TODO(b/189264383): fix", shrinkDesugaredLibrary);
     KeepRuleConsumer keepRuleConsumer = createKeepRuleConsumer(parameters);
     testForR8(Backend.DEX)
         .addInnerClasses(SimpleStreamTest.class)
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
new file mode 100644
index 0000000..9e42eb6
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
@@ -0,0 +1,166 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.callsites;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.google.common.collect.ImmutableList;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class CallSiteOptimizationPinnedMethodOverridePropagationTest extends TestBase {
+
+  private static final String CLASS_PREFIX =
+      "com.android.tools.r8.ir.optimize.callsites.CallSiteOptimizationPinnedMethodOverridePropagationTest$";
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDexRuntimes().withAllApiLevels().build();
+  }
+
+  public CallSiteOptimizationPinnedMethodOverridePropagationTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    R8TestCompileResult compiled =
+        testForR8(parameters.getBackend())
+            .addProgramClasses(
+                Arg.class, Arg1.class, Arg2.class, Call.class, CallImpl.class, Main2.class)
+            .addKeepRules(
+                ImmutableList.of(
+                    "-keep interface " + CLASS_PREFIX + "Arg",
+                    "-keep interface "
+                        + CLASS_PREFIX
+                        + "Call { \npublic void print("
+                        + CLASS_PREFIX
+                        + "Arg); \n}",
+                    "-keep class "
+                        + CLASS_PREFIX
+                        + "Main2 { \npublic static void main(java.lang.String[]); \npublic static "
+                        + CLASS_PREFIX
+                        + "Arg getArg1(); \npublic static "
+                        + CLASS_PREFIX
+                        + "Arg getArg2(); \npublic static "
+                        + CLASS_PREFIX
+                        + "Call getCaller(); \n}"))
+            .enableNoVerticalClassMergingAnnotations()
+            .enableNoHorizontalClassMergingAnnotations()
+            .enableInliningAnnotations()
+            .enableMemberValuePropagationAnnotations()
+            .setMinApi(parameters.getApiLevel())
+            .compile();
+    CodeInspector inspector = compiled.inspector();
+    compiled.run(parameters.getRuntime(), Main2.class).assertSuccessWithOutputLines("Arg1");
+    testForD8()
+        .addProgramClasses(Main.class)
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .addRunClasspathFiles(compiled.writeToZip())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("Arg1", "Arg2");
+  }
+
+  // Kept
+  @NoVerticalClassMerging
+  interface Arg {
+
+    @NeverInline
+    @NeverPropagateValue
+    String getString();
+  }
+
+  @NoVerticalClassMerging
+  @NoHorizontalClassMerging
+  static class Arg1 implements Arg {
+
+    @Override
+    @NeverInline
+    @NeverPropagateValue
+    public String getString() {
+      return "Arg1";
+    }
+  }
+
+  @NoVerticalClassMerging
+  @NoHorizontalClassMerging
+  static class Arg2 implements Arg {
+
+    @Override
+    @NeverInline
+    @NeverPropagateValue
+    public String getString() {
+      return "Arg2";
+    }
+  }
+
+  @NoVerticalClassMerging
+  interface Call {
+
+    // Kept.
+    @NeverInline
+    @NeverPropagateValue
+    void print(Arg arg);
+  }
+
+  @NoVerticalClassMerging
+  static class CallImpl implements Call {
+
+    @Override
+    @NeverInline
+    @NeverPropagateValue
+    public void print(Arg arg) {
+      System.out.println(arg.getString());
+    }
+  }
+
+  @NoVerticalClassMerging
+  static class Main2 {
+
+    // Kept.
+    public static void main(String[] args) {
+      // This would propagate Arg1 to print while it should not.
+      getCaller().print(new Arg1());
+    }
+
+    // Kept.
+    public static Arg getArg1() {
+      return new Arg1();
+    }
+
+    // Kept.
+    public static Arg getArg2() {
+      return new Arg2();
+    }
+
+    // Kept.
+    public static Call getCaller() {
+      return new CallSiteOptimizationPinnedMethodOverridePropagationTest.CallImpl();
+    }
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      Arg arg1 = Main2.getArg1();
+      Arg arg2 = Main2.getArg2();
+      Call caller = Main2.getCaller();
+      caller.print(arg1);
+      caller.print(arg2);
+    }
+  }
+}