Build up a hierarchy of local reservation states in ProtoNormalizer

This will ensure we keep track of all existing methods in parents and therefore do not accidentally use the same descriptor.

Bug: b/258720808
Change-Id: I0642ab7a0e011a0789479680d664fea5fa609b3b
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderAppendOptimizer.java b/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderAppendOptimizer.java
index 407717c..5f0bea9 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderAppendOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderAppendOptimizer.java
@@ -608,7 +608,7 @@
       }
 
       @Override
-      protected List<Void> getFinalStateForRoots(Collection<StringBuilderNode> roots) {
+      protected List<Void> getFinalStateForRoots(Collection<? extends StringBuilderNode> roots) {
         return null;
       }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java b/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
index a3f2d38..3c032b3 100644
--- a/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
+++ b/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
@@ -8,6 +8,7 @@
 
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexMethodSignature;
@@ -15,31 +16,35 @@
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexTypeList;
 import com.android.tools.r8.graph.GenericSignature.MethodTypeSignature;
-import com.android.tools.r8.graph.ObjectAllocationInfoCollection;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.proto.RewrittenPrototypeDescription;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.KeepMethodInfo;
+import com.android.tools.r8.utils.DepthFirstSearchWorkListBase.StatefulDepthFirstSearchWorkList;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.IterableUtils;
 import com.android.tools.r8.utils.MapUtils;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.TraversalContinuation;
 import com.android.tools.r8.utils.WorkList;
 import com.android.tools.r8.utils.collections.BidirectionalOneToOneHashMap;
 import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
 import com.android.tools.r8.utils.collections.MutableBidirectionalOneToOneMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.function.Function;
 
 public class ProtoNormalizer {
 
@@ -60,47 +65,93 @@
   private void run(ExecutorService executorService) throws ExecutionException {
     GlobalReservationState globalReservationState = computeGlobalReservationState(executorService);
 
-    // TODO(b/173398086): This uses a single LocalReservationState for the entire program. This
-    //  should process the strongly connected program components in parallel, each with their own
-    //  LocalReservationState.
-    LocalReservationState localReservationState = new LocalReservationState();
+    // To ensure we do not add collisions of method signatures when creating the new method
+    // signatures we keep a map of methods in the type hierarchy, similar to what we do for
+    // minification.
     ProtoNormalizerGraphLens.Builder lensBuilder = ProtoNormalizerGraphLens.builder(appView);
-    for (DexProgramClass clazz : appView.appInfo().classesWithDeterministicOrder()) {
-      Map<DexMethodSignature, DexMethodSignature> newInstanceInitializerSignatures =
-          computeNewInstanceInitializerSignatures(
-              clazz, localReservationState, globalReservationState);
-      clazz
-          .getMethodCollection()
-          .replaceMethods(
-              method -> {
-                DexMethodSignature methodSignature = method.getSignature();
-                DexMethodSignature newMethodSignature =
-                    method.isInstanceInitializer()
-                        ? newInstanceInitializerSignatures.get(methodSignature)
-                        : localReservationState.getAndReserveNewMethodSignature(
-                            methodSignature, dexItemFactory, globalReservationState);
-                if (methodSignature.equals(newMethodSignature)) {
-                  return method;
-                }
-                DexMethod newMethodReference = newMethodSignature.withHolder(clazz, dexItemFactory);
-                RewrittenPrototypeDescription prototypeChanges =
-                    lensBuilder.recordNewMethodSignature(method, newMethodReference);
-                // TODO(b/195112263): Assert that the method does not have any optimization info.
-                //  If/when enabling proto normalization after the final round of tree shaking, this
-                //  should simply clear the optimization info, or replace it by a
-                //  ThrowingMethodOptimizationInfo since we should never use the optimization info
-                //  after this point.
-                return method.toTypeSubstitutedMethod(
-                    newMethodReference,
-                    builder -> {
-                      if (!prototypeChanges.isEmpty()) {
-                        builder
-                            .apply(prototypeChanges.createParameterAnnotationsRemover(method))
-                            .setGenericSignature(MethodTypeSignature.noSignature());
-                      }
-                    });
-              });
-    }
+    new StatefulDepthFirstSearchWorkList<DexClass, LocalReservationState, Void>() {
+
+      @Override
+      @SuppressWarnings("ReturnValueIgnored")
+      protected TraversalContinuation<Void, LocalReservationState> process(
+          DFSNodeWithState<DexClass, LocalReservationState> node,
+          Function<DexClass, DFSNodeWithState<DexClass, LocalReservationState>> childNodeConsumer) {
+        DexClass clazz = node.getNode();
+        node.setState(new LocalReservationState());
+        if (clazz.getSuperType() != null) {
+          appView
+              .contextIndependentDefinitionForWithResolutionResult(clazz.getSuperType())
+              .forEachClassResolutionResult(childNodeConsumer::apply);
+        }
+        return TraversalContinuation.doContinue();
+      }
+
+      @Override
+      protected TraversalContinuation<Void, LocalReservationState> joiner(
+          DFSNodeWithState<DexClass, LocalReservationState> node,
+          List<DFSNodeWithState<DexClass, LocalReservationState>> childStates) {
+        DexClass clazz = node.getNode();
+        assert childStates.size() >= 1
+            || clazz.getType() == dexItemFactory.objectType
+            || clazz.hasMissingSuperType(appView.appInfo());
+        // We can have multiple child states if there are multiple definitions of a class.
+        assert childStates.size() <= 1 || options.loadAllClassDefinitions;
+        LocalReservationState localReservationState = node.getState();
+        if (!childStates.isEmpty()) {
+          localReservationState.linkParent(childStates.get(0).getState());
+        }
+        Map<DexMethodSignature, DexMethodSignature> newInstanceInitializerSignatures =
+            clazz.isProgramClass()
+                ? computeNewInstanceInitializerSignatures(
+                    clazz.asProgramClass(), localReservationState, globalReservationState)
+                : null;
+        clazz
+            .getMethodCollection()
+            .replaceMethods(
+                method -> {
+                  DexMethodSignature methodSignature = method.getSignature();
+                  if (!clazz.isProgramClass()) {
+                    // We cannot change the signature of the method. Record it to ensure we do not
+                    // change a program signature in a sub class to override this.
+                    localReservationState.recordNoSignatureChange(methodSignature, dexItemFactory);
+                    return method;
+                  } else {
+                    assert newInstanceInitializerSignatures != null;
+                    DexMethodSignature newMethodSignature =
+                        method.isInstanceInitializer()
+                            ? newInstanceInitializerSignatures.get(methodSignature)
+                            : localReservationState.getAndReserveNewMethodSignature(
+                                methodSignature, dexItemFactory, globalReservationState);
+                    if (methodSignature.equals(newMethodSignature)) {
+                      // This method could not be optimized, record it as identity mapping to ensure
+                      // we keep it going forward.
+                      localReservationState.recordNoSignatureChange(
+                          methodSignature, dexItemFactory);
+                      return method;
+                    }
+                    DexMethod newMethodReference =
+                        newMethodSignature.withHolder(clazz.asProgramClass(), dexItemFactory);
+                    RewrittenPrototypeDescription prototypeChanges =
+                        lensBuilder.recordNewMethodSignature(method, newMethodReference);
+                    // TODO(b/195112263): Assert that the method does not have optimization info.
+                    // If/when enabling proto normalization after the final round of tree shaking,
+                    // this should simply clear the optimization info, or replace it by a
+                    // ThrowingMethodOptimizationInfo since we should never use the optimization
+                    // info after this point.
+                    return method.toTypeSubstitutedMethod(
+                        newMethodReference,
+                        builder -> {
+                          if (!prototypeChanges.isEmpty()) {
+                            builder
+                                .apply(prototypeChanges.createParameterAnnotationsRemover(method))
+                                .setGenericSignature(MethodTypeSignature.noSignature());
+                          }
+                        });
+                  }
+                });
+        return TraversalContinuation.doContinue(localReservationState);
+      }
+    }.run(appView.appInfo().classesWithDeterministicOrder());
 
     if (!lensBuilder.isEmpty()) {
       appView.rewriteWithLens(lensBuilder.build());
@@ -331,15 +382,13 @@
     if (appInfo.isBootstrapMethod(method)) {
       return true;
     }
-    ObjectAllocationInfoCollection objectAllocationInfoCollection =
-        appInfo.getObjectAllocationInfoCollection();
-    if (method.getHolder().isInterface()
+    // As long as we do not consider interface method and overrides as optimizable we can change
+    // method signatures in a top-down traversal.
+    return method.getHolder().isInterface()
         && method.getDefinition().isAbstract()
-        && objectAllocationInfoCollection.isImmediateInterfaceOfInstantiatedLambda(
-            method.getHolder())) {
-      return true;
-    }
-    return false;
+        && appInfo
+            .getObjectAllocationInfoCollection()
+            .isImmediateInterfaceOfInstantiatedLambda(method.getHolder());
   }
 
   static class GlobalReservationState {
@@ -389,8 +438,10 @@
 
   static class LocalReservationState {
 
-    MutableBidirectionalOneToOneMap<DexMethodSignature, DexMethodSignature> newMethodSignatures =
-        new BidirectionalOneToOneHashMap<>();
+    private final List<LocalReservationState> parents = new ArrayList<>(2);
+
+    private final MutableBidirectionalOneToOneMap<DexMethodSignature, DexMethodSignature>
+        newMethodSignatures = new BidirectionalOneToOneHashMap<>();
 
     DexMethodSignature getNewMethodSignature(
         DexMethodSignature methodSignature,
@@ -414,33 +465,68 @@
         GlobalReservationState globalReservationState,
         boolean reserve) {
       if (globalReservationState.isUnoptimizable(methodSignature)) {
-        assert !newMethodSignatures.containsKey(methodSignature);
+        assert getReserved(methodSignature) == null
+            || methodSignature.equals(getReserved(methodSignature));
         return methodSignature;
       }
-      DexMethodSignature reservedSignature = newMethodSignatures.get(methodSignature);
+      DexMethodSignature reservedSignature = getReserved(methodSignature);
       if (reservedSignature != null) {
-        assert reservedSignature
-            .getParameters()
-            .equals(globalReservationState.getReservedParameters(methodSignature));
         return reservedSignature;
       }
       DexTypeList reservedParameters =
           globalReservationState.getReservedParameters(methodSignature);
       DexMethodSignature newMethodSignature =
           methodSignature.withParameters(reservedParameters, dexItemFactory);
-      if (newMethodSignatures.containsValue(newMethodSignature)) {
+      if (isDestinationTaken(newMethodSignature)) {
         int index = 1;
         String newMethodBaseName = methodSignature.getName().toString();
         do {
           DexString newMethodName = dexItemFactory.createString(newMethodBaseName + "$" + index);
           newMethodSignature = newMethodSignature.withName(newMethodName);
           index++;
-        } while (newMethodSignatures.containsValue(newMethodSignature));
+        } while (isDestinationTaken(newMethodSignature));
       }
       if (reserve) {
         newMethodSignatures.put(methodSignature, newMethodSignature);
       }
       return newMethodSignature;
     }
+
+    private void linkParent(LocalReservationState parent) {
+      this.parents.add(parent);
+    }
+
+    private DexMethodSignature getReserved(DexMethodSignature signature) {
+      WorkList<LocalReservationState> workList = WorkList.newIdentityWorkList(this);
+      while (workList.hasNext()) {
+        LocalReservationState localReservationState = workList.next();
+        DexMethodSignature reservedSignature =
+            localReservationState.newMethodSignatures.get(signature);
+        if (reservedSignature != null) {
+          return reservedSignature;
+        }
+        workList.addIfNotSeen(localReservationState.parents);
+      }
+      return null;
+    }
+
+    private boolean isDestinationTaken(DexMethodSignature signature) {
+      WorkList<LocalReservationState> workList = WorkList.newIdentityWorkList(this);
+      while (workList.hasNext()) {
+        LocalReservationState localReservationState = workList.next();
+        if (localReservationState.newMethodSignatures.containsValue(signature)) {
+          return true;
+        }
+        workList.addIfNotSeen(localReservationState.parents);
+      }
+      return false;
+    }
+
+    public void recordNoSignatureChange(
+        DexMethodSignature methodSignature, DexItemFactory factory) {
+      if (!methodSignature.getName().equals(factory.constructorMethodName)) {
+        newMethodSignatures.put(methodSignature, methodSignature);
+      }
+    }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java b/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
index 37a55bb..a4d1c8f 100644
--- a/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
+++ b/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
@@ -115,7 +115,7 @@
   /** The joining of state during backtracking of the algorithm. */
   abstract TraversalContinuation<TB, TC> internalOnJoin(T node);
 
-  protected abstract List<TC> getFinalStateForRoots(Collection<N> roots);
+  protected abstract List<TC> getFinalStateForRoots(Collection<? extends N> roots);
 
   final T internalEnqueueNode(N value) {
     T dfsNode = nodeToNodeWithStateMap.computeIfAbsent(value, this::createDfsNode);
@@ -138,7 +138,7 @@
     return run(Arrays.asList(roots));
   }
 
-  public final TraversalContinuation<TB, List<TC>> run(Collection<N> roots) {
+  public final TraversalContinuation<TB, List<TC>> run(Collection<? extends N> roots) {
     roots.forEach(this::internalEnqueueNode);
     while (!workList.isEmpty()) {
       T node = workList.removeLast();
@@ -199,7 +199,7 @@
     }
 
     @Override
-    protected List<TC> getFinalStateForRoots(Collection<N> roots) {
+    protected List<TC> getFinalStateForRoots(Collection<? extends N> roots) {
       return null;
     }
   }
@@ -264,7 +264,7 @@
     }
 
     @Override
-    public List<S> getFinalStateForRoots(Collection<N> roots) {
+    public List<S> getFinalStateForRoots(Collection<? extends N> roots) {
       return ListUtils.map(roots, root -> getNodeStateForNode(root).state);
     }
   }
diff --git a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationDestinationOverrideLibraryTest.java b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationDestinationOverrideLibraryTest.java
index 3b7a115..f58336a 100644
--- a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationDestinationOverrideLibraryTest.java
+++ b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationDestinationOverrideLibraryTest.java
@@ -4,10 +4,6 @@
 
 package com.android.tools.r8.optimize.proto;
 
-import static org.hamcrest.CoreMatchers.containsString;
-import static org.junit.Assert.assertThrows;
-
-import com.android.tools.r8.CompilationFailedException;
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.NoMethodStaticizing;
 import com.android.tools.r8.TestBase;
@@ -44,26 +40,19 @@
 
   @Test
   public void testR8() throws Exception {
-    // TODO(b/258720808): We should not fail compilation.
-    assertThrows(
-        CompilationFailedException.class,
-        () ->
-            testForR8(parameters.getBackend())
-                .addProgramClasses(Main.class, ProgramClass.class, X.class)
-                .addDefaultRuntimeLibrary(parameters)
-                .addLibraryClasses(LibraryClass.class)
-                .setMinApi(parameters.getApiLevel())
-                .addKeepMainRule(Main.class)
-                .addDontObfuscate()
-                .enableInliningAnnotations()
-                .enableNoMethodStaticizingAnnotations()
-                .compileWithExpectedDiagnostics(
-                    diagnostics -> {
-                      diagnostics.assertErrorMessageThatMatches(
-                          containsString(
-                              "went from not overriding a library method to overriding a library"
-                                  + " method"));
-                    }));
+    testForR8(parameters.getBackend())
+        .addProgramClasses(Main.class, ProgramClass.class, X.class)
+        .addDefaultRuntimeLibrary(parameters)
+        .addLibraryClasses(LibraryClass.class)
+        .setMinApi(parameters.getApiLevel())
+        .addKeepMainRule(Main.class)
+        .addDontObfuscate()
+        .enableInliningAnnotations()
+        .enableNoMethodStaticizingAnnotations()
+        .compile()
+        .addBootClasspathClasses(LibraryClass.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
   }
 
   public static class LibraryClass {
diff --git a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationDuplicateMethodTest.java b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationDuplicateMethodTest.java
index 899ca4d..c935224 100644
--- a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationDuplicateMethodTest.java
+++ b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationDuplicateMethodTest.java
@@ -4,10 +4,6 @@
 
 package com.android.tools.r8.optimize.proto;
 
-import static org.hamcrest.CoreMatchers.containsString;
-import static org.junit.Assert.assertThrows;
-
-import com.android.tools.r8.CompilationFailedException;
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.NoHorizontalClassMerging;
 import com.android.tools.r8.TestBase;
@@ -42,24 +38,17 @@
 
   @Test
   public void testR8() throws Exception {
-    // TODO(b/258720808): We should not cause collision with an existing method.
-    assertThrows(
-        CompilationFailedException.class,
-        () -> {
-          testForR8(parameters.getBackend())
-              .addInnerClasses(getClass())
-              .setMinApi(parameters.getApiLevel())
-              .addKeepMainRule(Main.class)
-              .addKeepMethodRules(
-                  Reference.methodFromMethod(
-                      B.class.getDeclaredMethod("foo$1", int.class, int.class, String.class)))
-              .enableInliningAnnotations()
-              .enableNoHorizontalClassMergingAnnotations()
-              .compileWithExpectedDiagnostics(
-                  diagnostics ->
-                      diagnostics.assertErrorMessageThatMatches(
-                          containsString("Duplicate method")));
-        });
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .setMinApi(parameters.getApiLevel())
+        .addKeepMainRule(Main.class)
+        .addKeepMethodRules(
+            Reference.methodFromMethod(
+                B.class.getDeclaredMethod("foo$1", int.class, int.class, String.class)))
+        .enableInliningAnnotations()
+        .enableNoHorizontalClassMergingAnnotations()
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
   }
 
   @NoHorizontalClassMerging
diff --git a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationIntroduceCollisionTest.java b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationIntroduceCollisionTest.java
index f4b8ccd..947e3e7 100644
--- a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationIntroduceCollisionTest.java
+++ b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationIntroduceCollisionTest.java
@@ -19,7 +19,6 @@
 public class ProtoNormalizationIntroduceCollisionTest extends TestBase {
 
   private final String[] EXPECTED = new String[] {"Base::foo-42Calling B::foo1337"};
-  private final String[] R8_EXPECTED = new String[] {"Sub::foo-Calling B::foo421337"};
 
   @Parameter() public TestParameters parameters;
 
@@ -46,8 +45,7 @@
         .enableInliningAnnotations()
         .enableNoVerticalClassMergingAnnotations()
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/258720808): We should produce the expected result.
-        .assertSuccessWithOutputLines(R8_EXPECTED);
+        .assertSuccessWithOutputLines(EXPECTED);
   }
 
   public static class Base {
diff --git a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithVirtualMethodCollisionTest.java b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithVirtualMethodCollisionTest.java
index c179493..c3997c9 100644
--- a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithVirtualMethodCollisionTest.java
+++ b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithVirtualMethodCollisionTest.java
@@ -36,9 +36,6 @@
   private final String[] EXPECTED =
       new String[] {"A::foo", "B", "A", "B::foo", "B", "A", "B::foo", "B", "A"};
 
-  private final String[] R8_EXPECTED =
-      new String[] {"A::foo", "B", "A", "A::foo", "B", "A", "B::foo", "B", "A"};
-
   @Test
   public void testRuntime() throws Exception {
     testForRuntime(parameters)
@@ -59,8 +56,7 @@
         .setMinApi(parameters.getApiLevel())
         .compile()
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/258720808): We should not produce incorrect results.
-        .assertSuccessWithOutputLines(R8_EXPECTED)
+        .assertSuccessWithOutputLines(EXPECTED)
         .inspect(
             inspector -> {
               ClassSubject bClassSubject = inspector.clazz(B.class);
@@ -72,7 +68,7 @@
               TypeSubject bTypeSubject = bClassSubject.asTypeSubject();
               TypeSubject aTypeSubject = aClassSubject.asTypeSubject();
 
-              MethodSubject fooMethodSubject = aClassSubject.uniqueMethodWithOriginalName("foo");
+              MethodSubject fooMethodSubject = aClassSubject.uniqueMethodWithFinalName("foo$1");
               assertThat(fooMethodSubject, isPresent());
               assertThat(fooMethodSubject, hasParameters(aTypeSubject, bTypeSubject));