Amend method access info collection in horizontal class merger

Change-Id: Ie5a8ae70aafe4311766eed17e6a9974a32c38272
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index 849e67a..2721861 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -17,8 +17,10 @@
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.graph.MethodAccessInfoCollection;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.PrunedItems;
+import com.android.tools.r8.horizontalclassmerging.VirtualMethodMerger.SuperMethodReference;
 import com.android.tools.r8.horizontalclassmerging.code.SyntheticInitializerConverter;
 import com.android.tools.r8.ir.conversion.LirConverter;
 import com.android.tools.r8.ir.conversion.MethodConversionOptions;
@@ -35,6 +37,7 @@
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.TraversalContinuation;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.LinkedList;
@@ -196,13 +199,14 @@
     appView.setGraphLens(horizontalClassMergerGraphLens);
     codeProvider.setGraphLens(horizontalClassMergerGraphLens);
 
-    // Finalize synthetic code.
-    transformIncompleteCode(groups, horizontalClassMergerGraphLens, executorService);
-
     // Must rewrite AppInfoWithLiveness before pruning the merged classes, to ensure that allocation
     // sites, fields accesses, etc. are correctly transferred to the target classes.
     DexApplication newApplication = getNewApplication(mergedClasses);
     if (appView.enableWholeProgramOptimizations()) {
+      ProgramMethodMap<DexMethod> newNonReboundMethodReferences =
+          extractNonReboundMethodReferences(groups, horizontalClassMergerGraphLens);
+      // Finalize synthetic code.
+      transformIncompleteCode(groups, horizontalClassMergerGraphLens, executorService);
       // Prune keep info.
       AppView<AppInfoWithClassHierarchy> appViewWithClassHierarchy = appView.withClassHierarchy();
       KeepInfoCollection keepInfo = appView.getKeepInfo();
@@ -225,6 +229,8 @@
         }
         appView.clearCodeRewritings(executorService, timing);
       }
+      amendMethodAccessInfoCollection(
+          horizontalClassMergerGraphLens, newNonReboundMethodReferences);
     } else {
       assert mode.isFinal();
       SyntheticItems syntheticItems = appView.appInfo().getSyntheticItems();
@@ -288,6 +294,40 @@
             });
   }
 
+  private ProgramMethodMap<DexMethod> extractNonReboundMethodReferences(
+      Collection<HorizontalMergeGroup> groups, HorizontalClassMergerGraphLens lens) {
+    ProgramMethodMap<DexMethod> newNonReboundMethodReferences = ProgramMethodMap.create();
+    for (HorizontalMergeGroup group : groups) {
+      group
+          .getTarget()
+          .forEachProgramVirtualMethodMatching(
+              method ->
+                  method.hasCode()
+                      && method.getCode() instanceof IncompleteVirtuallyMergedMethodCode
+                      && ((IncompleteVirtuallyMergedMethodCode) method.getCode()).hasSuperMethod(),
+              method -> {
+                SuperMethodReference superMethodReference =
+                    ((IncompleteVirtuallyMergedMethodCode) method.getDefinition().getCode())
+                        .getSuperMethod();
+                newNonReboundMethodReferences.put(
+                    method, superMethodReference.getRewrittenReference(lens, method));
+              });
+    }
+    return newNonReboundMethodReferences;
+  }
+
+  private void amendMethodAccessInfoCollection(
+      HorizontalClassMergerGraphLens lens,
+      ProgramMethodMap<DexMethod> newNonReboundMethodReferences) {
+    MethodAccessInfoCollection.Modifier methodAccessInfoCollectionModifier =
+        appView.appInfoWithLiveness().getMethodAccessInfoCollection().modifier();
+    newNonReboundMethodReferences.forEach(
+        (context, reference) ->
+            // Reference is already lens rewritten.
+            methodAccessInfoCollectionModifier.registerInvokeSuperInContext(
+                reference, context.rewrittenWithLens(lens, lens.getPrevious(), appView)));
+  }
+
   private FieldAccessInfoCollectionModifier createFieldAccessInfoCollectionModifier(
       Collection<HorizontalMergeGroup> groups) {
     FieldAccessInfoCollectionModifier.Builder builder =
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/IncompleteVirtuallyMergedMethodCode.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/IncompleteVirtuallyMergedMethodCode.java
index e39290c..4b44559 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/IncompleteVirtuallyMergedMethodCode.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/IncompleteVirtuallyMergedMethodCode.java
@@ -146,12 +146,7 @@
       fallthroughTarget =
           lens.getNextMethodSignature(mappedMethods.get(mappedMethods.lastIntKey()));
     } else {
-      DexMethod reboundFallthroughTarget =
-          lens.lookupInvokeSuper(superMethod.getReboundReference(), method).getReference();
-      fallthroughTarget =
-          reboundFallthroughTarget.withHolder(
-              lens.getNextClassType(superMethod.getReference().getHolderType()),
-              appView.dexItemFactory());
+      fallthroughTarget = superMethod.getRewrittenReference(lens, method);
     }
     instructions.add(
         new CfInvoke(Opcodes.INVOKESPECIAL, fallthroughTarget, method.getHolder().isInterface()));
