Resolve instance initializer collisions in proto normalization

Bug: 195112263
Change-Id: Ib3e8172db63123acd0ce77f23a04654f59b2b3c9
diff --git a/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java b/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
index 29ed273..e28c76b 100644
--- a/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
+++ b/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
@@ -6,6 +6,7 @@
 
 import static com.android.tools.r8.utils.MapUtils.ignoreKey;
 
+import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
@@ -16,11 +17,15 @@
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.IterableUtils;
+import com.android.tools.r8.utils.MapUtils;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.WorkList;
 import com.android.tools.r8.utils.collections.BidirectionalOneToOneHashMap;
 import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
 import com.android.tools.r8.utils.collections.MutableBidirectionalOneToOneMap;
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
 import java.util.Collections;
 import java.util.HashMap;
@@ -59,14 +64,19 @@
     LocalReservationState localReservationState = new LocalReservationState();
     ProtoNormalizerGraphLens.Builder lensBuilder = ProtoNormalizerGraphLens.builder(appView);
     for (DexProgramClass clazz : appView.appInfo().classesWithDeterministicOrder()) {
+      Map<DexMethodSignature, DexMethodSignature> newInstanceInitializerSignatures =
+          computeNewInstanceInitializerSignatures(
+              clazz, localReservationState, globalReservationState);
       clazz
           .getMethodCollection()
           .replaceMethods(
               method -> {
                 DexMethodSignature methodSignature = method.getSignature();
                 DexMethodSignature newMethodSignature =
-                    localReservationState.getNewMethodSignature(
-                        methodSignature, dexItemFactory, globalReservationState);
+                    method.isInstanceInitializer()
+                        ? newInstanceInitializerSignatures.get(methodSignature)
+                        : localReservationState.getAndReserveNewMethodSignature(
+                            methodSignature, dexItemFactory, globalReservationState);
                 if (methodSignature.equals(newMethodSignature)) {
                   return method;
                 }
@@ -170,6 +180,90 @@
     }
   }
 
+  Map<DexMethodSignature, DexMethodSignature> computeNewInstanceInitializerSignatures(
+      DexProgramClass clazz,
+      LocalReservationState localReservationState,
+      GlobalReservationState globalReservationState) {
+    // Create a map from new method signatures to old method signatures. This produces a one-to-many
+    // mapping since multiple instance initializers may normalize to the same signature.
+    Map<DexMethodSignature, DexMethodSignatureSet> instanceInitializerCollisions =
+        computeInstanceInitializerCollisions(clazz, localReservationState, globalReservationState);
+
+    // Resolve each collision to ensure that the mapping is one-to-one.
+    resolveInstanceInitializerCollisions(instanceInitializerCollisions);
+
+    // Inverse the one-to-one map to produce a mapping from old method signatures to new method
+    // signatures.
+    return MapUtils.transform(
+        instanceInitializerCollisions,
+        HashMap::new,
+        (newMethodSignature, methodSignatures) -> Iterables.getFirst(methodSignatures, null),
+        (newMethodSignature, methodSignatures) -> newMethodSignature,
+        (newMethodSignature, methodSignature, otherMethodSignature) -> {
+          throw new Unreachable();
+        });
+  }
+
+  private Map<DexMethodSignature, DexMethodSignatureSet> computeInstanceInitializerCollisions(
+      DexProgramClass clazz,
+      LocalReservationState localReservationState,
+      GlobalReservationState globalReservationState) {
+    Map<DexMethodSignature, DexMethodSignatureSet> instanceInitializerCollisions = new HashMap<>();
+    clazz.forEachProgramInstanceInitializer(
+        method -> {
+          DexMethodSignature methodSignature = method.getMethodSignature();
+          DexMethodSignature newMethodSignature =
+              localReservationState.getNewMethodSignature(
+                  methodSignature, dexItemFactory, globalReservationState);
+          instanceInitializerCollisions
+              .computeIfAbsent(newMethodSignature, ignoreKey(DexMethodSignatureSet::create))
+              .add(methodSignature);
+        });
+    return instanceInitializerCollisions;
+  }
+
+  private void resolveInstanceInitializerCollisions(
+      Map<DexMethodSignature, DexMethodSignatureSet> instanceInitializerCollisions) {
+    WorkList<DexMethodSignature> worklist = WorkList.newEqualityWorkList();
+    instanceInitializerCollisions.forEach(
+        (newMethodSignature, methodSignatures) -> {
+          if (methodSignatures.size() > 1) {
+            worklist.addIfNotSeen(newMethodSignature);
+          }
+        });
+
+    while (worklist.hasNext()) {
+      DexMethodSignature newMethodSignature = worklist.removeSeen();
+      DexMethodSignatureSet methodSignatures =
+          instanceInitializerCollisions.get(newMethodSignature);
+      assert methodSignatures.size() > 1;
+
+      // Resolve this conflict in a deterministic way.
+      DexMethodSignature survivor =
+          methodSignatures.contains(newMethodSignature)
+              ? newMethodSignature
+              : IterableUtils.min(methodSignatures, DexMethodSignature::compareTo);
+
+      // Disallow optimizations of all other methods than the `survivor`.
+      for (DexMethodSignature methodSignature : methodSignatures) {
+        if (!methodSignature.equals(survivor)) {
+          DexMethodSignatureSet originalMethodSignaturesForMethodSignature =
+              instanceInitializerCollisions.computeIfAbsent(
+                  methodSignature, ignoreKey(DexMethodSignatureSet::create));
+          originalMethodSignaturesForMethodSignature.add(methodSignature);
+          if (originalMethodSignaturesForMethodSignature.size() > 1) {
+            worklist.addIfNotSeen(methodSignature);
+          }
+        }
+      }
+
+      // Remove all pinned methods from the set of original method signatures stored at
+      // instanceInitializerCollisions.get(newMethodSignature).
+      methodSignatures.clear();
+      methodSignatures.add(survivor);
+    }
+  }
+
   private boolean isUnoptimizable(ProgramMethod method) {
     // TODO(b/195112263): This is incomplete.
     return appView.getKeepInfo(method).isPinned(options)
@@ -226,11 +320,27 @@
     MutableBidirectionalOneToOneMap<DexMethodSignature, DexMethodSignature> newMethodSignatures =
         new BidirectionalOneToOneHashMap<>();
 
-    // TODO: avoid sorting multiple times.
     DexMethodSignature getNewMethodSignature(
         DexMethodSignature methodSignature,
         DexItemFactory dexItemFactory,
         GlobalReservationState globalReservationState) {
+      return internalGetAndReserveNewMethodSignature(
+          methodSignature, dexItemFactory, globalReservationState, false);
+    }
+
+    DexMethodSignature getAndReserveNewMethodSignature(
+        DexMethodSignature methodSignature,
+        DexItemFactory dexItemFactory,
+        GlobalReservationState globalReservationState) {
+      return internalGetAndReserveNewMethodSignature(
+          methodSignature, dexItemFactory, globalReservationState, true);
+    }
+
+    private DexMethodSignature internalGetAndReserveNewMethodSignature(
+        DexMethodSignature methodSignature,
+        DexItemFactory dexItemFactory,
+        GlobalReservationState globalReservationState,
+        boolean reserve) {
       if (globalReservationState.isUnoptimizable(methodSignature)) {
         assert !newMethodSignatures.containsKey(methodSignature);
         return methodSignature;
@@ -254,7 +364,9 @@
           newMethodSignature = newMethodSignature.withName(newMethodName);
         } while (newMethodSignatures.containsValue(newMethodSignature));
       }
-      newMethodSignatures.put(methodSignature, newMethodSignature);
+      if (reserve) {
+        newMethodSignatures.put(methodSignature, newMethodSignature);
+      }
       return newMethodSignature;
     }
   }
