Reserve kept method signatures in vertical class merging

Bug: b/316862633
Change-Id: I9f9539a39efbfd89bc6a248194443b9ffc16b16e
diff --git a/src/main/java/com/android/tools/r8/classmerging/ClassMergerTreeFixer.java b/src/main/java/com/android/tools/r8/classmerging/ClassMergerTreeFixer.java
index 207c73b..8bf6828 100644
--- a/src/main/java/com/android/tools/r8/classmerging/ClassMergerTreeFixer.java
+++ b/src/main/java/com/android/tools/r8/classmerging/ClassMergerTreeFixer.java
@@ -29,6 +29,7 @@
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.BidirectionalOneToOneHashMap;
 import com.android.tools.r8.utils.collections.DexMethodSignatureBiMap;
+import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
 import com.android.tools.r8.utils.collections.MutableBidirectionalOneToOneMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
@@ -52,6 +53,8 @@
   private final SyntheticArgumentClass syntheticArgumentClass;
 
   private final Map<DexProgramClass, DexType> originalSuperTypes = new IdentityHashMap<>();
+
+  protected final DexMethodSignatureSet keptSignatures = DexMethodSignatureSet.create();
   private final DexMethodSignatureBiMap<DexMethodSignature> reservedInterfaceSignatures =
       new DexMethodSignatureBiMap<>();
 
@@ -74,6 +77,7 @@
     }
     timing.begin("Fixup");
     AppView<AppInfoWithLiveness> appViewWithLiveness = appView.withLiveness();
+    preprocess();
     Collection<DexProgramClass> classes = appView.appInfo().classesWithDeterministicOrder();
     Iterables.filter(classes, DexProgramClass::isInterface).forEach(this::fixupInterfaceClass);
     classes.forEach(this::fixupAttributes);
@@ -94,6 +98,10 @@
 
   public abstract boolean isRunningBeforePrimaryOptimizationPass();
 
+  public void preprocess() {
+    // Intentionally empty.
+  }
+
   public void postprocess() {
     // Intentionally empty.
   }
