Perform unused argument removal before horizontal class merging

This also prevents adding a synthetic argument when not necessary (Checked in the existing test `ConstructorMergingWithArgumentsTest` as well as in `ConstructorMergingOverlapTest``).

Change-Id: Ib4808c0f61abd455a870521aa5325da11a900288
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index 38904af..f61edb9 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -532,23 +532,6 @@
           }
           timing.end();
         }
-        if (options.enableHorizontalClassMerging) {
-          timing.begin("HorizontalClassMerger");
-          HorizontalClassMerger merger =
-              new HorizontalClassMerger(
-                  appViewWithLiveness, mainDexTracingResult, classMergingEnqueuerExtension);
-          DirectMappedDexApplication.Builder appBuilder =
-              appView.appInfo().app().asDirect().builder();
-          HorizontalClassMergerGraphLens lens = merger.run(appBuilder);
-          if (lens != null) {
-            appView.setHorizontallyMergedClasses(lens.getHorizontallyMergedClasses());
-            appView.rewriteWithLensAndApplication(lens, appBuilder.build());
-          }
-          timing.end();
-        }
-
-        // Only required for class merging, clear instance to save memory.
-        classMergingEnqueuerExtension = null;
 
         if (options.enableArgumentRemoval) {
           SubtypingInfo subtypingInfo = appViewWithLiveness.appInfo().computeSubtypingInfo();
@@ -577,6 +560,23 @@
             timing.end();
           }
         }
+        if (options.enableHorizontalClassMerging) {
+          timing.begin("HorizontalClassMerger");
+          HorizontalClassMerger merger =
+              new HorizontalClassMerger(
+                  appViewWithLiveness, mainDexTracingResult, classMergingEnqueuerExtension);
+          DirectMappedDexApplication.Builder appBuilder =
+              appView.appInfo().app().asDirect().builder();
+          HorizontalClassMergerGraphLens lens = merger.run(appBuilder);
+          if (lens != null) {
+            appView.setHorizontallyMergedClasses(lens.getHorizontallyMergedClasses());
+            appView.rewriteWithLensAndApplication(lens, appBuilder.build());
+          }
+          timing.end();
+        }
+
+        // Only required for class merging, clear instance to save memory.
+        classMergingEnqueuerExtension = null;
       }
 
       // None of the optimizations above should lead to the creation of type lattice elements.
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorEntryPoint.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorEntryPoint.java
index ff2a898..732a5d2 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorEntryPoint.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorEntryPoint.java
@@ -78,7 +78,9 @@
 
   protected void prepareMultiConstructorInstructions() {
     int typeConstructorCount = typeConstructors.size();
-    int idRegister = getParamRegister(method.getArity() - 1);
+    DexMethod exampleTargetConstructor = typeConstructors.values().iterator().next();
+    // The class id register is always the first synthetic argument.
+    int idRegister = getParamRegister(exampleTargetConstructor.getArity());
 
     addRegisterClassIdAssignment(idRegister);
 
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorMerger.java
index b411667..2f3c068 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorMerger.java
@@ -16,6 +16,9 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.MethodAccessFlags;
 import com.android.tools.r8.graph.ParameterAnnotationsList;
+import com.android.tools.r8.ir.conversion.ExtraConstantIntParameter;
+import com.android.tools.r8.ir.conversion.ExtraParameter;
+import com.android.tools.r8.ir.conversion.ExtraUnusedNullParameter;
 import com.android.tools.r8.shaking.FieldAccessInfoCollectionModifier;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceAVLTreeMap;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceSortedMap;
@@ -23,6 +26,7 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.LinkedList;
 import java.util.List;
 
 public class ConstructorMerger {
@@ -91,7 +95,7 @@
   }
 
   private DexProto getNewConstructorProto(SyntheticArgumentClass syntheticArgumentClass) {
-    DexEncodedMethod firstConstructor = constructors.stream().findFirst().get();
+    DexEncodedMethod firstConstructor = constructors.iterator().next();
     DexProto oldProto = firstConstructor.getProto();
 
     if (isTrivialMerge()) {
@@ -100,12 +104,19 @@
 
     List<DexType> parameters = new ArrayList<>();
     Collections.addAll(parameters, oldProto.parameters.values);
-    parameters.add(syntheticArgumentClass.getArgumentClass());
     parameters.add(dexItemFactory.intType);
+    if (syntheticArgumentClass != null) {
+      parameters.add(syntheticArgumentClass.getArgumentClass());
+    }
     // TODO(b/165783587): add synthesised class to prevent constructor merge conflict
     return dexItemFactory.createProto(oldProto.returnType, parameters);
   }
 
+  private DexMethod getNewConstructorReference(SyntheticArgumentClass syntheticArgumentClass) {
+    DexProto proto = getNewConstructorProto(syntheticArgumentClass);
+    return appView.dexItemFactory().createMethod(target.type, proto, dexItemFactory.initMethodName);
+  }
+
   private MethodAccessFlags getAccessFlags() {
     // TODO(b/164998929): ensure this behaviour is correct, should probably calculate upper bound
     return MethodAccessFlags.fromSharedAccessFlags(
@@ -132,12 +143,15 @@
           classIdentifiers.getInt(constructor.getHolderType()), movedConstructor);
     }
 
-    DexProto newProto = getNewConstructorProto(syntheticArgumentClass);
+    DexMethod newConstructorReference = getNewConstructorReference(null);
+    boolean addExtraNull = target.lookupMethod(newConstructorReference) != null;
+    if (addExtraNull) {
+      newConstructorReference = getNewConstructorReference(syntheticArgumentClass);
+      assert target.lookupMethod(newConstructorReference) == null;
+    }
 
     DexMethod originalConstructorReference =
         appView.graphLens().getOriginalMethodSignature(constructors.iterator().next().method);
-    DexMethod newConstructorReference =
-        appView.dexItemFactory().createMethod(target.type, newProto, dexItemFactory.initMethodName);
     ConstructorEntryPointSynthesizedCode synthesizedCode =
         new ConstructorEntryPointSynthesizedCode(
             typeConstructorClassMap,
@@ -161,10 +175,16 @@
     } else {
       // Map each old constructor to the newly synthesized constructor in the graph lens.
       for (DexEncodedMethod oldConstructor : constructors) {
+        int classIdentifier = classIdentifiers.getInt(oldConstructor.getHolderType());
+
+        List<ExtraParameter> extraParameters = new LinkedList<>();
+        extraParameters.add(new ExtraConstantIntParameter(classIdentifier));
+        if (addExtraNull) {
+          extraParameters.add(new ExtraUnusedNullParameter());
+        }
+
         lensBuilder.mapMergedConstructor(
-            oldConstructor.method,
-            newConstructorReference,
-            classIdentifiers.getInt(oldConstructor.getHolderType()));
+            oldConstructor.method, newConstructorReference, extraParameters);
       }
     }
     // Map the first constructor to the newly synthesized constructor.
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerGraphLens.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerGraphLens.java
index 0179237..df0ba26 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerGraphLens.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerGraphLens.java
@@ -10,25 +10,26 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.GraphLens;
 import com.android.tools.r8.graph.GraphLens.NestedGraphLens;
+import com.android.tools.r8.graph.RewrittenPrototypeDescription;
 import com.android.tools.r8.ir.code.Invoke.Type;
-import com.android.tools.r8.ir.conversion.ExtraConstantIntParameter;
-import com.android.tools.r8.ir.conversion.ExtraUnusedNullParameter;
+import com.android.tools.r8.ir.conversion.ExtraParameter;
 import com.google.common.collect.BiMap;
 import com.google.common.collect.HashBiMap;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.IdentityHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
 public class HorizontalClassMergerGraphLens extends NestedGraphLens {
   private final AppView<?> appView;
-  private final Map<DexMethod, Integer> constructorIds;
+  private final Map<DexMethod, List<ExtraParameter>> methodExtraParameters;
   private final Map<DexMethod, DexMethod> originalConstructorSignatures;
 
   private HorizontalClassMergerGraphLens(
       AppView<?> appView,
-      Map<DexMethod, Integer> constructorIds,
+      Map<DexMethod, List<ExtraParameter>> methodExtraParameters,
       Map<DexType, DexType> typeMap,
       Map<DexField, DexField> fieldMap,
       Map<DexMethod, DexMethod> methodMap,
@@ -45,7 +46,7 @@
         previousLens,
         appView.dexItemFactory());
     this.appView = appView;
-    this.constructorIds = constructorIds;
+    this.methodExtraParameters = methodExtraParameters;
     this.originalConstructorSignatures = originalConstructorSignatures;
   }
 
@@ -69,19 +70,19 @@
   @Override
   public GraphLensLookupResult lookupMethod(DexMethod method, DexMethod context, Type type) {
     DexMethod previousContext = internalGetPreviousMethodSignature(context);
-    GraphLensLookupResult previousLookup = previousLens.lookupMethod(method, previousContext, type);
-    Integer constructorId = constructorIds.get(previousLookup.getMethod());
+    GraphLensLookupResult previousLookup =
+        getPrevious().lookupMethod(method, previousContext, type);
+    List<ExtraParameter> extraParameters = methodExtraParameters.get(previousLookup.getMethod());
 
     GraphLensLookupResult lookup = super.lookupMethod(method, previousLookup);
-    if (constructorId != null) {
+    if (extraParameters != null) {
       DexMethod newMethod = lookup.getMethod();
+
+      RewrittenPrototypeDescription prototypeChanges =
+          lookup.getPrototypeChanges().withExtraParameters(extraParameters);
+
       return new GraphLensLookupResult(
-          newMethod,
-          mapInvocationType(newMethod, method, lookup.getType()),
-          lookup
-              .getPrototypeChanges()
-              .withExtraParameter(new ExtraUnusedNullParameter())
-              .withExtraParameter(new ExtraConstantIntParameter(constructorId)));
+          newMethod, mapInvocationType(newMethod, method, lookup.getType()), prototypeChanges);
     } else {
       return lookup;
     }
@@ -96,7 +97,8 @@
     private final BiMap<DexMethod, DexMethod> originalMethodSignatures = HashBiMap.create();
     private final Map<DexMethod, DexMethod> extraOriginalMethodSignatures = new IdentityHashMap<>();
 
-    private final Map<DexMethod, Integer> constructorIds = new IdentityHashMap<>();
+    private final Map<DexMethod, List<ExtraParameter>> methodExtraParameters =
+        new IdentityHashMap<>();
 
     Builder() {}
 
@@ -106,7 +108,7 @@
       BiMap<DexField, DexField> originalFieldSignatures = fieldMap.inverse();
       return new HorizontalClassMergerGraphLens(
           appView,
-          constructorIds,
+          methodExtraParameters,
           typeMap,
           fieldMap,
           methodMap,
@@ -189,13 +191,11 @@
      * One way mapping from one constructor to another. This is used for synthesized constructors,
      * where many constructors are merged into a single constructor. The synthesized constructor
      * therefore does not have a unique reverse constructor.
-     *
-     * @param constructorId The id that must be appended to the constructor call to ensure the
-     *     correct constructor is called.
      */
-    public Builder mapMergedConstructor(DexMethod from, DexMethod to, int constructorId) {
+    public Builder mapMergedConstructor(
+        DexMethod from, DexMethod to, List<ExtraParameter> extraParameters) {
       mapMethod(from, to);
-      constructorIds.put(from, constructorId);
+      methodExtraParameters.put(from, extraParameters);
       return this;
     }
   }
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingOverlapTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingOverlapTest.java
index 8e41e73..316d547 100644
--- a/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingOverlapTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingOverlapTest.java
@@ -55,7 +55,7 @@
                 ClassSubject synthesizedClass = getSynthesizedArgumentClassSubject(codeInspector);
 
                 MethodSubject otherInitSubject =
-                    aClassSubject.init(synthesizedClass.getFinalName(), "int");
+                    aClassSubject.init("int", synthesizedClass.getFinalName());
                 assertThat(otherInitSubject, isPresent());
                 assertThat(
                     otherInitSubject, writesInstanceField(classIdFieldSubject.getFieldReference()));
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingPreoptimizedTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingPreoptimizedTest.java
index df8f653..b7d3a2d 100644
--- a/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingPreoptimizedTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingPreoptimizedTest.java
@@ -12,6 +12,7 @@
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoHorizontalClassMerging;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.horizontalclassmerging.ClassMerger;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
@@ -35,12 +36,17 @@
             options -> options.enableHorizontalClassMerging = enableHorizontalClassMerging)
         .enableInliningAnnotations()
         .enableNeverClassInliningAnnotations()
+        .enableNoHorizontalClassMergingAnnotations()
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), Main.class)
-        .assertSuccessWithOutputLines("changed", "13", "42", "foo", "7", "print a", "print b")
+        .assertSuccessWithOutputLines(
+            "changed", "13", "42", "foo", "7", "foo", "print a", "print b")
         .inspect(
             codeInspector -> {
               if (enableHorizontalClassMerging) {
+                ClassSubject changedClassSubject = codeInspector.clazz(Changed.class);
+                assertThat(changedClassSubject, isPresent());
+
                 ClassSubject aClassSubject = codeInspector.clazz(A.class);
                 assertThat(aClassSubject, isPresent());
                 FieldSubject classIdFieldSubject =
@@ -52,14 +58,8 @@
                 assertThat(
                     firstInitSubject, writesInstanceField(classIdFieldSubject.getFieldReference()));
 
-                ClassSubject syntheticArgumentSubject =
-                    getSynthesizedArgumentClassSubject(codeInspector);
-
                 MethodSubject otherInitSubject =
-                    aClassSubject.init(
-                        aClassSubject.getFinalName(),
-                        syntheticArgumentSubject.getFinalName(),
-                        "int");
+                    aClassSubject.init(changedClassSubject.getFinalName(), "int");
                 assertThat(otherInitSubject, isPresent());
                 assertThat(
                     otherInitSubject, writesInstanceField(classIdFieldSubject.getFieldReference()));
@@ -80,6 +80,7 @@
   }
 
   @NeverClassInline
+  @NoHorizontalClassMerging
   public static class Parent {
     @NeverInline
     public void foo() {
@@ -88,6 +89,7 @@
   }
 
   @NeverClassInline
+  @NoHorizontalClassMerging
   public static class Changed extends Parent {
     public Changed() {
       System.out.println("changed");
@@ -115,6 +117,7 @@
   public static class B {
     public B(Parent p) {
       System.out.println(7);
+      p.foo();
     }
 
     @NeverInline
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingWithArgumentsTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingWithArgumentsTest.java
index a1b2d1a..a4a0fd9 100644
--- a/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingWithArgumentsTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingWithArgumentsTest.java
@@ -10,6 +10,8 @@
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
 import org.junit.Test;
 
 public class ConstructorMergingWithArgumentsTest extends HorizontalClassMergingTestBase {
@@ -32,8 +34,13 @@
         .inspect(
             codeInspector -> {
               if (enableHorizontalClassMerging) {
-                assertThat(codeInspector.clazz(A.class), isPresent());
+                ClassSubject aClassSubject = codeInspector.clazz(A.class);
+
+                assertThat(aClassSubject, isPresent());
                 assertThat(codeInspector.clazz(B.class), not(isPresent()));
+
+                MethodSubject initSubject = aClassSubject.init(String.class.getName(), "int");
+                assertThat(initSubject, isPresent());
                 // TODO(b/165517236): Explicitly check classes have been merged.
               } else {
                 assertThat(codeInspector.clazz(A.class), isPresent());
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/MergedConstructorForwardingTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/MergedConstructorForwardingTest.java
index c6df40f..b2e1d59 100644
--- a/src/test/java/com/android/tools/r8/classmerging/horizontal/MergedConstructorForwardingTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/MergedConstructorForwardingTest.java
@@ -50,16 +50,12 @@
                     aClassSubject.uniqueFieldWithName(ClassMerger.CLASS_ID_FIELD_NAME);
                 assertThat(classIdFieldSubject, isPresent());
 
-                ClassSubject synthesizedClass = getSynthesizedArgumentClassSubject(codeInspector);
-
-                MethodSubject firstInitSubject =
-                    aClassSubject.init(synthesizedClass.getFinalName(), "int");
+                MethodSubject firstInitSubject = aClassSubject.init("int");
                 assertThat(firstInitSubject, isPresent());
                 assertThat(
                     firstInitSubject, writesInstanceField(classIdFieldSubject.getFieldReference()));
 
-                MethodSubject otherInitSubject =
-                    aClassSubject.init("long", synthesizedClass.getFinalName(), "int");
+                MethodSubject otherInitSubject = aClassSubject.init("long", "int");
                 assertThat(otherInitSubject, isPresent());
                 assertThat(
                     otherInitSubject, writesInstanceField(classIdFieldSubject.getFieldReference()));