Fix incorrect rewriting of invoke-super instruction in class merger

This CL rewrites the two graph lenses produced by the vertical class merger into a single graph lens, which indirectly prevents an invoke-super rewriting issue.

Bug: 133501933
Change-Id: I0b180be1a24d698d36576d307205a1beceeda27c
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index cf89715..bded74d 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -27,7 +27,6 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexTypeList;
 import com.android.tools.r8.graph.GraphLense;
-import com.android.tools.r8.graph.GraphLense.Builder;
 import com.android.tools.r8.graph.GraphLense.GraphLenseLookupResult;
 import com.android.tools.r8.graph.KeyedDexItem;
 import com.android.tools.r8.graph.MethodAccessFlags;
@@ -250,7 +249,7 @@
     this.appView = appView;
     this.executorService = executorService;
     this.methodPoolCollection = new MethodPoolCollection(appView);
-    this.renamedMembersLense = new VerticalClassMergerGraphLense.Builder();
+    this.renamedMembersLense = new VerticalClassMergerGraphLense.Builder(appView.dexItemFactory());
     this.timing = timing;
     this.mainDexClasses = mainDexClasses;
 
@@ -637,10 +636,15 @@
 
   public GraphLense run() {
     timing.begin("merge");
-    GraphLense mergingGraphLense = mergeClasses();
+    // Visit the program classes in a top-down order according to the class hierarchy.
+    TopDownClassHierarchyTraversal.forProgramClasses(appView)
+        .visit(mergeCandidates, this::mergeClassIfPossible);
+    if (Log.ENABLED) {
+      Log.debug(getClass(), "Merged %d classes.", mergedClasses.size());
+    }
     timing.end();
     timing.begin("fixup");
-    GraphLense result = new TreeFixer().fixupTypeReferences(mergingGraphLense);
+    GraphLense result = new TreeFixer().fixupTypeReferences();
     timing.end();
     assert result.assertDefinitionsNotModified(
         appInfo.alwaysInline.stream()
@@ -714,16 +718,6 @@
     return true;
   }
 
-  private GraphLense mergeClasses() {
-    // Visit the program classes in a top-down order according to the class hierarchy.
-    TopDownClassHierarchyTraversal.forProgramClasses(appView)
-        .visit(mergeCandidates, this::mergeClassIfPossible);
-    if (Log.ENABLED) {
-      Log.debug(getClass(), "Merged %d classes.", mergedClasses.size());
-    }
-    return renamedMembersLense.build(appView.graphLense(), mergedClasses, appView);
-  }
-
   private boolean methodResolutionMayChange(DexClass source, DexClass target) {
     for (DexEncodedMethod virtualSourceMethod : source.virtualMethods()) {
       DexEncodedMethod directTargetMethod = target.lookupDirectMethod(virtualSourceMethod.method);
@@ -891,7 +885,7 @@
     private final DexClass source;
     private final DexClass target;
     private final VerticalClassMergerGraphLense.Builder deferredRenamings =
-        new VerticalClassMergerGraphLense.Builder();
+        new VerticalClassMergerGraphLense.Builder(appView.dexItemFactory());
     private final List<SynthesizedBridgeCode> synthesizedBridges = new ArrayList<>();
 
     private boolean abortMerge = false;
@@ -1442,10 +1436,12 @@
 
   private class TreeFixer {
 
-    private final Builder lense = GraphLense.builder();
+    private final VerticalClassMergerGraphLense.Builder lensBuilder =
+        VerticalClassMergerGraphLense.Builder.createBuilderForFixup(
+            renamedMembersLense, mergedClasses);
     private final Map<DexProto, DexProto> protoFixupCache = new IdentityHashMap<>();
 
-    private GraphLense fixupTypeReferences(GraphLense graphLense) {
+    private GraphLense fixupTypeReferences() {
       // Globally substitute merged class types in protos and holders.
       for (DexProgramClass clazz : appInfo.classes()) {
         fixupMethods(clazz.directMethods(), clazz::setDirectMethod);
@@ -1456,11 +1452,7 @@
       for (SynthesizedBridgeCode synthesizedBridge : synthesizedBridges) {
         synthesizedBridge.updateMethodSignatures(this::fixupMethod);
       }
-      // Record type renamings so check-cast and instance-of checks are also fixed.
-      for (DexType type : mergedClasses.keySet()) {
-        lense.map(type, fixupType(type));
-      }
-      return lense.build(application.dexItemFactory, graphLense);
+      return lensBuilder.build(appView, mergedClasses);
     }
 
     private void fixupMethods(List<DexEncodedMethod> methods, MethodSetter setter) {
@@ -1472,7 +1464,9 @@
         DexMethod method = encodedMethod.method;
         DexMethod newMethod = fixupMethod(method);
         if (newMethod != method) {
-          lense.move(method, newMethod);
+          if (!lensBuilder.hasOriginalSignatureMappingFor(newMethod)) {
+            lensBuilder.map(method, newMethod).recordMove(method, newMethod);
+          }
           setter.setMethod(i, encodedMethod.toTypeSubstitutedMethod(newMethod));
         }
       }
@@ -1489,7 +1483,9 @@
         DexType newHolder = fixupType(field.holder);
         DexField newField = application.dexItemFactory.createField(newHolder, newType, field.name);
         if (newField != encodedField.field) {
-          lense.move(encodedField.field, newField);
+          if (!lensBuilder.hasOriginalSignatureMappingFor(newField)) {
+            lensBuilder.map(field, newField);
+          }
           setter.setField(i, encodedField.toTypeSubstitutedField(newField));
         }
       }
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMergerGraphLense.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMergerGraphLense.java
index d980242..689eea1 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMergerGraphLense.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMergerGraphLense.java
@@ -15,9 +15,7 @@
 import com.android.tools.r8.ir.code.Invoke.Type;
 import com.google.common.collect.BiMap;
 import com.google.common.collect.HashBiMap;
-import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
-import java.util.HashMap;
 import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.Set;
@@ -55,8 +53,9 @@
   private Set<DexMethod> mergedMethods;
   private final Map<DexMethod, DexMethod> originalMethodSignaturesForBridges;
 
-  public VerticalClassMergerGraphLense(
+  private VerticalClassMergerGraphLense(
       AppView<?> appView,
+      Map<DexType, DexType> typeMap,
       Map<DexField, DexField> fieldMap,
       Map<DexMethod, DexMethod> methodMap,
       Set<DexMethod> mergedMethods,
@@ -66,7 +65,7 @@
       Map<DexMethod, DexMethod> originalMethodSignaturesForBridges,
       GraphLense previousLense) {
     super(
-        ImmutableMap.of(),
+        typeMap,
         methodMap,
         fieldMap,
         originalFieldSignatures,
@@ -96,9 +95,7 @@
     DexMethod previousContext =
         originalMethodSignaturesForBridges.containsKey(context)
             ? originalMethodSignaturesForBridges.get(context)
-            : originalMethodSignatures != null
-                ? originalMethodSignatures.getOrDefault(context, context)
-                : context;
+            : originalMethodSignatures.getOrDefault(context, context);
     GraphLenseLookupResult previous = previousLense.lookupMethod(method, previousContext, type);
     if (previous.getType() == Type.SUPER && !mergedMethods.contains(context)) {
       Map<DexMethod, GraphLenseLookupResult> virtualToDirectMethodMap =
@@ -162,75 +159,122 @@
 
   public static class Builder {
 
+    private final DexItemFactory dexItemFactory;
+
     protected final BiMap<DexField, DexField> fieldMap = HashBiMap.create();
-    protected final Map<DexMethod, DexMethod> methodMap = new HashMap<>();
+    protected final Map<DexMethod, DexMethod> methodMap = new IdentityHashMap<>();
     private final ImmutableSet.Builder<DexMethod> mergedMethodsBuilder = ImmutableSet.builder();
     private final Map<DexType, Map<DexMethod, GraphLenseLookupResult>>
-        contextualVirtualToDirectMethodMaps = new HashMap<>();
+        contextualVirtualToDirectMethodMaps = new IdentityHashMap<>();
 
     private final BiMap<DexMethod, DexMethod> originalMethodSignatures = HashBiMap.create();
     private final Map<DexMethod, DexMethod> originalMethodSignaturesForBridges =
         new IdentityHashMap<>();
 
-    public GraphLense build(
-        GraphLense previousLense, Map<DexType, DexType> mergedClasses, AppView<?> appView) {
-      if (fieldMap.isEmpty()
-          && methodMap.isEmpty()
-          && contextualVirtualToDirectMethodMaps.isEmpty()) {
-        return previousLense;
+    private final Map<DexProto, DexProto> cache = new IdentityHashMap<>();
+
+    Builder(DexItemFactory dexItemFactory) {
+      this.dexItemFactory = dexItemFactory;
+    }
+
+    static Builder createBuilderForFixup(Builder builder, Map<DexType, DexType> mergedClasses) {
+      Builder newBuilder = new Builder(builder.dexItemFactory);
+      for (Map.Entry<DexField, DexField> entry : builder.fieldMap.entrySet()) {
+        newBuilder.map(
+            entry.getKey(),
+            builder.getFieldSignatureAfterClassMerging(entry.getValue(), mergedClasses));
       }
-      Map<DexProto, DexProto> cache = new HashMap<>();
+      for (Map.Entry<DexMethod, DexMethod> entry : builder.methodMap.entrySet()) {
+        newBuilder.map(
+            entry.getKey(),
+            builder.getMethodSignatureAfterClassMerging(entry.getValue(), mergedClasses));
+      }
+      for (DexMethod method : builder.mergedMethodsBuilder.build()) {
+        newBuilder.markMethodAsMerged(
+            builder.getMethodSignatureAfterClassMerging(method, mergedClasses));
+      }
+      for (Map.Entry<DexType, Map<DexMethod, GraphLenseLookupResult>> entry :
+          builder.contextualVirtualToDirectMethodMaps.entrySet()) {
+        DexType context = entry.getKey();
+        assert context == builder.getTypeAfterClassMerging(context, mergedClasses);
+        for (Map.Entry<DexMethod, GraphLenseLookupResult> innerEntry :
+            entry.getValue().entrySet()) {
+          DexMethod from = innerEntry.getKey();
+          GraphLenseLookupResult rewriting = innerEntry.getValue();
+          DexMethod to =
+              builder.getMethodSignatureAfterClassMerging(rewriting.getMethod(), mergedClasses);
+          newBuilder.mapVirtualMethodToDirectInType(
+              from, new GraphLenseLookupResult(to, rewriting.getType()), context);
+        }
+      }
+      for (Map.Entry<DexMethod, DexMethod> entry : builder.originalMethodSignatures.entrySet()) {
+        newBuilder.recordMove(
+            entry.getValue(),
+            builder.getMethodSignatureAfterClassMerging(entry.getKey(), mergedClasses));
+      }
+      for (Map.Entry<DexMethod, DexMethod> entry :
+          builder.originalMethodSignaturesForBridges.entrySet()) {
+        newBuilder.recordCreationOfBridgeMethod(
+            entry.getValue(),
+            builder.getMethodSignatureAfterClassMerging(entry.getKey(), mergedClasses));
+      }
+      return newBuilder;
+    }
+
+    public GraphLense build(AppView<?> appView, Map<DexType, DexType> mergedClasses) {
+      if (mergedClasses.isEmpty()) {
+        return appView.graphLense();
+      }
       BiMap<DexField, DexField> originalFieldSignatures = fieldMap.inverse();
       // Build new graph lense.
       return new VerticalClassMergerGraphLense(
           appView,
+          mergedClasses,
           fieldMap,
           methodMap,
-          getMergedMethodSignaturesAfterClassMerging(
-              mergedMethodsBuilder.build(), mergedClasses, appView.dexItemFactory(), cache),
+          mergedMethodsBuilder.build(),
           contextualVirtualToDirectMethodMaps,
           originalFieldSignatures,
           originalMethodSignatures,
           originalMethodSignaturesForBridges,
-          previousLense);
+          appView.graphLense());
     }
 
-    // After we have recorded that a method "a.b.c.Foo;->m(A, B, C)V" was merged into another class,
-    // it could be that the class B was merged into its subclass B'. In that case we update the
-    // signature to "a.b.c.Foo;->m(A, B', C)V".
-    private static Set<DexMethod> getMergedMethodSignaturesAfterClassMerging(
-        Set<DexMethod> mergedMethods,
-        Map<DexType, DexType> mergedClasses,
-        DexItemFactory dexItemFactory,
-        Map<DexProto, DexProto> cache) {
-      ImmutableSet.Builder<DexMethod> result = ImmutableSet.builder();
-      for (DexMethod signature : mergedMethods) {
-        result.add(
-            getMethodSignatureAfterClassMerging(signature, mergedClasses, dexItemFactory, cache));
+    private DexField getFieldSignatureAfterClassMerging(
+        DexField field, Map<DexType, DexType> mergedClasses) {
+      assert !field.holder.isArrayType();
+
+      DexType holder = field.holder;
+      DexType newHolder = mergedClasses.getOrDefault(holder, holder);
+
+      DexType type = field.type;
+      DexType newType = mergedClasses.getOrDefault(type, type);
+
+      if (holder == newHolder && type == newType) {
+        return field;
       }
-      return result.build();
+      return dexItemFactory.createField(newHolder, newType, field.name);
     }
 
-    private static DexMethod getMethodSignatureAfterClassMerging(
-        DexMethod signature,
-        Map<DexType, DexType> mergedClasses,
-        DexItemFactory dexItemFactory,
-        Map<DexProto, DexProto> cache) {
+    private DexMethod getMethodSignatureAfterClassMerging(
+        DexMethod signature, Map<DexType, DexType> mergedClasses) {
       assert !signature.holder.isArrayType();
-      DexType newHolder = mergedClasses.getOrDefault(signature.holder, signature.holder);
+
+      DexType holder = signature.holder;
+      DexType newHolder = mergedClasses.getOrDefault(holder, holder);
+
+      DexProto proto = signature.proto;
       DexProto newProto =
           dexItemFactory.applyClassMappingToProto(
-              signature.proto,
-              type -> getTypeAfterClassMerging(type, mergedClasses, dexItemFactory),
-              cache);
-      if (signature.holder.equals(newHolder) && signature.proto.equals(newProto)) {
+              proto, type -> getTypeAfterClassMerging(type, mergedClasses), cache);
+
+      if (holder == newHolder && proto == newProto) {
         return signature;
       }
       return dexItemFactory.createMethod(newHolder, newProto, signature.name);
     }
 
-    private static DexType getTypeAfterClassMerging(
-        DexType type, Map<DexType, DexType> mergedClasses, DexItemFactory dexItemFactory) {
+    private DexType getTypeAfterClassMerging(DexType type, Map<DexType, DexType> mergedClasses) {
       if (type.isArrayType()) {
         DexType baseType = type.toBaseType(dexItemFactory);
         DexType newBaseType = mergedClasses.getOrDefault(baseType, baseType);
@@ -251,6 +295,15 @@
       return false;
     }
 
+    public boolean hasOriginalSignatureMappingFor(DexField field) {
+      return fieldMap.inverse().containsKey(field);
+    }
+
+    public boolean hasOriginalSignatureMappingFor(DexMethod method) {
+      return originalMethodSignatures.containsKey(method)
+          || originalMethodSignaturesForBridges.containsKey(method);
+    }
+
     public void markMethodAsMerged(DexMethod method) {
       mergedMethodsBuilder.add(method);
     }
@@ -259,8 +312,9 @@
       fieldMap.put(from, to);
     }
 
-    public void map(DexMethod from, DexMethod to) {
+    public Builder map(DexMethod from, DexMethod to) {
       methodMap.put(from, to);
+      return this;
     }
 
     public void recordMove(DexMethod from, DexMethod to) {
@@ -274,7 +328,7 @@
     public void mapVirtualMethodToDirectInType(
         DexMethod from, GraphLenseLookupResult to, DexType type) {
       Map<DexMethod, GraphLenseLookupResult> virtualToDirectMethodMap =
-          contextualVirtualToDirectMethodMaps.computeIfAbsent(type, key -> new HashMap<>());
+          contextualVirtualToDirectMethodMaps.computeIfAbsent(type, key -> new IdentityHashMap<>());
       virtualToDirectMethodMap.put(from, to);
     }
 
diff --git a/src/test/java/com/android/tools/r8/classmerging/IncorrectRewritingOfInvokeSuperTest.java b/src/test/java/com/android/tools/r8/classmerging/IncorrectRewritingOfInvokeSuperTest.java
index d4b045d..25ddfcd 100644
--- a/src/test/java/com/android/tools/r8/classmerging/IncorrectRewritingOfInvokeSuperTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/IncorrectRewritingOfInvokeSuperTest.java
@@ -4,8 +4,6 @@
 
 package com.android.tools.r8.classmerging;
 
-import static org.hamcrest.core.StringContains.containsString;
-
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
@@ -40,8 +38,7 @@
         .noMinification()
         .setMinApi(parameters.getRuntime())
         .run(parameters.getRuntime(), TestClass.class)
-        // TODO(b/133501933): Should succeed.
-        .assertFailureWithErrorThatMatches(containsString(StackOverflowError.class.getTypeName()));
+        .assertSuccess();
   }
 
   static class TestClass {