Update set of merged methods after vertical class merging

Change-Id: I273d1726b27b556db471c626afd8cf7ee2c2280c
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 4e4a78f..feef8f7 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -18,12 +18,15 @@
 import com.android.tools.r8.ir.code.Position;
 import com.android.tools.r8.kotlin.Kotlin;
 import com.android.tools.r8.naming.NamingLens;
+import com.android.tools.r8.utils.ArrayUtils;
 import com.google.common.base.Strings;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
+import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
 import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.IdentityHashMap;
@@ -32,6 +35,7 @@
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Consumer;
+import java.util.function.Function;
 
 public class DexItemFactory {
 
@@ -662,6 +666,41 @@
         parameters.length == 0 ? DexTypeList.empty() : new DexTypeList(parameters));
   }
 
+  public DexProto applyClassMappingToProto(
+      DexProto proto, Function<DexType, DexType> mapping, Map<DexProto, DexProto> cache) {
+    assert cache != null;
+    DexProto result = cache.get(proto);
+    if (result == null) {
+      DexType returnType = mapping.apply(proto.returnType);
+      DexType[] parameters = applyClassMappingToDexTypes(proto.parameters.values, mapping);
+      if (returnType == proto.returnType && parameters == proto.parameters.values) {
+        result = proto;
+      } else {
+        // Should be different if reference has changed.
+        assert returnType == proto.returnType || !returnType.equals(proto.returnType);
+        assert parameters == proto.parameters.values
+            || !Arrays.equals(parameters, proto.parameters.values);
+        result = createProto(returnType, parameters);
+      }
+      cache.put(proto, result);
+    }
+    return result;
+  }
+
+  private static DexType[] applyClassMappingToDexTypes(
+      DexType[] types, Function<DexType, DexType> mapping) {
+    Map<Integer, DexType> changed = new Int2ObjectArrayMap<>();
+    for (int i = 0; i < types.length; i++) {
+      DexType applied = mapping.apply(types[i]);
+      if (applied != types[i]) {
+        changed.put(i, applied);
+      }
+    }
+    return changed.isEmpty()
+        ? types
+        : ArrayUtils.copyWithSparseChanges(DexType[].class, types, changed);
+  }
+
   private DexString createShorty(DexType returnType, DexType[] argumentTypes) {
     StringBuilder shortyBuilder = new StringBuilder();
     shortyBuilder.append(returnType.toShorty());
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index e653ff0..6b016b9 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -36,8 +36,10 @@
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.StaticPut;
 import com.android.tools.r8.ir.code.Value;
+import java.util.HashMap;
 import java.util.List;
 import java.util.ListIterator;
+import java.util.Map;
 import java.util.stream.Collectors;
 
 public class LensCodeRewriter {
@@ -45,6 +47,8 @@
   private final GraphLense graphLense;
   private final AppInfoWithSubtyping appInfo;
 
+  private final Map<DexProto, DexProto> protoFixupCache = new HashMap<>();
+
   public LensCodeRewriter(GraphLense graphLense, AppInfoWithSubtyping appInfo) {
     this.graphLense = graphLense;
     this.appInfo = appInfo;
@@ -71,13 +75,9 @@
         if (current.isInvokeCustom()) {
           InvokeCustom invokeCustom = current.asInvokeCustom();
           DexCallSite callSite = invokeCustom.getCallSite();
-          DexType[] newParameters = new DexType[callSite.methodProto.parameters.size()];
-          for (int i = 0; i < callSite.methodProto.parameters.size(); i++) {
-            newParameters[i] = graphLense.lookupType(callSite.methodProto.parameters.values[i]);
-          }
           DexProto newMethodProto =
-              appInfo.dexItemFactory.createProto(
-                  graphLense.lookupType(callSite.methodProto.returnType), newParameters);
+              appInfo.dexItemFactory.applyClassMappingToProto(
+                  callSite.methodProto, graphLense::lookupType, protoFixupCache);
           DexMethodHandle newBootstrapMethod = rewriteDexMethodHandle(method,
               callSite.bootstrapMethod);
           List<DexValue> newArgs = callSite.bootstrapArgs.stream().map(
diff --git a/src/main/java/com/android/tools/r8/naming/ProguardMapApplier.java b/src/main/java/com/android/tools/r8/naming/ProguardMapApplier.java
index 7dfd840..c159b9d 100644
--- a/src/main/java/com/android/tools/r8/naming/ProguardMapApplier.java
+++ b/src/main/java/com/android/tools/r8/naming/ProguardMapApplier.java
@@ -209,32 +209,8 @@
     }
 
     private DexProto applyClassMappingOnTheFly(DexProto proto) {
-      DexProto result = protoFixupCache.get(proto);
-      if (result == null) {
-        DexType returnType = applyClassMappingOnTheFly(proto.returnType);
-        DexType[] arguments = applyClassMappingOnTheFly(proto.parameters.values);
-        if (arguments != null || returnType != proto.returnType) {
-          arguments = arguments == null ? proto.parameters.values : arguments;
-          result = appInfo.dexItemFactory.createProto(returnType, arguments);
-        } else {
-          result = proto;
-        }
-        protoFixupCache.put(proto, result);
-      }
-      return result;
-    }
-
-    private DexType[] applyClassMappingOnTheFly(DexType[] types) {
-      Map<Integer, DexType> changed = new Int2ObjectArrayMap<>();
-      for (int i = 0; i < types.length; i++) {
-        DexType applied = applyClassMappingOnTheFly(types[i]);
-        if (applied != types[i]) {
-          changed.put(i, applied);
-        }
-      }
-      return changed.isEmpty()
-          ? null
-          : ArrayUtils.copyWithSparseChanges(DexType[].class, types, changed);
+      return appInfo.dexItemFactory.applyClassMappingToProto(
+          proto, this::applyClassMappingOnTheFly, protoFixupCache);
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index d06c73a..ed36a70 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -413,7 +413,7 @@
     if (Log.ENABLED) {
       Log.debug(getClass(), "Merged %d classes.", numberOfMerges);
     }
-    return renamedMembersLense.build(graphLense);
+    return renamedMembersLense.build(graphLense, mergedClasses, application.dexItemFactory);
   }
 
   private boolean methodResolutionMayChange(DexClass source, DexClass target) {
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMergerGraphLense.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMergerGraphLense.java
index 4cfe8db..4f614b4 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMergerGraphLense.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMergerGraphLense.java
@@ -7,7 +7,9 @@
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.GraphLense;
 import com.android.tools.r8.graph.GraphLense.NestedGraphLense;
@@ -145,7 +147,10 @@
       this.appInfo = appInfo;
     }
 
-    public GraphLense build(GraphLense previousLense) {
+    public GraphLense build(
+        GraphLense previousLense,
+        Map<DexType, DexType> mergedClasses,
+        DexItemFactory dexItemFactory) {
       Map<DexField, DexField> fieldMap = fieldMapBuilder.build();
       Map<DexMethod, DexMethod> methodMap = methodMapBuilder.build();
       if (fieldMap.isEmpty()
@@ -157,11 +162,35 @@
           appInfo,
           fieldMap,
           methodMap,
-          mergedMethodsBuilder.build(),
+          getMergedMethodSignaturesAfterClassMerging(
+              mergedMethodsBuilder.build(), mergedClasses, dexItemFactory),
           contextualVirtualToDirectMethodMaps,
           previousLense);
     }
 
+    // After we have recorded that a method "a.b.c.Foo;->m(A, B, C)V" was merged into another class,
+    // it could be that the class B was merged into its subclass B'. In that case we update the
+    // signature to "a.b.c.Foo;->m(A, B', C)V".
+    private Set<DexMethod> getMergedMethodSignaturesAfterClassMerging(
+        Set<DexMethod> mergedMethods,
+        Map<DexType, DexType> mergedClasses,
+        DexItemFactory dexItemFactory) {
+      ImmutableSet.Builder<DexMethod> result = ImmutableSet.builder();
+      Map<DexProto, DexProto> cache = new HashMap<>();
+      for (DexMethod signature : mergedMethods) {
+        DexType newHolder = mergedClasses.getOrDefault(signature.holder, signature.holder);
+        DexProto newProto =
+            dexItemFactory.applyClassMappingToProto(
+                signature.proto, type -> mergedClasses.getOrDefault(type, type), cache);
+        if (signature.holder.equals(newHolder) && signature.proto.equals(newProto)) {
+          result.add(signature);
+        } else {
+          result.add(dexItemFactory.createMethod(newHolder, newProto, signature.name));
+        }
+      }
+      return result.build();
+    }
+
     public boolean hasMappingForSignatureInContext(DexType context, DexMethod signature) {
       Map<DexMethod, DexMethod> virtualToDirectMethodMap =
           contextualVirtualToDirectMethodMaps.get(context);