Pin types referenced by pinned program members
Bug: 169724051
Bug: 163311975
Change-Id: I732a166394c661ba62643e301b15df19302e8f63
diff --git a/src/main/java/com/android/tools/r8/graph/DexField.java b/src/main/java/com/android/tools/r8/graph/DexField.java
index 012030b..87adbf7 100644
--- a/src/main/java/com/android/tools/r8/graph/DexField.java
+++ b/src/main/java/com/android/tools/r8/graph/DexField.java
@@ -8,6 +8,7 @@
import com.android.tools.r8.naming.NamingLens;
import com.android.tools.r8.references.FieldReference;
import com.android.tools.r8.references.Reference;
+import java.util.Collections;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
@@ -116,6 +117,11 @@
}
@Override
+ public Iterable<DexType> getReferencedTypes() {
+ return Collections.singleton(type);
+ }
+
+ @Override
public int slowCompareTo(DexField other) {
int result = holder.slowCompareTo(other.holder);
if (result != 0) {
diff --git a/src/main/java/com/android/tools/r8/graph/DexMember.java b/src/main/java/com/android/tools/r8/graph/DexMember.java
index a1767c3..5c34e83 100644
--- a/src/main/java/com/android/tools/r8/graph/DexMember.java
+++ b/src/main/java/com/android/tools/r8/graph/DexMember.java
@@ -3,6 +3,8 @@
// BSD-style license that can be found in the LICENSE file.
package com.android.tools.r8.graph;
+import com.google.common.collect.Iterables;
+
public abstract class DexMember<D extends DexEncodedMember<D, R>, R extends DexMember<D, R>>
extends DexReference implements PresortedComparable<R> {
@@ -33,4 +35,10 @@
public DexMember<D, R> asDexMember() {
return this;
}
+
+ public abstract Iterable<DexType> getReferencedTypes();
+
+ public Iterable<DexType> getReferencedBaseTypes(DexItemFactory dexItemFactory) {
+ return Iterables.transform(getReferencedTypes(), type -> type.toBaseType(dexItemFactory));
+ }
}
diff --git a/src/main/java/com/android/tools/r8/graph/DexMethod.java b/src/main/java/com/android/tools/r8/graph/DexMethod.java
index 212f5ca..9757af4 100644
--- a/src/main/java/com/android/tools/r8/graph/DexMethod.java
+++ b/src/main/java/com/android/tools/r8/graph/DexMethod.java
@@ -142,6 +142,11 @@
}
@Override
+ public Iterable<DexType> getReferencedTypes() {
+ return proto.getTypes();
+ }
+
+ @Override
public int computeHashCode() {
return holder.hashCode()
+ proto.hashCode() * 7
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoKeepRules.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoKeepRules.java
index 6181c41..969f43a 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoKeepRules.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoKeepRules.java
@@ -4,27 +4,50 @@
package com.android.tools.r8.horizontalclassmerging.policies;
+
import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMember;
+import com.android.tools.r8.graph.DexMember;
import com.android.tools.r8.graph.DexProgramClass;
import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.horizontalclassmerging.SingleClassPolicy;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import java.util.Set;
public class NoKeepRules extends SingleClassPolicy {
private final AppView<AppInfoWithLiveness> appView;
+ private final Set<DexType> dontMergeTypes = Sets.newIdentityHashSet();
public NoKeepRules(AppView<AppInfoWithLiveness> appView) {
this.appView = appView;
+
+ appView.appInfo().classes().forEach(this::processClass);
+ }
+
+ private void processClass(DexProgramClass programClass) {
+ DexType type = programClass.getType();
+ boolean pinProgramClass = appView.appInfo().isPinned(type);
+
+ for (DexEncodedMember<?, ?> member : programClass.members()) {
+ DexMember<?, ?> reference = member.toReference();
+ if (appView.appInfo().isPinned(reference)) {
+ pinProgramClass = true;
+ Iterables.addAll(
+ dontMergeTypes,
+ Iterables.filter(
+ reference.getReferencedBaseTypes(appView.dexItemFactory()), DexType::isClassType));
+ }
+ }
+
+ if (pinProgramClass) {
+ dontMergeTypes.add(type);
+ }
}
@Override
public boolean canMerge(DexProgramClass program) {
- DexType type = program.getType();
- boolean anyPinned =
- appView.appInfo().isPinned(type)
- || Iterables.any(
- program.members(), member -> appView.appInfo().isPinned(member.toReference()));
- return !anyPinned;
+ return !dontMergeTypes.contains(program.getType());
}
}
diff --git a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidPTest.java b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidPTest.java
index 8edbd69..9ac01a0 100644
--- a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidPTest.java
+++ b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidPTest.java
@@ -51,7 +51,6 @@
@Test
public void invokeCustomWithShrinking() throws Throwable {
- expectThrowsWithHorizontalClassMerging();
test("invokecustom-with-shrinking", "invokecustom", "InvokeCustom")
.withMinApiLevel(AndroidApiLevel.P.getLevel())
.withBuilderTransformation(builder ->
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/PinnedClassMemberReferenceTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/PinnedClassMemberReferenceTest.java
new file mode 100644
index 0000000..4a22fe3
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/PinnedClassMemberReferenceTest.java
@@ -0,0 +1,161 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.R8FullTestBuilder;
+import com.android.tools.r8.R8TestRunResult;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import org.junit.Test;
+
+public class PinnedClassMemberReferenceTest extends HorizontalClassMergingTestBase {
+ public PinnedClassMemberReferenceTest(
+ TestParameters parameters, boolean enableHorizontalClassMerging) {
+ super(parameters, enableHorizontalClassMerging);
+ }
+
+ private R8FullTestBuilder testCommon() throws Exception {
+ return testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .addOptionsModification(
+ options -> options.enableHorizontalClassMerging = enableHorizontalClassMerging)
+ .noMinification()
+ .enableInliningAnnotations()
+ .enableNeverClassInliningAnnotations()
+ .setMinApi(parameters.getApiLevel());
+ }
+
+ private R8TestRunResult runAndAssertOutput(R8FullTestBuilder builder) throws Exception {
+ return builder
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines(
+ "a", "b", "foo a: bar", "foo b: baz", "fields a: bar", "fields b: baz");
+ }
+
+ @Test
+ public void testWithoutKeepRules() throws Exception {
+ // This is just a small check ensure that without the keep rules the classes are merged.
+ assumeTrue(enableHorizontalClassMerging);
+ assumeTrue(parameters.isCfRuntime());
+
+ runAndAssertOutput(testCommon())
+ .inspect(
+ codeInspector -> {
+ ClassSubject aClassSubject = codeInspector.clazz(A.class);
+ assertThat(aClassSubject, isPresent());
+
+ assertThat(codeInspector.clazz(B.class), not(isPresent()));
+
+ ClassSubject cClassSubject = codeInspector.clazz(C.class);
+ assertThat(cClassSubject, isPresent());
+ assertThat(cClassSubject.field(aClassSubject.getFinalName(), "a"), isPresent());
+ assertThat(cClassSubject.field(aClassSubject.getFinalName(), "b"), isPresent());
+
+ assertThat(
+ cClassSubject.method("void", "foo", aClassSubject.getFinalName()), isPresent());
+ });
+ }
+
+ @Test
+ public void testWithKeepRules() throws Exception {
+ runAndAssertOutput(
+ testCommon()
+ .addKeepRules(
+ "-keepclassmembers class " + C.class.getTypeName() + " { ",
+ " " + A.class.getTypeName() + " a;",
+ " " + C.class.getTypeName() + " c;",
+ " void foo(" + A.class.getTypeName() + ");",
+ " void foo(" + B.class.getTypeName() + ");",
+ "}"))
+ .inspect(
+ codeInspector -> {
+ ClassSubject aClassSubject = codeInspector.clazz(A.class);
+ assertThat(aClassSubject, isPresent());
+
+ ClassSubject bClassSubject = codeInspector.clazz(B.class);
+ assertThat(bClassSubject, isPresent());
+
+ ClassSubject cClassSubject = codeInspector.clazz(C.class);
+ assertThat(cClassSubject, isPresent());
+ assertThat(cClassSubject.field(aClassSubject.getFinalName(), "a"), isPresent());
+ assertThat(cClassSubject.field(bClassSubject.getFinalName(), "b"), isPresent());
+
+ assertThat(
+ cClassSubject.method("void", "foo", aClassSubject.getFinalName()), isPresent());
+ assertThat(
+ cClassSubject.method("void", "foo", bClassSubject.getFinalName()), isPresent());
+ });
+ }
+
+ @NeverClassInline
+ public static class A {
+ public A() {
+ System.out.println("a");
+ }
+
+ @NeverInline
+ public String bar() {
+ return "bar";
+ }
+ }
+
+ @NeverClassInline
+ public static class B {
+ public B() {
+ System.out.println("b");
+ }
+
+ @NeverInline
+ public String baz() {
+ return "baz";
+ }
+ }
+
+ @NeverClassInline
+ public static class C {
+ A a;
+ B b;
+
+ public C(A a, B b) {
+ this.a = a;
+ this.b = b;
+ }
+
+ @NeverInline
+ public void foo(A a2) {
+ System.out.println("foo a: " + a2.bar());
+ }
+
+ @NeverInline
+ public void foo(B b) {
+ System.out.println("foo b: " + b.baz());
+ }
+
+ @NeverInline
+ public void fields() {
+ System.out.println("fields a: " + a.bar());
+ System.out.println("fields b: " + b.baz());
+ }
+ }
+
+ public static class Main {
+ public static void main(String[] args) {
+ A a = new A();
+ B b = new B();
+ C c = new C(a, b);
+ c.foo(a);
+ c.foo(b);
+ c.fields();
+ }
+ }
+}