diff --git a/src/main/java/com/android/tools/r8/utils/MapUtils.java b/src/main/java/com/android/tools/r8/utils/MapUtils.java
index a29af0b..2f55680 100644
--- a/src/main/java/com/android/tools/r8/utils/MapUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/MapUtils.java
@@ -9,6 +9,7 @@
 import it.unimi.dsi.fastutil.ints.Int2ReferenceMaps;
 import java.util.IdentityHashMap;
 import java.util.Map;
+import java.util.function.BiFunction;
 import java.util.function.BiPredicate;
 import java.util.function.Function;
 import java.util.function.IntFunction;
@@ -63,14 +64,28 @@
       Function<K1, K2> keyMapping,
       Function<V1, V2> valueMapping,
       TriFunction<K2, V2, V2, V2> valueMerger) {
+    return transform(
+        map,
+        factory,
+        (key, value) -> keyMapping.apply(key),
+        (key, value) -> valueMapping.apply(value),
+        valueMerger);
+  }
+
+  public static <K1, V1, K2, V2> Map<K2, V2> transform(
+      Map<K1, V1> map,
+      IntFunction<Map<K2, V2>> factory,
+      BiFunction<K1, V1, K2> keyMapping,
+      BiFunction<K1, V1, V2> valueMapping,
+      TriFunction<K2, V2, V2, V2> valueMerger) {
     Map<K2, V2> result = factory.apply(map.size());
     map.forEach(
         (key, value) -> {
-          K2 newKey = keyMapping.apply(key);
+          K2 newKey = keyMapping.apply(key, value);
           if (newKey == null) {
             return;
           }
-          V2 newValue = valueMapping.apply(value);
+          V2 newValue = valueMapping.apply(key, value);
           V2 existingValue = result.put(newKey, newValue);
           if (existingValue != null) {
             result.put(newKey, valueMerger.apply(newKey, existingValue, newValue));
diff --git a/src/test/java/com/android/tools/r8/graph/genericsignature/GenericSignaturePartialTypeArgumentApplierTest.java b/src/test/java/com/android/tools/r8/graph/genericsignature/GenericSignaturePartialTypeArgumentApplierTest.java
index da89df6..4b04321 100644
--- a/src/test/java/com/android/tools/r8/graph/genericsignature/GenericSignaturePartialTypeArgumentApplierTest.java
+++ b/src/test/java/com/android/tools/r8/graph/genericsignature/GenericSignaturePartialTypeArgumentApplierTest.java
@@ -33,6 +33,7 @@
 import java.util.Map;
 import java.util.Set;
 import java.util.function.BiPredicate;
+import java.util.function.Function;
 import java.util.function.Predicate;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -121,7 +122,7 @@
                     MapUtils.transform(
                         substitutions,
                         HashMap::new,
-                        s -> s,
+                        Function.identity(),
                         ClassTypeSignature::new,
                         (key, val1, val2) -> {
                           throw new Unreachable("No keys should be merged");
diff --git a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithInstanceInitializerCollisionTest.java b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithInstanceInitializerCollisionTest.java
new file mode 100644
index 0000000..66a4cdc
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithInstanceInitializerCollisionTest.java
@@ -0,0 +1,107 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.proto;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import com.android.tools.r8.utils.codeinspector.TypeSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class ProtoNormalizationWithInstanceInitializerCollisionTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options -> options.testing.enableExperimentalProtoNormalization = true)
+        .enableNoHorizontalClassMergingAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(
+            inspector -> {
+              ClassSubject mainClassSubject = inspector.clazz(Main.class);
+              assertThat(mainClassSubject, isPresent());
+
+              ClassSubject aClassSubject = inspector.clazz(A.class);
+              assertThat(aClassSubject, isPresent());
+
+              ClassSubject bClassSubject = inspector.clazz(B.class);
+              assertThat(bClassSubject, isPresent());
+
+              TypeSubject aTypeSubject = aClassSubject.asTypeSubject();
+              TypeSubject bTypeSubject = bClassSubject.asTypeSubject();
+
+              // Main.<init>(A, B) is unchanged.
+              MethodSubject initMethodSubject =
+                  mainClassSubject.initFromTypes(aTypeSubject, bTypeSubject);
+              assertThat(initMethodSubject, isPresent());
+
+              // Main.<init>(B, A) is unchanged.
+              MethodSubject otherInitMethodSubject =
+                  mainClassSubject.initFromTypes(bTypeSubject, aTypeSubject);
+              assertThat(otherInitMethodSubject, isPresent());
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("A", "B", "A", "B");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      new Main(new A(), new B());
+      new Main(new B(), new A());
+    }
+
+    Main(A a, B b) {
+      System.out.println(a);
+      System.out.println(b);
+    }
+
+    Main(B b, A a) {
+      System.out.println(a);
+      System.out.println(b);
+    }
+  }
+
+  @NoHorizontalClassMerging
+  static class A {
+
+    @Override
+    public String toString() {
+      return "A";
+    }
+  }
+
+  @NoHorizontalClassMerging
+  static class B {
+
+    @Override
+    public String toString() {
+      return "B";
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
index 30b3fb0..d604156 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
@@ -26,6 +26,7 @@
 import java.util.List;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
+import java.util.stream.Collectors;
 import kotlinx.metadata.jvm.KotlinClassMetadata;
 import org.junit.rules.TemporaryFolder;
 
@@ -128,6 +129,14 @@
     return init(Arrays.asList(parameters));
   }
 
+  public MethodSubject initFromTypes(List<TypeSubject> parameters) {
+    return init(parameters.stream().map(TypeSubject::getTypeName).collect(Collectors.toList()));
+  }
+
+  public MethodSubject initFromTypes(TypeSubject... parameters) {
+    return initFromTypes(Arrays.asList(parameters));
+  }
+
   public MethodSubject method(MethodSignature signature) {
     return method(signature.type, signature.name, ImmutableList.copyOf(signature.parameters));
   }