Pin types referenced by pinned program members

Bug: 169724051
Bug: 163311975
Change-Id: I732a166394c661ba62643e301b15df19302e8f63
diff --git a/src/main/java/com/android/tools/r8/graph/DexField.java b/src/main/java/com/android/tools/r8/graph/DexField.java
index 012030b..87adbf7 100644
--- a/src/main/java/com/android/tools/r8/graph/DexField.java
+++ b/src/main/java/com/android/tools/r8/graph/DexField.java
@@ -8,6 +8,7 @@
 import com.android.tools.r8.naming.NamingLens;
 import com.android.tools.r8.references.FieldReference;
 import com.android.tools.r8.references.Reference;
+import java.util.Collections;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -116,6 +117,11 @@
   }
 
   @Override
+  public Iterable<DexType> getReferencedTypes() {
+    return Collections.singleton(type);
+  }
+
+  @Override
   public int slowCompareTo(DexField other) {
     int result = holder.slowCompareTo(other.holder);
     if (result != 0) {
diff --git a/src/main/java/com/android/tools/r8/graph/DexMember.java b/src/main/java/com/android/tools/r8/graph/DexMember.java
index a1767c3..5c34e83 100644
--- a/src/main/java/com/android/tools/r8/graph/DexMember.java
+++ b/src/main/java/com/android/tools/r8/graph/DexMember.java
@@ -3,6 +3,8 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.graph;
 
+import com.google.common.collect.Iterables;
+
 public abstract class DexMember<D extends DexEncodedMember<D, R>, R extends DexMember<D, R>>
     extends DexReference implements PresortedComparable<R> {
 
@@ -33,4 +35,10 @@
   public DexMember<D, R> asDexMember() {
     return this;
   }
+
+  public abstract Iterable<DexType> getReferencedTypes();
+
+  public Iterable<DexType> getReferencedBaseTypes(DexItemFactory dexItemFactory) {
+    return Iterables.transform(getReferencedTypes(), type -> type.toBaseType(dexItemFactory));
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/DexMethod.java b/src/main/java/com/android/tools/r8/graph/DexMethod.java
index 212f5ca..9757af4 100644
--- a/src/main/java/com/android/tools/r8/graph/DexMethod.java
+++ b/src/main/java/com/android/tools/r8/graph/DexMethod.java
@@ -142,6 +142,11 @@
   }
 
   @Override
+  public Iterable<DexType> getReferencedTypes() {
+    return proto.getTypes();
+  }
+
+  @Override
   public int computeHashCode() {
     return holder.hashCode()
         + proto.hashCode() * 7
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoKeepRules.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoKeepRules.java
index 6181c41..969f43a 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoKeepRules.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoKeepRules.java
@@ -4,27 +4,50 @@
 
 package com.android.tools.r8.horizontalclassmerging.policies;
 
+
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMember;
+import com.android.tools.r8.graph.DexMember;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.horizontalclassmerging.SingleClassPolicy;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import java.util.Set;
 
 public class NoKeepRules extends SingleClassPolicy {
   private final AppView<AppInfoWithLiveness> appView;
+  private final Set<DexType> dontMergeTypes = Sets.newIdentityHashSet();
 
   public NoKeepRules(AppView<AppInfoWithLiveness> appView) {
     this.appView = appView;
+
+    appView.appInfo().classes().forEach(this::processClass);
+  }
+
+  private void processClass(DexProgramClass programClass) {
+    DexType type = programClass.getType();
+    boolean pinProgramClass = appView.appInfo().isPinned(type);
+
+    for (DexEncodedMember<?, ?> member : programClass.members()) {
+      DexMember<?, ?> reference = member.toReference();
+      if (appView.appInfo().isPinned(reference)) {
+        pinProgramClass = true;
+        Iterables.addAll(
+            dontMergeTypes,
+            Iterables.filter(
+                reference.getReferencedBaseTypes(appView.dexItemFactory()), DexType::isClassType));
+      }
+    }
+
+    if (pinProgramClass) {
+      dontMergeTypes.add(type);
+    }
   }
 
   @Override
   public boolean canMerge(DexProgramClass program) {
-    DexType type = program.getType();
-    boolean anyPinned =
-        appView.appInfo().isPinned(type)
-            || Iterables.any(
-                program.members(), member -> appView.appInfo().isPinned(member.toReference()));
-    return !anyPinned;
+    return !dontMergeTypes.contains(program.getType());
   }
 }
diff --git a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidPTest.java b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidPTest.java
index 8edbd69..9ac01a0 100644
--- a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidPTest.java
+++ b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidPTest.java
@@ -51,7 +51,6 @@
 
   @Test
   public void invokeCustomWithShrinking() throws Throwable {
-    expectThrowsWithHorizontalClassMerging();
     test("invokecustom-with-shrinking", "invokecustom", "InvokeCustom")
         .withMinApiLevel(AndroidApiLevel.P.getLevel())
         .withBuilderTransformation(builder ->
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/PinnedClassMemberReferenceTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/PinnedClassMemberReferenceTest.java
new file mode 100644
index 0000000..4a22fe3
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/PinnedClassMemberReferenceTest.java
@@ -0,0 +1,161 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.R8FullTestBuilder;
+import com.android.tools.r8.R8TestRunResult;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import org.junit.Test;
+
+public class PinnedClassMemberReferenceTest extends HorizontalClassMergingTestBase {
+  public PinnedClassMemberReferenceTest(
+      TestParameters parameters, boolean enableHorizontalClassMerging) {
+    super(parameters, enableHorizontalClassMerging);
+  }
+
+  private R8FullTestBuilder testCommon() throws Exception {
+    return testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options -> options.enableHorizontalClassMerging = enableHorizontalClassMerging)
+        .noMinification()
+        .enableInliningAnnotations()
+        .enableNeverClassInliningAnnotations()
+        .setMinApi(parameters.getApiLevel());
+  }
+
+  private R8TestRunResult runAndAssertOutput(R8FullTestBuilder builder) throws Exception {
+    return builder
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(
+            "a", "b", "foo a: bar", "foo b: baz", "fields a: bar", "fields b: baz");
+  }
+
+  @Test
+  public void testWithoutKeepRules() throws Exception {
+    // This is just a small check ensure that without the keep rules the classes are merged.
+    assumeTrue(enableHorizontalClassMerging);
+    assumeTrue(parameters.isCfRuntime());
+
+    runAndAssertOutput(testCommon())
+        .inspect(
+            codeInspector -> {
+              ClassSubject aClassSubject = codeInspector.clazz(A.class);
+              assertThat(aClassSubject, isPresent());
+
+              assertThat(codeInspector.clazz(B.class), not(isPresent()));
+
+              ClassSubject cClassSubject = codeInspector.clazz(C.class);
+              assertThat(cClassSubject, isPresent());
+              assertThat(cClassSubject.field(aClassSubject.getFinalName(), "a"), isPresent());
+              assertThat(cClassSubject.field(aClassSubject.getFinalName(), "b"), isPresent());
+
+              assertThat(
+                  cClassSubject.method("void", "foo", aClassSubject.getFinalName()), isPresent());
+            });
+  }
+
+  @Test
+  public void testWithKeepRules() throws Exception {
+    runAndAssertOutput(
+            testCommon()
+                .addKeepRules(
+                    "-keepclassmembers class " + C.class.getTypeName() + " { ",
+                    "  " + A.class.getTypeName() + " a;",
+                    "  " + C.class.getTypeName() + " c;",
+                    "  void foo(" + A.class.getTypeName() + ");",
+                    "  void foo(" + B.class.getTypeName() + ");",
+                    "}"))
+        .inspect(
+            codeInspector -> {
+              ClassSubject aClassSubject = codeInspector.clazz(A.class);
+              assertThat(aClassSubject, isPresent());
+
+              ClassSubject bClassSubject = codeInspector.clazz(B.class);
+              assertThat(bClassSubject, isPresent());
+
+              ClassSubject cClassSubject = codeInspector.clazz(C.class);
+              assertThat(cClassSubject, isPresent());
+              assertThat(cClassSubject.field(aClassSubject.getFinalName(), "a"), isPresent());
+              assertThat(cClassSubject.field(bClassSubject.getFinalName(), "b"), isPresent());
+
+              assertThat(
+                  cClassSubject.method("void", "foo", aClassSubject.getFinalName()), isPresent());
+              assertThat(
+                  cClassSubject.method("void", "foo", bClassSubject.getFinalName()), isPresent());
+            });
+  }
+
+  @NeverClassInline
+  public static class A {
+    public A() {
+      System.out.println("a");
+    }
+
+    @NeverInline
+    public String bar() {
+      return "bar";
+    }
+  }
+
+  @NeverClassInline
+  public static class B {
+    public B() {
+      System.out.println("b");
+    }
+
+    @NeverInline
+    public String baz() {
+      return "baz";
+    }
+  }
+
+  @NeverClassInline
+  public static class C {
+    A a;
+    B b;
+
+    public C(A a, B b) {
+      this.a = a;
+      this.b = b;
+    }
+
+    @NeverInline
+    public void foo(A a2) {
+      System.out.println("foo a: " + a2.bar());
+    }
+
+    @NeverInline
+    public void foo(B b) {
+      System.out.println("foo b: " + b.baz());
+    }
+
+    @NeverInline
+    public void fields() {
+      System.out.println("fields a: " + a.bar());
+      System.out.println("fields b: " + b.baz());
+    }
+  }
+
+  public static class Main {
+    public static void main(String[] args) {
+      A a = new A();
+      B b = new B();
+      C c = new C(a, b);
+      c.foo(a);
+      c.foo(b);
+      c.fields();
+    }
+  }
+}