Enable class merging in presence of unused argument removal

Change-Id: I0163b41a3346b1663fe0fd05fdfb41f8caf6dcb0
diff --git a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
index 8a44fd0..c8fc0a8 100644
--- a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
+++ b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
@@ -396,6 +396,12 @@
     return rewrittenReturnInfo != null;
   }
 
+  public boolean requiresRewritingAtCallSite() {
+    return hasRewrittenReturnInfo()
+        || numberOfExtraParameters() > 0
+        || argumentInfoCollection.numberOfRemovedArguments() > 0;
+  }
+
   public RewrittenTypeInfo getRewrittenReturnInfo() {
     return rewrittenReturnInfo;
   }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index 6db4dc6..0821a4c 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -27,7 +27,6 @@
 import com.android.tools.r8.horizontalclassmerging.policies.NoNativeMethods;
 import com.android.tools.r8.horizontalclassmerging.policies.NoServiceLoaders;
 import com.android.tools.r8.horizontalclassmerging.policies.NoStaticClassInitializer;
-import com.android.tools.r8.horizontalclassmerging.policies.NoUnusedArguments;
 import com.android.tools.r8.horizontalclassmerging.policies.NotMatchedByNoHorizontalClassMerging;
 import com.android.tools.r8.horizontalclassmerging.policies.NotVerticallyMergedIntoSubtype;
 import com.android.tools.r8.horizontalclassmerging.policies.PreserveMethodCharacteristics;
@@ -122,7 +121,6 @@
         new NoKeepRules(appView),
         new NoKotlinMetadata(),
         new NoKotlinLambdas(appView),
-        new NoUnusedArguments(appView),
         new NoServiceLoaders(appView),
         new NotVerticallyMergedIntoSubtype(appView),
         new NoDirectRuntimeTypeChecks(runtimeTypeCheckInfo),
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerGraphLens.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerGraphLens.java
index d348b61..16f853e 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerGraphLens.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerGraphLens.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.GraphLens;
 import com.android.tools.r8.graph.GraphLens.NestedGraphLens;
+import com.android.tools.r8.graph.RewrittenPrototypeDescription;
 import com.android.tools.r8.ir.conversion.ExtraParameter;
 import com.android.tools.r8.utils.IterableUtils;
 import com.android.tools.r8.utils.collections.BidirectionalOneToOneHashMap;
@@ -54,12 +55,25 @@
     this.mergedClasses = mergedClasses;
   }
 
+  private boolean isSynthesizedByHorizontalClassMerging(DexMethod method) {
+    return methodExtraParameters.containsKey(method);
+  }
+
   @Override
   protected Iterable<DexType> internalGetOriginalTypes(DexType previous) {
     return IterableUtils.prependSingleton(previous, mergedClasses.getSourcesFor(previous));
   }
 
   @Override
+  public RewrittenPrototypeDescription lookupPrototypeChangesForMethodDefinition(DexMethod method) {
+    if (isSynthesizedByHorizontalClassMerging(method)) {
+      // If we are processing the call site, the arguments should be removed.
+      return RewrittenPrototypeDescription.none();
+    }
+    return super.lookupPrototypeChangesForMethodDefinition(method);
+  }
+
+  @Override
   public DexMethod getOriginalMethodSignature(DexMethod method) {
     DexMethod originalConstructor = extraOriginalMethodSignatures.get(method);
     if (originalConstructor == null) {
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoUnusedArguments.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoUnusedArguments.java
deleted file mode 100644
index 907c855..0000000
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoUnusedArguments.java
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package com.android.tools.r8.horizontalclassmerging.policies;
-
-import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.RewrittenPrototypeDescription;
-import com.android.tools.r8.horizontalclassmerging.SingleClassPolicy;
-import com.android.tools.r8.shaking.AppInfoWithLiveness;
-
-public class NoUnusedArguments extends SingleClassPolicy {
-
-  private final AppView<AppInfoWithLiveness> appView;
-
-  public NoUnusedArguments(AppView<AppInfoWithLiveness> appView) {
-    this.appView = appView;
-  }
-
-  @Override
-  public boolean canMerge(DexProgramClass program) {
-    for (DexEncodedMethod method : program.methods()) {
-      RewrittenPrototypeDescription prototypeChanges =
-          appView.graphLens().lookupPrototypeChangesForMethodDefinition(method.getReference());
-      if (prototypeChanges.getArgumentInfoCollection().numberOfRemovedArguments() > 0) {
-        return false;
-      }
-    }
-    return true;
-  }
-}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 20cc2c6..cf605e2 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -223,9 +223,11 @@
                   graphLens.lookupMethod(invokedMethod, method.getReference(), invoke.getType());
               DexMethod actualTarget = lensLookup.getReference();
               Invoke.Type actualInvokeType = lensLookup.getType();
-              if (actualTarget != invokedMethod || invoke.getType() != actualInvokeType) {
-                RewrittenPrototypeDescription prototypeChanges = lensLookup.getPrototypeChanges();
 
+              RewrittenPrototypeDescription prototypeChanges = lensLookup.getPrototypeChanges();
+              if (prototypeChanges.requiresRewritingAtCallSite()
+                  || invoke.getType() != actualInvokeType
+                  || actualTarget != invokedMethod) {
                 List<Value> newInValues;
                 ArgumentInfoCollection argumentInfoCollection =
                     prototypeChanges.getArgumentInfoCollection();
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingAfterUnusedArgumentRemovalTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingAfterUnusedArgumentRemovalTest.java
new file mode 100644
index 0000000..32f7454
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/ConstructorMergingAfterUnusedArgumentRemovalTest.java
@@ -0,0 +1,78 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.TestParameters;
+import org.junit.Test;
+
+public class ConstructorMergingAfterUnusedArgumentRemovalTest
+    extends HorizontalClassMergingTestBase {
+
+  public ConstructorMergingAfterUnusedArgumentRemovalTest(
+      TestParameters parameters, boolean enableHorizontalClassMerging) {
+    super(parameters, enableHorizontalClassMerging);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options -> options.enableHorizontalClassMerging = enableHorizontalClassMerging)
+        .enableNeverClassInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .addHorizontallyMergedClassesInspectorIf(
+            enableHorizontalClassMerging,
+            inspector ->
+                inspector
+                    .assertMergedInto(B.class, A.class)
+                    .assertMergedInto(C.class, A.class)
+                    .assertMergedInto(D.class, A.class))
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(
+            "A.<init>(?, 43)", "B.<init>(44)", "C.<init>()", "D.<init>()");
+  }
+
+  @NeverClassInline
+  public static class A {
+
+    public A(int unused, int x) {
+      System.out.println("A.<init>(?, " + x + ")");
+    }
+  }
+
+  @NeverClassInline
+  public static class B {
+
+    public B(int x) {
+      System.out.println("B.<init>(" + x + ")");
+    }
+  }
+
+  @NeverClassInline
+  public static class C {
+    public C() {
+      System.out.println("C.<init>()");
+    }
+  }
+
+  @NeverClassInline
+  public static class D {
+    public D() {
+      System.out.println("D.<init>()");
+    }
+  }
+
+  public static class Main {
+    public static void main(String[] args) {
+      new A(42, 43);
+      new B(44);
+      new C();
+      new D();
+    }
+  }
+}