Account for collisions in lir rewriting

Bug: b/354878031
Change-Id: I0f141936f2835497fa6553a0e260f4d7fb04d20b
diff --git a/src/main/java/com/android/tools/r8/lightir/LirLensCodeRewriter.java b/src/main/java/com/android/tools/r8/lightir/LirLensCodeRewriter.java
index b5ab4d1..bd3ff3e 100644
--- a/src/main/java/com/android/tools/r8/lightir/LirLensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/lightir/LirLensCodeRewriter.java
@@ -36,17 +36,23 @@
 import com.android.tools.r8.lightir.LirCode.TryCatchTable;
 import com.android.tools.r8.naming.dexitembasedstring.NameComputationInfo;
 import com.android.tools.r8.utils.ArrayUtils;
+import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.verticalclassmerging.VerticalClassMergerGraphLens;
+import com.google.common.collect.ImmutableSet;
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
 import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 public class LirLensCodeRewriter<EV> extends LirParsedInstructionCallback<EV> {
 
+  private static final Set<DexMethod> NO_INVOKES_TO_REWRITE = ImmutableSet.of();
+
   private final AppView<? extends AppInfoWithClassHierarchy> appView;
   private final ProgramMethod context;
   private final DexMethod contextReference;
@@ -55,6 +61,7 @@
   private final LensCodeRewriterUtils helper;
 
   private int numberOfInvokeOpcodeChanges = 0;
+  private Set<DexMethod> invokesToRewrite = NO_INVOKES_TO_REWRITE;
   private Map<LirConstant, LirConstant> constantPoolMapping = null;
 
   private boolean hasNonTrivialRewritings = false;
@@ -131,7 +138,7 @@
       numberOfInvokeOpcodeChanges++;
     } else {
       // All non-type dependent mappings are just rewritten in the content pool.
-      addRewrittenMapping(method, newMethod);
+      addRewrittenMethodMapping(method, newMethod);
     }
   }
 
@@ -161,17 +168,31 @@
     return false;
   }
 