@@ -166,6 +161,14 @@
         lens, originalMethod.getHolderType(), maxStack, maxLocals, instructions);
   }
 
+  public boolean hasSuperMethod() {
+    return superMethod != null;
+  }
+
+  public SuperMethodReference getSuperMethod() {
+    return superMethod;
+  }
+
   @Override
   public LirCode<Integer> toLirCode(
       AppView<? extends AppInfoWithClassHierarchy> appView,
@@ -213,7 +216,7 @@
 
     // Emit switch.
     IntBidirectionalIterator classIdIterator = mappedMethods.keySet().iterator();
-    int[] keys = new int[mappedMethods.size() - BooleanUtils.intValue(superMethod == null)];
+    int[] keys = new int[mappedMethods.size() - BooleanUtils.intValue(!hasSuperMethod())];
     int[] targets = new int[keys.length];
     int nextTarget = instructionIndex - argumentValues.size() + 3;
     for (int i = 0; i < keys.length; i++) {
@@ -225,7 +228,10 @@
     instructionIndex++;
 
     // Emit switch fallthrough.
-    if (superMethod == null) {
+    if (hasSuperMethod()) {
+      lirBuilder.addInvokeSuper(
+          superMethod.getRewrittenReference(lens, method), argumentValues, false);
+    } else {
       DexMethod fallthroughTarget =
           lens.getNextMethodSignature(mappedMethods.get(mappedMethods.lastIntKey()));
       if (method.getHolder().isInterface()) {
@@ -233,14 +239,6 @@
       } else {
         lirBuilder.addInvokeVirtual(fallthroughTarget, argumentValues);
       }
-    } else {
-      DexMethod reboundFallthroughTarget =
-          lens.lookupInvokeSuper(superMethod.getReboundReference(), method).getReference();
-      DexMethod fallthroughTarget =
-          reboundFallthroughTarget.withHolder(
-              lens.getNextClassType(superMethod.getReference().getHolderType()),
-              appView.dexItemFactory());
-      lirBuilder.addInvokeSuper(fallthroughTarget, argumentValues, false);
     }
     if (method.getReturnType().isVoidType()) {
       lirBuilder.addReturnVoid();
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/VirtualMethodMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/VirtualMethodMerger.java
index 6175c5a..e166410 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/VirtualMethodMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/VirtualMethodMerger.java
@@ -38,12 +38,12 @@
       this.reboundReference = reboundReference;
     }
 
-    public DexMethod getReference() {
-      return reference;
-    }
-
-    public DexMethod getReboundReference() {
-      return reboundReference;
+    public DexMethod getRewrittenReference(
+        HorizontalClassMergerGraphLens lens, ProgramMethod context) {
+      DexMethod reboundFallthroughTarget =
+          lens.lookupInvokeSuper(reboundReference, context).getReference();
+      return reboundFallthroughTarget.withHolder(
+          lens.getNextClassType(reference.getHolderType()), lens.dexItemFactory());
     }
   }