@@ -155,9 +163,12 @@
   }
 
   private DexEncodedMethod fixupVirtualInterfaceMethod(DexEncodedMethod method) {
-    DexMethod originalMethodReference = method.getReference();
+    if (keptSignatures.contains(method)) {
+      return method;
+    }
 
     // Don't process this method if it does not refer to a merge class type.
+    DexMethod originalMethodReference = method.getReference();
     boolean referencesMergeClass =
         Iterables.any(
             originalMethodReference.getReferencedBaseTypes(dexItemFactory),
@@ -264,64 +275,70 @@
     DexMethod originalMethodReference = method.getReference();
 
     // Fix all type references in the method prototype.
-    DexMethodSignature reservedMethodSignature = newMethodSignatures.get(method);
     DexMethod newMethodReference;
-    if (reservedMethodSignature != null) {
-      newMethodReference = reservedMethodSignature.withHolder(clazz, dexItemFactory);
+    if (keptSignatures.contains(method)) {
+      newMethodReference = method.getReference();
     } else {
-      newMethodReference = fixupMethodReference(originalMethodReference);
-      if (newMethodSignatures.containsValue(newMethodReference.getSignature())) {
-        // If the method collides with a direct method on the same class then rename it to a
-        // globally
-        // fresh name and record the signature.
-        if (method.isInstanceInitializer()) {
-          // If the method is an instance initializer, then add extra nulls.
-          Box<Set<DexType>> usedSyntheticArgumentClasses = new Box<>();
-          newMethodReference =
-              dexItemFactory.createInstanceInitializerWithFreshProto(
-                  newMethodReference,
-                  syntheticArgumentClass.getArgumentClasses(),
-                  tryMethod -> !newMethodSignatures.containsValue(tryMethod.getSignature()),
-                  usedSyntheticArgumentClasses::set);
-          lensBuilder.addExtraParameters(
-              originalMethodReference,
-              newMethodReference,
-              ExtraUnusedNullParameter.computeExtraUnusedNullParameters(
-                  originalMethodReference, newMethodReference));
+      DexMethodSignature reservedMethodSignature = newMethodSignatures.get(method);
+      if (reservedMethodSignature != null) {
+        newMethodReference = reservedMethodSignature.withHolder(clazz, dexItemFactory);
+      } else {
+        newMethodReference = fixupMethodReference(originalMethodReference);
+        if (keptSignatures.contains(newMethodReference)
+            || newMethodSignatures.containsValue(newMethodReference.getSignature())) {
+          // If the method collides with a direct method on the same class then rename it to a
+          // globally
+          // fresh name and record the signature.
+          if (method.isInstanceInitializer()) {
+            // If the method is an instance initializer, then add extra nulls.
+            Box<Set<DexType>> usedSyntheticArgumentClasses = new Box<>();
+            newMethodReference =
+                dexItemFactory.createInstanceInitializerWithFreshProto(
+                    newMethodReference,
+                    syntheticArgumentClass.getArgumentClasses(),
+                    tryMethod -> !newMethodSignatures.containsValue(tryMethod.getSignature()),
+                    usedSyntheticArgumentClasses::set);
+            lensBuilder.addExtraParameters(
+                originalMethodReference,
+                newMethodReference,
+                ExtraUnusedNullParameter.computeExtraUnusedNullParameters(
+                    originalMethodReference, newMethodReference));
 
-          // Amend the art profile collection.
-          if (usedSyntheticArgumentClasses.isSet()) {
-            Set<DexMethod> previousMethodReferences =
-                lensBuilder.getOriginalMethodReferences(originalMethodReference);
-            if (previousMethodReferences.isEmpty()) {
-              profileCollectionAdditions.applyIfContextIsInProfile(
-                  originalMethodReference,
-                  additionsBuilder ->
-                      usedSyntheticArgumentClasses.get().forEach(additionsBuilder::addRule));
-            } else {
-              for (DexMethod previousMethodReference : previousMethodReferences) {
+            // Amend the art profile collection.
+            if (usedSyntheticArgumentClasses.isSet()) {
+              Set<DexMethod> previousMethodReferences =
+                  lensBuilder.getOriginalMethodReferences(originalMethodReference);
+              if (previousMethodReferences.isEmpty()) {
                 profileCollectionAdditions.applyIfContextIsInProfile(
-                    previousMethodReference,
+                    originalMethodReference,
                     additionsBuilder ->
                         usedSyntheticArgumentClasses.get().forEach(additionsBuilder::addRule));
+              } else {
+                for (DexMethod previousMethodReference : previousMethodReferences) {
+                  profileCollectionAdditions.applyIfContextIsInProfile(
+                      previousMethodReference,
+                      additionsBuilder ->
+                          usedSyntheticArgumentClasses.get().forEach(additionsBuilder::addRule));
+                }
               }
             }
+          } else {
+            newMethodReference =
+                dexItemFactory.createFreshMethodNameWithoutHolder(
+                    newMethodReference.getName().toSourceString(),
+                    newMethodReference.getProto(),
+                    newMethodReference.getHolderType(),
+                    tryMethod ->
+                        !keptSignatures.contains(tryMethod)
+                            && !reservedInterfaceSignatures.containsValue(tryMethod.getSignature())
+                            && !remappedVirtualMethods.containsValue(tryMethod.getSignature())
+                            && !newMethodSignatures.containsValue(tryMethod.getSignature()));
           }
-        } else {
-          newMethodReference =
-              dexItemFactory.createFreshMethodNameWithoutHolder(
-                  newMethodReference.getName().toSourceString(),
-                  newMethodReference.getProto(),
-                  newMethodReference.getHolderType(),
-                  tryMethod ->
-                      !reservedInterfaceSignatures.containsValue(tryMethod.getSignature())
-                          && !remappedVirtualMethods.containsValue(tryMethod.getSignature())
-                          && !newMethodSignatures.containsValue(tryMethod.getSignature()));
         }
-      }
 
-      assert !newMethodSignatures.containsValue(newMethodReference.getSignature());
-      newMethodSignatures.put(method, newMethodReference.getSignature());
+        assert !newMethodSignatures.containsValue(newMethodReference.getSignature());
+        newMethodSignatures.put(method, newMethodReference.getSignature());
+      }
     }
 
     return fixupProgramMethod(clazz, method, newMethodReference);
@@ -378,24 +395,31 @@
       DexEncodedMethod method,
       DexMethodSignatureBiMap<DexMethodSignature> renamedClassVirtualMethods,
       MutableBidirectionalOneToOneMap<DexEncodedMethod, DexMethodSignature> newMethodSignatures) {
-    DexMethodSignature newSignature = newMethodSignatures.get(method);
-    if (newSignature == null) {
-      // Fix all type references in the method prototype.
-      newSignature =
-          dexItemFactory.createFreshMethodSignatureName(
-              method.getName().toSourceString(),
-              null,
-              fixupProto(method.getProto()),
-              trySignature ->
-                  !reservedInterfaceSignatures.containsValue(trySignature)
-                      && !newMethodSignatures.containsValue(trySignature)
-                      && !renamedClassVirtualMethods.containsValue(trySignature));
-      newMethodSignatures.put(method, newSignature);
+    DexMethodSignature newSignature;
+    if (keptSignatures.contains(method)) {
+      newSignature = method.getSignature();
+    } else {
+      newSignature = newMethodSignatures.get(method);
+      if (newSignature == null) {
+        // Fix all type references in the method prototype.
+        newSignature =
+            dexItemFactory.createFreshMethodSignatureName(
+                method.getName().toSourceString(),
+                null,
+                fixupProto(method.getProto()),
+                trySignature ->
+                    !keptSignatures.contains(trySignature)
+                        && !reservedInterfaceSignatures.containsValue(trySignature)
+                        && !newMethodSignatures.containsValue(trySignature)
+                        && !renamedClassVirtualMethods.containsValue(trySignature));
+        newMethodSignatures.put(method, newSignature);
+      }
     }
 
     // If any of the parameter types have been merged, record the signature mapping so that
     // subclasses perform the identical rename.
-    if (!reservedInterfaceSignatures.containsKey(method)
+    if (!keptSignatures.contains(method)
+        && !reservedInterfaceSignatures.containsKey(method)
         && Iterables.any(
             newSignature.getProto().getBaseTypes(dexItemFactory), mergedClasses::isMergeTarget)) {
       renamedClassVirtualMethods.put(method.getSignature(), newSignature);
diff --git a/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMergerTreeFixer.java b/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMergerTreeFixer.java
index 6b9955c..d814695 100644
--- a/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMergerTreeFixer.java
+++ b/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMergerTreeFixer.java
@@ -46,6 +46,19 @@
   }
 
   @Override
+  public void preprocess() {
+    appView
+        .getKeepInfo()
+        .forEachPinnedMethod(
+            method -> {
+              if (!method.isInstanceInitializer(dexItemFactory)) {
+                keptSignatures.add(method);
+              }
+            },
+            appView.options());
+  }
+
+  @Override
   public void postprocess() {
     lensBuilder.fixupContextualVirtualToDirectMethodMaps();
   }
diff --git a/src/test/java/com/android/tools/r8/classmerging/vertical/VerticalClassMergerPinnedMethodCollisionTest.java b/src/test/java/com/android/tools/r8/classmerging/vertical/VerticalClassMergerPinnedMethodCollisionTest.java
index c9fe0e3..1458079 100644
--- a/src/test/java/com/android/tools/r8/classmerging/vertical/VerticalClassMergerPinnedMethodCollisionTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/vertical/VerticalClassMergerPinnedMethodCollisionTest.java
@@ -3,7 +3,6 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.classmerging.vertical;
 
-import static com.android.tools.r8.utils.codeinspector.AssertUtils.assertFailsCompilation;
 
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
@@ -28,20 +27,19 @@
 
   @Test
   public void test() throws Exception {
-    assertFailsCompilation(
-        () ->
-            testForR8(parameters.getBackend())
-                .addInnerClasses(getClass())
-                .addKeepMainRule(Main.class)
-                .addKeepRules(
-                    "-keep class " + UserSub.class.getTypeName() + " {",
-                    "  void f(" + B.class.getTypeName() + ");",
-                    "}")
-                .addVerticallyMergedClassesInspector(
-                    inspector -> inspector.assertMergedIntoSubtype(A.class))
-                .enableInliningAnnotations()
-                .setMinApi(parameters)
-                .compile());
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addKeepRules(
+            "-keep class " + UserSub.class.getTypeName() + " {",
+            "  void f(" + B.class.getTypeName() + ");",
+            "}")
+        .addVerticallyMergedClassesInspector(
+            inspector -> inspector.assertMergedIntoSubtype(A.class))
+        .enableInliningAnnotations()
+        .setMinApi(parameters)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("B", "B");
   }
 
   static class Main {
diff --git a/src/test/java/com/android/tools/r8/classmerging/vertical/VerticalClassMergerPinnedMethodInterfaceCollisionTest.java b/src/test/java/com/android/tools/r8/classmerging/vertical/VerticalClassMergerPinnedMethodInterfaceCollisionTest.java
index 8b3129e..84fa122 100644
--- a/src/test/java/com/android/tools/r8/classmerging/vertical/VerticalClassMergerPinnedMethodInterfaceCollisionTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/vertical/VerticalClassMergerPinnedMethodInterfaceCollisionTest.java
@@ -3,7 +3,6 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.classmerging.vertical;
 
-import static com.android.tools.r8.utils.codeinspector.AssertUtils.assertFailsCompilation;
 
 import com.android.tools.r8.NoUnusedInterfaceRemoval;
 import com.android.tools.r8.NoVerticalClassMerging;
@@ -29,22 +28,20 @@
 
   @Test
   public void test() throws Exception {
-    assertFailsCompilation(
-        () ->
-            testForR8(parameters.getBackend())
-                .addInnerClasses(getClass())
-                .addKeepClassAndMembersRules(Main.class)
-                .addKeepRules(
-                    "-keep class " + OtherUser.class.getTypeName() + " {",
-                    "  void f(" + B.class.getTypeName() + ");",
-                    "}")
-                .addVerticallyMergedClassesInspector(
-                    inspector -> inspector.assertMergedIntoSubtype(A.class))
-                .enableInliningAnnotations()
-                .enableNoUnusedInterfaceRemovalAnnotations()
-                .enableNoVerticalClassMergingAnnotations()
-                .setMinApi(parameters)
-                .compile());
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepClassAndMembersRules(Main.class)
+        .addKeepRules(
+            "-keep class " + OtherUser.class.getTypeName() + " {",
+            "  void f(" + B.class.getTypeName() + ");",
+            "}")
+        .addVerticallyMergedClassesInspector(
+            inspector -> inspector.assertMergedIntoSubtype(A.class))
+        .enableNoUnusedInterfaceRemovalAnnotations()
+        .enableNoVerticalClassMergingAnnotations()
+        .setMinApi(parameters)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("B", "B");
   }
 
   static class Main {
@@ -80,7 +77,7 @@
 
     @Override
     public void f(A a) {
-      System.out.println("UserImpl");
+      System.out.println(a);
     }
   }