Fix CallSiteOptimization for pinned method overrides
Bug: 189264383
Change-Id: I78e9106a288aa0549e6bea68be4aebeec392255d
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 8ae52ef..3ef120b 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -722,6 +722,9 @@
// 2) Second inlining pass for dealing with double inline callers.
printPhase("Post optimization pass");
if (appView.callSiteOptimizationInfoPropagator() != null) {
+ appView
+ .callSiteOptimizationInfoPropagator()
+ .abandonCallSitePropagationForPinnedMethodsAndOverrides(executorService);
postMethodProcessorBuilder.put(appView.callSiteOptimizationInfoPropagator());
}
if (inliner != null) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
index 2fdd282..3403e7b 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
@@ -38,6 +38,7 @@
import com.android.tools.r8.utils.InternalOptions;
import com.android.tools.r8.utils.InternalOptions.CallSiteOptimizationOptions;
import com.android.tools.r8.utils.LazyBox;
+import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.Timing;
import com.android.tools.r8.utils.collections.ProgramMethodSet;
import com.google.common.collect.Sets;
@@ -45,6 +46,9 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
public class CallSiteOptimizationInfoPropagator implements PostOptimization {
@@ -291,6 +295,25 @@
}
}
+ public void abandonCallSitePropagationForPinnedMethodsAndOverrides(
+ ExecutorService executorService) throws ExecutionException {
+ ThreadUtils.processItems(
+ this::forEachPinnedNonPrivateVirtualMethod,
+ this::abandonCallSitePropagationForMethodAndOverrides,
+ executorService);
+ }
+
+ private void forEachPinnedNonPrivateVirtualMethod(Consumer<ProgramMethod> consumer) {
+ for (DexProgramClass clazz : appView.appInfo().classes()) {
+ for (ProgramMethod virtualProgramMethod : clazz.virtualProgramMethods()) {
+ if (virtualProgramMethod.getDefinition().isNonPrivateVirtualMethod()
+ && appView.getKeepInfo().isPinned(virtualProgramMethod.getReference(), appView)) {
+ consumer.accept(virtualProgramMethod);
+ }
+ }
+ }
+ }
+
private void abandonCallSitePropagationForMethodAndOverrides(ProgramMethod method) {
Set<ProgramMethod> abandonSet = Sets.newIdentityHashSet();
if (method.getDefinition().isNonPrivateVirtualMethod()) {
diff --git a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java
index d454fe7..aec1045 100644
--- a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java
+++ b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java
@@ -10,7 +10,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
-import org.junit.Assume;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -37,7 +36,6 @@
@Test
public void testStreamD8() throws Exception {
- Assume.assumeFalse("TODO(b/189264383): fix", shrinkDesugaredLibrary);
KeepRuleConsumer keepRuleConsumer = createKeepRuleConsumer(parameters);
testForD8()
.addInnerClasses(SimpleStreamTest.class)
@@ -55,7 +53,6 @@
@Test
public void testStreamR8() throws Exception {
- Assume.assumeFalse("TODO(b/189264383): fix", shrinkDesugaredLibrary);
KeepRuleConsumer keepRuleConsumer = createKeepRuleConsumer(parameters);
testForR8(Backend.DEX)
.addInnerClasses(SimpleStreamTest.class)
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
new file mode 100644
index 0000000..9e42eb6
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
@@ -0,0 +1,166 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.callsites;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.google.common.collect.ImmutableList;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class CallSiteOptimizationPinnedMethodOverridePropagationTest extends TestBase {
+
+ private static final String CLASS_PREFIX =
+ "com.android.tools.r8.ir.optimize.callsites.CallSiteOptimizationPinnedMethodOverridePropagationTest$";
+ private final TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withDexRuntimes().withAllApiLevels().build();
+ }
+
+ public CallSiteOptimizationPinnedMethodOverridePropagationTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void test() throws Exception {
+ R8TestCompileResult compiled =
+ testForR8(parameters.getBackend())
+ .addProgramClasses(
+ Arg.class, Arg1.class, Arg2.class, Call.class, CallImpl.class, Main2.class)
+ .addKeepRules(
+ ImmutableList.of(
+ "-keep interface " + CLASS_PREFIX + "Arg",
+ "-keep interface "
+ + CLASS_PREFIX
+ + "Call { \npublic void print("
+ + CLASS_PREFIX
+ + "Arg); \n}",
+ "-keep class "
+ + CLASS_PREFIX
+ + "Main2 { \npublic static void main(java.lang.String[]); \npublic static "
+ + CLASS_PREFIX
+ + "Arg getArg1(); \npublic static "
+ + CLASS_PREFIX
+ + "Arg getArg2(); \npublic static "
+ + CLASS_PREFIX
+ + "Call getCaller(); \n}"))
+ .enableNoVerticalClassMergingAnnotations()
+ .enableNoHorizontalClassMergingAnnotations()
+ .enableInliningAnnotations()
+ .enableMemberValuePropagationAnnotations()
+ .setMinApi(parameters.getApiLevel())
+ .compile();
+ CodeInspector inspector = compiled.inspector();
+ compiled.run(parameters.getRuntime(), Main2.class).assertSuccessWithOutputLines("Arg1");
+ testForD8()
+ .addProgramClasses(Main.class)
+ .setMinApi(parameters.getApiLevel())
+ .compile()
+ .addRunClasspathFiles(compiled.writeToZip())
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("Arg1", "Arg2");
+ }
+
+ // Kept
+ @NoVerticalClassMerging
+ interface Arg {
+
+ @NeverInline
+ @NeverPropagateValue
+ String getString();
+ }
+
+ @NoVerticalClassMerging
+ @NoHorizontalClassMerging
+ static class Arg1 implements Arg {
+
+ @Override
+ @NeverInline
+ @NeverPropagateValue
+ public String getString() {
+ return "Arg1";
+ }
+ }
+
+ @NoVerticalClassMerging
+ @NoHorizontalClassMerging
+ static class Arg2 implements Arg {
+
+ @Override
+ @NeverInline
+ @NeverPropagateValue
+ public String getString() {
+ return "Arg2";
+ }
+ }
+
+ @NoVerticalClassMerging
+ interface Call {
+
+ // Kept.
+ @NeverInline
+ @NeverPropagateValue
+ void print(Arg arg);
+ }
+
+ @NoVerticalClassMerging
+ static class CallImpl implements Call {
+
+ @Override
+ @NeverInline
+ @NeverPropagateValue
+ public void print(Arg arg) {
+ System.out.println(arg.getString());
+ }
+ }
+
+ @NoVerticalClassMerging
+ static class Main2 {
+
+ // Kept.
+ public static void main(String[] args) {
+ // This would propagate Arg1 to print while it should not.
+ getCaller().print(new Arg1());
+ }
+
+ // Kept.
+ public static Arg getArg1() {
+ return new Arg1();
+ }
+
+ // Kept.
+ public static Arg getArg2() {
+ return new Arg2();
+ }
+
+ // Kept.
+ public static Call getCaller() {
+ return new CallSiteOptimizationPinnedMethodOverridePropagationTest.CallImpl();
+ }
+ }
+
+ static class Main {
+
+ public static void main(String[] args) {
+ Arg arg1 = Main2.getArg1();
+ Arg arg2 = Main2.getArg2();
+ Call caller = Main2.getCaller();
+ caller.print(arg1);
+ caller.print(arg2);
+ }
+ }
+}