Fix CallSiteOptimization for pinned method overrides
Bug: 189264383
Change-Id: I78e9106a288aa0549e6bea68be4aebeec392255d
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 2d91110..2412c2b 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -717,6 +717,9 @@
// 2) Second inlining pass for dealing with double inline callers.
printPhase("Post optimization pass");
if (appView.callSiteOptimizationInfoPropagator() != null) {
+ appView
+ .callSiteOptimizationInfoPropagator()
+ .abandonCallSitePropagationForPinnedMethodsAndOverrides(executorService);
postMethodProcessorBuilder.put(appView.callSiteOptimizationInfoPropagator());
}
if (inliner != null) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
index 2fdd282..3403e7b 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
@@ -38,6 +38,7 @@
import com.android.tools.r8.utils.InternalOptions;
import com.android.tools.r8.utils.InternalOptions.CallSiteOptimizationOptions;
import com.android.tools.r8.utils.LazyBox;
+import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.Timing;
import com.android.tools.r8.utils.collections.ProgramMethodSet;
import com.google.common.collect.Sets;
@@ -45,6 +46,9 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
public class CallSiteOptimizationInfoPropagator implements PostOptimization {
@@ -291,6 +295,25 @@
}
}
+ public void abandonCallSitePropagationForPinnedMethodsAndOverrides(
+ ExecutorService executorService) throws ExecutionException {
+ ThreadUtils.processItems(
+ this::forEachPinnedNonPrivateVirtualMethod,
+ this::abandonCallSitePropagationForMethodAndOverrides,
+ executorService);
+ }
+
+ private void forEachPinnedNonPrivateVirtualMethod(Consumer<ProgramMethod> consumer) {
+ for (DexProgramClass clazz : appView.appInfo().classes()) {
+ for (ProgramMethod virtualProgramMethod : clazz.virtualProgramMethods()) {
+ if (virtualProgramMethod.getDefinition().isNonPrivateVirtualMethod()
+ && appView.getKeepInfo().isPinned(virtualProgramMethod.getReference(), appView)) {
+ consumer.accept(virtualProgramMethod);
+ }
+ }
+ }
+ }
+
private void abandonCallSitePropagationForMethodAndOverrides(ProgramMethod method) {
Set<ProgramMethod> abandonSet = Sets.newIdentityHashSet();
if (method.getDefinition().isNonPrivateVirtualMethod()) {
diff --git a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java
new file mode 100644
index 0000000..aec1045
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/SimpleStreamTest.java
@@ -0,0 +1,84 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.desugar.desugaredlibrary;
+
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.StringUtils;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class SimpleStreamTest extends DesugaredLibraryTestBase {
+
+ private static final String EXPECTED_RESULT = StringUtils.lines("3");
+
+ private final TestParameters parameters;
+ private final boolean shrinkDesugaredLibrary;
+
+ @Parameters(name = "{1}, shrinkDesugaredLibrary: {0}")
+ public static List<Object[]> data() {
+ return buildParameters(
+ BooleanUtils.values(), getTestParameters().withDexRuntimes().withAllApiLevels().build());
+ }
+
+ public SimpleStreamTest(boolean shrinkDesugaredLibrary, TestParameters parameters) {
+ this.shrinkDesugaredLibrary = shrinkDesugaredLibrary;
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void testStreamD8() throws Exception {
+ KeepRuleConsumer keepRuleConsumer = createKeepRuleConsumer(parameters);
+ testForD8()
+ .addInnerClasses(SimpleStreamTest.class)
+ .setMinApi(parameters.getApiLevel())
+ .enableCoreLibraryDesugaring(parameters.getApiLevel(), keepRuleConsumer)
+ .compile()
+ .addDesugaredCoreLibraryRunClassPath(
+ this::buildDesugaredLibrary,
+ parameters.getApiLevel(),
+ keepRuleConsumer.get(),
+ shrinkDesugaredLibrary)
+ .run(parameters.getRuntime(), Executor.class)
+ .assertSuccessWithOutput(EXPECTED_RESULT);
+ }
+
+ @Test
+ public void testStreamR8() throws Exception {
+ KeepRuleConsumer keepRuleConsumer = createKeepRuleConsumer(parameters);
+ testForR8(Backend.DEX)
+ .addInnerClasses(SimpleStreamTest.class)
+ .setMinApi(parameters.getApiLevel())
+ .addKeepClassAndMembersRules(Executor.class)
+ .enableCoreLibraryDesugaring(parameters.getApiLevel(), keepRuleConsumer)
+ .compile()
+ .addDesugaredCoreLibraryRunClassPath(
+ this::buildDesugaredLibrary,
+ parameters.getApiLevel(),
+ keepRuleConsumer.get(),
+ shrinkDesugaredLibrary)
+ .run(parameters.getRuntime(), Executor.class)
+ .assertSuccessWithOutput(EXPECTED_RESULT);
+ }
+
+ @SuppressWarnings("unchecked")
+ static class Executor {
+
+ public static void main(String[] args) {
+ ArrayList<Integer> integers = new ArrayList<>();
+ integers.add(1);
+ integers.add(2);
+ integers.add(3);
+ List<Integer> collectedList = integers.stream().map(i -> i + 3).collect(Collectors.toList());
+ System.out.println(collectedList.size());
+ }
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
new file mode 100644
index 0000000..9e42eb6
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
@@ -0,0 +1,166 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.callsites;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.google.common.collect.ImmutableList;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class CallSiteOptimizationPinnedMethodOverridePropagationTest extends TestBase {
+
+ private static final String CLASS_PREFIX =
+ "com.android.tools.r8.ir.optimize.callsites.CallSiteOptimizationPinnedMethodOverridePropagationTest$";
+ private final TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withDexRuntimes().withAllApiLevels().build();
+ }
+
+ public CallSiteOptimizationPinnedMethodOverridePropagationTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void test() throws Exception {
+ R8TestCompileResult compiled =
+ testForR8(parameters.getBackend())
+ .addProgramClasses(
+ Arg.class, Arg1.class, Arg2.class, Call.class, CallImpl.class, Main2.class)
+ .addKeepRules(
+ ImmutableList.of(
+ "-keep interface " + CLASS_PREFIX + "Arg",
+ "-keep interface "
+ + CLASS_PREFIX
+ + "Call { \npublic void print("
+ + CLASS_PREFIX
+ + "Arg); \n}",
+ "-keep class "
+ + CLASS_PREFIX
+ + "Main2 { \npublic static void main(java.lang.String[]); \npublic static "
+ + CLASS_PREFIX
+ + "Arg getArg1(); \npublic static "
+ + CLASS_PREFIX
+ + "Arg getArg2(); \npublic static "
+ + CLASS_PREFIX
+ + "Call getCaller(); \n}"))
+ .enableNoVerticalClassMergingAnnotations()
+ .enableNoHorizontalClassMergingAnnotations()
+ .enableInliningAnnotations()
+ .enableMemberValuePropagationAnnotations()
+ .setMinApi(parameters.getApiLevel())
+ .compile();
+ CodeInspector inspector = compiled.inspector();
+ compiled.run(parameters.getRuntime(), Main2.class).assertSuccessWithOutputLines("Arg1");
+ testForD8()
+ .addProgramClasses(Main.class)
+ .setMinApi(parameters.getApiLevel())
+ .compile()
+ .addRunClasspathFiles(compiled.writeToZip())
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("Arg1", "Arg2");
+ }
+
+ // Kept
+ @NoVerticalClassMerging
+ interface Arg {
+
+ @NeverInline
+ @NeverPropagateValue
+ String getString();
+ }
+
+ @NoVerticalClassMerging
+ @NoHorizontalClassMerging
+ static class Arg1 implements Arg {
+
+ @Override
+ @NeverInline
+ @NeverPropagateValue
+ public String getString() {
+ return "Arg1";
+ }
+ }
+
+ @NoVerticalClassMerging
+ @NoHorizontalClassMerging
+ static class Arg2 implements Arg {
+
+ @Override
+ @NeverInline
+ @NeverPropagateValue
+ public String getString() {
+ return "Arg2";
+ }
+ }
+
+ @NoVerticalClassMerging
+ interface Call {
+
+ // Kept.
+ @NeverInline
+ @NeverPropagateValue
+ void print(Arg arg);
+ }
+
+ @NoVerticalClassMerging
+ static class CallImpl implements Call {
+
+ @Override
+ @NeverInline
+ @NeverPropagateValue
+ public void print(Arg arg) {
+ System.out.println(arg.getString());
+ }
+ }
+
+ @NoVerticalClassMerging
+ static class Main2 {
+
+ // Kept.
+ public static void main(String[] args) {
+ // This would propagate Arg1 to print while it should not.
+ getCaller().print(new Arg1());
+ }
+
+ // Kept.
+ public static Arg getArg1() {
+ return new Arg1();
+ }
+
+ // Kept.
+ public static Arg getArg2() {
+ return new Arg2();
+ }
+
+ // Kept.
+ public static Call getCaller() {
+ return new CallSiteOptimizationPinnedMethodOverridePropagationTest.CallImpl();
+ }
+ }
+
+ static class Main {
+
+ public static void main(String[] args) {
+ Arg arg1 = Main2.getArg1();
+ Arg arg2 = Main2.getArg2();
+ Call caller = Main2.getCaller();
+ caller.print(arg1);
+ caller.print(arg2);
+ }
+ }
+}