Fix use of incorrect method reference in horizontal class merger

Fixes: b/328582778
Change-Id: I272540f4c557c9ddd8315aa27ea58953d77008b8
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorEntryPoint.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorEntryPoint.java
index 49532c6..24efc9a 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorEntryPoint.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorEntryPoint.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.code.ConstNumber;
 import com.android.tools.r8.ir.code.InvokeType;
 import com.android.tools.r8.ir.code.Position;
@@ -36,21 +37,23 @@
  * </code>
  */
 public class ConstructorEntryPoint extends SyntheticSourceCode {
+
   private final DexField classIdField;
   private final int extraNulls;
+  private final ProgramMethod method;
   private final Int2ReferenceSortedMap<DexMethod> typeConstructors;
 
   public ConstructorEntryPoint(
       Int2ReferenceSortedMap<DexMethod> typeConstructors,
-      DexMethod newConstructor,
+      ProgramMethod method,
       DexField classIdField,
       int extraNulls,
       Position position) {
-    super(newConstructor.holder, newConstructor, position);
-
-    this.typeConstructors = typeConstructors;
+    super(method, position);
     this.classIdField = classIdField;
     this.extraNulls = extraNulls;
+    this.method = method;
+    this.typeConstructors = typeConstructors;
   }
 
   private boolean hasClassIdField() {
@@ -95,25 +98,17 @@
   }
 
   /** Assign the given register to the class id field. */
-  void addRegisterClassIdAssignment(int idRegister) {
+  void addRegisterClassIdAssignment(int classIdRegister) {
     assert hasClassIdField();
-    add(builder -> builder.addInstancePut(idRegister, getReceiverRegister(), classIdField));
-  }
-
-  /** Assign the given constant integer value to the class id field. */
-  void addConstantRegisterClassIdAssignment(int classId) {
-    assert hasClassIdField();
-    int idRegister = nextRegister(ValueType.INT);
-    add(builder -> builder.addIntConst(idRegister, classId));
-    addRegisterClassIdAssignment(idRegister);
+    add(builder -> builder.addInstancePut(classIdRegister, getReceiverRegister(), classIdField));
   }
 
   protected void prepareMultiConstructorInstructions() {
     int typeConstructorCount = typeConstructors.size();
     // The class id register is always the first synthetic argument.
-    int idRegister = getParamRegister(method.getArity() - 1 - extraNulls);
+    int classIdRegister = getParamRegister(method.getArity() - 1 - extraNulls);
     if (hasClassIdField()) {
-      addRegisterClassIdAssignment(idRegister);
+      addRegisterClassIdAssignment(classIdRegister);
     }
 
     int[] keys = new int[typeConstructorCount - 1];
@@ -121,7 +116,7 @@
     IntBox fallthrough = new IntBox();
     int switchIndex = lastInstructionIndex();
     add(
-        builder -> builder.addSwitch(idRegister, keys, fallthrough.get(), offsets),
+        builder -> builder.addSwitch(classIdRegister, keys, fallthrough.get(), offsets),
         builder -> endsSwitch(builder, switchIndex, fallthrough.get(), offsets));
 
     int index = 0;
@@ -148,7 +143,10 @@
   protected void prepareSingleConstructorInstructions() {
     Entry<DexMethod> entry = typeConstructors.int2ReferenceEntrySet().first();
     if (hasClassIdField()) {
-      addConstantRegisterClassIdAssignment(entry.getIntKey());
+      int classIdRegister = nextRegister(ValueType.INT);
+      int classIdValue = entry.getIntKey();
+      add(builder -> builder.addIntConst(classIdRegister, classIdValue));
+      addRegisterClassIdAssignment(classIdRegister);
     }
     addConstructorInvoke(entry.getValue());
     add(IRBuilder::addReturn, endsBlock);
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/IncompleteHorizontalClassMergerCode.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/IncompleteHorizontalClassMergerCode.java
index 031fa97..9aada42 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/IncompleteHorizontalClassMergerCode.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/IncompleteHorizontalClassMergerCode.java
@@ -24,7 +24,7 @@
   public abstract void addExtraUnusedArguments(int numberOfUnusedArguments);
 
   @Override
-  public boolean isHorizontalClassMergerCode() {
+  public final boolean isHorizontalClassMergerCode() {
     return true;
   }
 
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/InstanceInitializerMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/InstanceInitializerMerger.java
index 5f1815b..559f14f 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/InstanceInitializerMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/InstanceInitializerMerger.java
@@ -252,7 +252,6 @@
   }
 
   private Code getNewCode(
-      DexMethod newMethodReference,
       boolean needsClassId,
       int extraNulls) {
     if (hasInstanceInitializerDescription()) {
@@ -261,7 +260,6 @@
     assert useSyntheticMethod();
     return new ConstructorEntryPointSynthesizedCode(
         createClassIdToInstanceInitializerMap(),
-        newMethodReference,
         group.hasClassIdField() ? group.getClassIdField() : null,
         extraNulls);
   }
@@ -349,7 +347,7 @@
           DexEncodedMethod.syntheticBuilder()
               .setMethod(newMethodReference)
               .setAccessFlags(getNewAccessFlags())
-              .setCode(getNewCode(newMethodReference, needsClassId, extraUnusedParameters.size()))
+              .setCode(getNewCode(needsClassId, extraUnusedParameters.size()))
               .setClassFileVersion(getNewClassFileVersion())
               .setApiLevelForDefinition(representativeMethod.getApiLevelForDefinition())
               .setApiLevelForCode(representativeMethod.getApiLevelForCode())
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/code/ConstructorEntryPointSynthesizedCode.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/code/ConstructorEntryPointSynthesizedCode.java
index 706a4fe..b0cfa19 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/code/ConstructorEntryPointSynthesizedCode.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/code/ConstructorEntryPointSynthesizedCode.java
@@ -36,18 +36,15 @@
 
 public class ConstructorEntryPointSynthesizedCode extends IncompleteHorizontalClassMergerCode {
 
-  private final DexMethod newConstructor;
   private final DexField classIdField;
   private int extraNulls;
   private final Int2ReferenceSortedMap<DexMethod> typeConstructors;
 
   public ConstructorEntryPointSynthesizedCode(
       Int2ReferenceSortedMap<DexMethod> typeConstructors,
-      DexMethod newConstructor,
       DexField classIdField,
       int extraNulls) {
     this.typeConstructors = typeConstructors;
-    this.newConstructor = newConstructor;
     this.classIdField = classIdField;
     this.extraNulls = extraNulls;
   }
@@ -81,11 +78,6 @@
   }
 
   @Override
-  public boolean isHorizontalClassMergerCode() {
-    return true;
-  }
-
-  @Override
   public LirCode<Integer> toLirCode(
       AppView<? extends AppInfoWithClassHierarchy> appView,
       ProgramMethod method,
@@ -128,8 +120,7 @@
             .setIsD8R8Synthesized(true)
             .build();
     SourceCode sourceCode =
-        new ConstructorEntryPoint(
-            typeConstructors, newConstructor, classIdField, extraNulls, position);
+        new ConstructorEntryPoint(typeConstructors, method, classIdField, extraNulls, position);
     return IRBuilder.create(method, appView, sourceCode).build(method, conversionOptions);
   }
 
@@ -144,7 +135,7 @@
       RewrittenPrototypeDescription protoChanges) {
     SourceCode sourceCode =
         new ConstructorEntryPoint(
-            typeConstructors, newConstructor, classIdField, extraNulls, callerPosition);
+            typeConstructors, method, classIdField, extraNulls, callerPosition);
     return IRBuilder.createForInlining(
             method, appView, codeLens, sourceCode, valueNumberGenerator, protoChanges)
         .build(context, MethodConversionOptions.nonConverting());
diff --git a/src/main/java/com/android/tools/r8/ir/synthetic/SyntheticSourceCode.java b/src/main/java/com/android/tools/r8/ir/synthetic/SyntheticSourceCode.java
index 4e3bf9c..de4b6c8 100644
--- a/src/main/java/com/android/tools/r8/ir/synthetic/SyntheticSourceCode.java
+++ b/src/main/java/com/android/tools/r8/ir/synthetic/SyntheticSourceCode.java
@@ -6,29 +6,25 @@
 
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.DebugLocalInfo;
-import com.android.tools.r8.graph.DexMethod;
-import com.android.tools.r8.graph.DexProto;
-import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.code.CatchHandlers;
 import com.android.tools.r8.ir.code.Position;
 import com.android.tools.r8.ir.code.ValueType;
 import com.android.tools.r8.ir.conversion.DexSourceCode;
 import com.android.tools.r8.ir.conversion.IRBuilder;
 import com.android.tools.r8.ir.conversion.SourceCode;
+import com.android.tools.r8.utils.ArrayUtils;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 
+@Deprecated
 public abstract class SyntheticSourceCode implements SourceCode {
+
   protected final static Predicate<IRBuilder> doesNotEndBlock = x -> false;
   protected final static Predicate<IRBuilder> endsBlock = x -> true;
 
-  // TODO(b/146124603): Remove these fields as optimizations (e.g., merging) could invalidate them.
-  protected final DexType receiver;
-  protected final DexMethod method;
-  protected final DexProto proto;
-
   // The next free register, note that we always
   // assign each value a new (next available) register.
   private int nextRegister = 0;
@@ -43,22 +39,15 @@
 
   private final Position position;
 
-  protected SyntheticSourceCode(DexType receiver, DexMethod method, Position position) {
-    assert method != null;
-    this.receiver = receiver;
-    this.method = method;
-    this.proto = method.proto;
+  protected SyntheticSourceCode(ProgramMethod method, Position position) {
     this.position = position;
 
     // Initialize register values for receiver and arguments
-    this.receiverRegister = receiver != null ? nextRegister(ValueType.OBJECT) : -1;
-
-    DexType[] params = proto.parameters.values;
-    int paramCount = params.length;
-    this.paramRegisters = new int[paramCount];
-    for (int i = 0; i < paramCount; i++) {
-      this.paramRegisters[i] = nextRegister(ValueType.fromDexType(params[i]));
-    }
+    this.receiverRegister = nextRegister(ValueType.OBJECT);
+    this.paramRegisters =
+        ArrayUtils.initialize(
+            new int[method.getParameters().size()],
+            i -> nextRegister(ValueType.fromDexType(method.getParameter(i))));
   }
 
   protected final void add(Consumer<IRBuilder> constructor) {
@@ -77,7 +66,6 @@
   }
 
   protected final int getReceiverRegister() {
-    assert receiver != null;
     assert receiverRegister >= 0;
     return receiverRegister;
   }
diff --git a/src/main/java/com/android/tools/r8/utils/ArrayUtils.java b/src/main/java/com/android/tools/r8/utils/ArrayUtils.java
index 0e57243..8dbaecf 100644
--- a/src/main/java/com/android/tools/r8/utils/ArrayUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/ArrayUtils.java
@@ -14,6 +14,7 @@
 import java.util.function.Function;
 import java.util.function.IntFunction;
 import java.util.function.IntPredicate;
+import java.util.function.IntUnaryOperator;
 import java.util.function.Predicate;
 
 public class ArrayUtils {
@@ -70,6 +71,13 @@
     return array;
   }
 
+  public static int[] initialize(int[] array, IntUnaryOperator fn) {
+    for (int i = 0; i < array.length; i++) {
+      array[i] = fn.applyAsInt(i);
+    }
+    return array;
+  }
+
   public static <T> T[] initialize(T[] array, IntFunction<T> fn) {
     for (int i = 0; i < array.length; i++) {
       array[i] = fn.apply(i);