+  private void addRewrittenMethodMapping(DexMethod method, DexMethod rewrittenMethod) {
+    getOrCreateConstantPoolMapping()
+        .compute(
+            method,
+            (unusedKey, otherRewrittenMethod) -> {
+              if (otherRewrittenMethod == null || otherRewrittenMethod == rewrittenMethod) {
+                return rewrittenMethod;
+              } else {
+                // Two invokes with the same symbolic method reference but different invoke types
+                // are rewritten to two different symbolic method references. Record that the
+                // invokes need to be processed.
+                if (invokesToRewrite == NO_INVOKES_TO_REWRITE) {
+                  invokesToRewrite = new HashSet<>();
+                }
+                invokesToRewrite.add(method);
+                return method;
+              }
+            });
+  }
+
   private void addRewrittenMapping(LirConstant item, LirConstant rewrittenItem) {
     if (item == rewrittenItem) {
       return;
     }
-    if (constantPoolMapping == null) {
-      constantPoolMapping =
-          new IdentityHashMap<>(
-              // Avoid using initial capacity larger than the number of actual constants.
-              Math.min(getCode().getConstantPool().length, 32));
-    }
-    LirConstant old = constantPoolMapping.put(item, rewrittenItem);
+    LirConstant old = getOrCreateConstantPoolMapping().put(item, rewrittenItem);
     if (old != null && old != rewrittenItem) {
       throw new Unreachable(
           "Unexpected rewriting of item: "
@@ -183,6 +204,16 @@
     }
   }
 
+  private Map<LirConstant, LirConstant> getOrCreateConstantPoolMapping() {
+    if (constantPoolMapping == null) {
+      constantPoolMapping =
+          new IdentityHashMap<>(
+              // Avoid using initial capacity larger than the number of actual constants.
+              Math.min(getCode().getConstantPool().length, 32));
+    }
+    return constantPoolMapping;
+  }
+
   @Override
   public void onDexItemBasedConstString(
       DexReference item, NameComputationInfo<?> nameComputationInfo) {
@@ -271,25 +302,29 @@
     onInvoke(method, InvokeType.INTERFACE, true);
   }
 
-  private InvokeType getInvokeTypeThatMayChange(int opcode) {
+  private boolean isInvokeThatMaybeRequiresRewriting(int opcode) {
+    assert LirOpcodeUtils.isInvokeMethod(opcode);
+    if (!invokesToRewrite.isEmpty()) {
+      return true;
+    }
     if (codeLens.isIdentityLens() && LirOpcodeUtils.isInvokeMethod(opcode)) {
-      return LirOpcodeUtils.getInvokeType(opcode);
+      return true;
     }
     if (opcode == LirOpcodes.INVOKEVIRTUAL) {
-      return InvokeType.VIRTUAL;
+      return true;
     }
     if (opcode == LirOpcodes.INVOKEINTERFACE) {
-      return InvokeType.INTERFACE;
+      return true;
     }
     if (graphLens.isVerticalClassMergerLens()) {
       if (opcode == LirOpcodes.INVOKESTATIC_ITF) {
-        return InvokeType.STATIC;
+        return true;
       }
       if (opcode == LirOpcodes.INVOKESUPER) {
-        return InvokeType.SUPER;
+        return true;
       }
     }
-    return null;
+    return false;
   }
 
   public LirCode<EV> rewrite() {
@@ -447,7 +482,7 @@
   }
 
   private LirCode<EV> rewriteInstructionsWithInvokeTypeChanges(LirCode<EV> code) {
-    if (numberOfInvokeOpcodeChanges == 0) {
+    if (numberOfInvokeOpcodeChanges == 0 && invokesToRewrite.isEmpty()) {
       return code;
     }
     // Build a small map from method refs to index in case the type-dependent methods are already
@@ -471,8 +506,7 @@
         lirWriter.writeOneByteInstruction(opcode);
         continue;
       }
-      InvokeType type = getInvokeTypeThatMayChange(opcode);
-      if (type == null) {
+      if (!LirOpcodeUtils.isInvokeMethod(opcode) || !isInvokeThatMaybeRequiresRewriting(opcode)) {
         int size = view.getRemainingOperandSizeInBytes();
         lirWriter.writeInstruction(opcode, size);
         while (size-- > 0) {
@@ -480,9 +514,12 @@
         }
         continue;
       }
-      // This is potentially an invoke with a type change, in such cases the method is mapped with
+      // If this is either (i) an invoke with a type change or (ii) an invoke to a method M where
+      // there exists another invoke in the current method to M, and the two invokes are mapped to
+      // two different methods (one-to-many constant pool mapping), then the method is mapped with
       // the instruction updated to the new type. The constant pool is amended with the mapped
       // method if needed.
+      InvokeType type = LirOpcodeUtils.getInvokeType(opcode);
       int constantIndex = view.getNextConstantOperand();
       DexMethod method = (DexMethod) code.getConstantItem(constantIndex);
       MethodLookupResult result =
@@ -490,8 +527,7 @@
       boolean newIsInterface = lookupIsInterface(method, opcode, result);
       InvokeType newType = result.getType();
       int newOpcode = newType.getLirOpcode(newIsInterface);
-      if (newOpcode != opcode) {
-        --numberOfInvokeOpcodeChanges;
+      if (newOpcode != opcode || invokesToRewrite.contains(method)) {
         constantIndex =
             methodIndices.computeIfAbsent(
                 result.getReference(),
@@ -499,6 +535,7 @@
                   methodsToAppend.add(ref);
                   return rewrittenConstants.length + methodsToAppend.size() - 1;
                 });
+        numberOfInvokeOpcodeChanges -= BooleanUtils.intValue(newOpcode != opcode);
       }
       int constantIndexSize = ByteUtils.intEncodingSize(constantIndex);
       int remainingSize = view.getRemainingOperandSizeInBytes();
@@ -512,11 +549,9 @@
     // Note that since we assume 'null' in the mapping is identity this may end up with a stale
     // reference to a no longer used method. That is not an issue as it will be pruned when
     // building IR again, it is just a small and size overhead.
-    LirCode<EV> newCode =
-        code.copyWithNewConstantsAndInstructions(
-            ArrayUtils.appendElements(code.getConstantPool(), methodsToAppend),
-            byteWriter.toByteArray());
-    return newCode;
+    return code.copyWithNewConstantsAndInstructions(
+        ArrayUtils.appendElements(code.getConstantPool(), methodsToAppend),
+        byteWriter.toByteArray());
   }
 
   // TODO(b/157111832): This should be part of the graph lens lookup result.
diff --git a/src/test/java/com/android/tools/r8/lightir/LirLensRewritingWithOneToManyMappingTest.java b/src/test/java/com/android/tools/r8/lightir/LirLensRewritingWithOneToManyMappingTest.java
index 28c6a88..d4f27cc 100644
--- a/src/test/java/com/android/tools/r8/lightir/LirLensRewritingWithOneToManyMappingTest.java
+++ b/src/test/java/com/android/tools/r8/lightir/LirLensRewritingWithOneToManyMappingTest.java
@@ -13,7 +13,6 @@
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.utils.AndroidApiLevel;
-import com.android.tools.r8.utils.codeinspector.AssertUtils;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -33,28 +32,22 @@
 
   @Test
   public void test() throws Exception {
-    // TODO(b/354878031): Should succeed.
-    AssertUtils.assertFailsCompilationIf(
-        parameters.isCfRuntime() || parameters.getApiLevel().isLessThan(AndroidApiLevel.N),
-        () ->
-            testForR8(parameters.getBackend())
-                .addProgramClasses(Main.class, Baz.class, Qux.class)
-                .addKeepMainRule(Main.class)
-                .addLibraryClasses(Foo.class, Bar.class)
-                .addDefaultRuntimeLibrary(parameters)
-                .apply(setMockApiLevelForClass(Foo.class, AndroidApiLevel.B))
-                .apply(
-                    setMockApiLevelForMethod(
-                        Foo.class.getDeclaredMethod("method"), AndroidApiLevel.B))
-                .apply(setMockApiLevelForClass(Bar.class, AndroidApiLevel.B))
-                .enableInliningAnnotations()
-                .enableNeverClassInliningAnnotations()
-                .enableNoVerticalClassMergingAnnotations()
-                .setMinApi(parameters)
-                .compile()
-                .addRunClasspathClasses(Foo.class, Bar.class)
-                .run(parameters.getRuntime(), Main.class)
-                .assertSuccessWithOutputLines("Foo", "Foo"));
+    testForR8(parameters.getBackend())
+        .addProgramClasses(Main.class, Baz.class, Qux.class)
+        .addKeepMainRule(Main.class)
+        .addLibraryClasses(Foo.class, Bar.class)
+        .addDefaultRuntimeLibrary(parameters)
+        .apply(setMockApiLevelForClass(Foo.class, AndroidApiLevel.B))
+        .apply(setMockApiLevelForMethod(Foo.class.getDeclaredMethod("method"), AndroidApiLevel.B))
+        .apply(setMockApiLevelForClass(Bar.class, AndroidApiLevel.B))
+        .enableInliningAnnotations()
+        .enableNeverClassInliningAnnotations()
+        .enableNoVerticalClassMergingAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        .addRunClasspathClasses(Foo.class, Bar.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("Foo", "Foo");
   }
 
   static class Main {