Refactor class id field into merge group

Change-Id: I28f5a132a3af9c7030c7cfb251c45ea29ab74fbf
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
index 3c7b587..358ec53 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
@@ -11,7 +11,6 @@
 import com.android.tools.r8.graph.DexAnnotationSet;
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
@@ -44,6 +43,7 @@
  * ClassMerger#lensBuilder}.
  */
 public class ClassMerger {
+
   public static final String CLASS_ID_FIELD_NAME = "$r8$classId";
 
   private final MergeGroup group;
@@ -58,7 +58,6 @@
   private final ClassInstanceFieldsMerger classInstanceFieldsMerger;
   private final Collection<VirtualMethodMerger> virtualMethodMergers;
   private final Collection<ConstructorMerger> constructorMergers;
-  private final DexField classIdField;
 
   private ClassMerger(
       AppView<AppInfoWithLiveness> appView,
@@ -66,14 +65,12 @@
       HorizontallyMergedClasses.Builder mergedClassesBuilder,
       FieldAccessInfoCollectionModifier.Builder fieldAccessChangesBuilder,
       MergeGroup group,
-      DexField classIdField,
       Collection<VirtualMethodMerger> virtualMethodMergers,
       Collection<ConstructorMerger> constructorMergers) {
     this.lensBuilder = lensBuilder;
     this.mergedClassesBuilder = mergedClassesBuilder;
     this.fieldAccessChangesBuilder = fieldAccessChangesBuilder;
     this.group = group;
-    this.classIdField = classIdField;
     this.virtualMethodMergers = virtualMethodMergers;
     this.constructorMergers = constructorMergers;
 
@@ -159,7 +156,7 @@
   void appendClassIdField() {
     DexEncodedField encodedField =
         new DexEncodedField(
-            classIdField,
+            group.getClassIdField(),
             FieldAccessFlags.fromSharedAccessFlags(
                 Constants.ACC_PUBLIC + Constants.ACC_FINAL + Constants.ACC_SYNTHETIC),
             FieldTypeSignature.noSignature(),
@@ -228,11 +225,22 @@
     public Builder(AppView<AppInfoWithLiveness> appView, MergeGroup group) {
       this.appView = appView;
       this.group = group;
-      setupForMethodMerging(group.getTarget());
-      group.forEachSource(this::setupForMethodMerging);
     }
 
-    void setupForMethodMerging(DexProgramClass toMerge) {
+    private Builder setup() {
+      DexItemFactory dexItemFactory = appView.dexItemFactory();
+      DexProgramClass target = group.iterator().next();
+      // TODO(b/165498187): ensure the name for the field is fresh
+      group.setClassIdField(
+          dexItemFactory.createField(
+              target.getType(), dexItemFactory.intType, CLASS_ID_FIELD_NAME));
+      group.setTarget(target);
+      setupForMethodMerging(target);
+      group.forEachSource(this::setupForMethodMerging);
+      return this;
+    }
+
+    private void setupForMethodMerging(DexProgramClass toMerge) {
       toMerge.forEachProgramDirectMethod(
           method -> {
             DexEncodedMethod definition = method.getDefinition();
@@ -244,7 +252,7 @@
       toMerge.forEachProgramVirtualMethod(this::addVirtualMethod);
     }
 
-    void addConstructor(ProgramMethod method) {
+    private void addConstructor(ProgramMethod method) {
       assert method.getDefinition().isInstanceInitializer();
       constructorMergerBuilders
           .computeIfAbsent(
@@ -252,7 +260,7 @@
           .add(method.getDefinition());
     }
 
-    void addVirtualMethod(ProgramMethod method) {
+    private void addVirtualMethod(ProgramMethod method) {
       assert method.getDefinition().isNonPrivateVirtualMethod();
       virtualMethodMergerBuilders
           .computeIfAbsent(
@@ -265,16 +273,11 @@
         HorizontallyMergedClasses.Builder mergedClassesBuilder,
         HorizontalClassMergerGraphLens.Builder lensBuilder,
         FieldAccessInfoCollectionModifier.Builder fieldAccessChangesBuilder) {
-      DexItemFactory dexItemFactory = appView.dexItemFactory();
-      // TODO(b/165498187): ensure the name for the field is fresh
-      DexField classIdField =
-          dexItemFactory.createField(
-              group.getTarget().getType(), dexItemFactory.intType, CLASS_ID_FIELD_NAME);
-
+      setup();
       List<VirtualMethodMerger> virtualMethodMergers =
           new ArrayList<>(virtualMethodMergerBuilders.size());
       for (VirtualMethodMerger.Builder builder : virtualMethodMergerBuilders.values()) {
-        virtualMethodMergers.add(builder.build(appView, group, classIdField));
+        virtualMethodMergers.add(builder.build(appView, group));
       }
       // Try and merge the functions with the most arguments first, to avoid using synthetic
       // arguments if possible.
@@ -283,7 +286,7 @@
       List<ConstructorMerger> constructorMergers =
           new ArrayList<>(constructorMergerBuilders.size());
       for (ConstructorMerger.Builder builder : constructorMergerBuilders.values()) {
-        constructorMergers.addAll(builder.build(appView, group, classIdField));
+        constructorMergers.addAll(builder.build(appView, group));
       }
 
       // Try and merge the functions with the most arguments first, to avoid using synthetic
@@ -297,7 +300,6 @@
           mergedClassesBuilder,
           fieldAccessChangesBuilder,
           group,
-          classIdField,
           virtualMethodMergers,
           constructorMergers);
     }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorMerger.java
index afc5805..ad4307d 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/ConstructorMerger.java
@@ -9,7 +9,6 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexAnnotationSet;
 import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexType;
@@ -37,17 +36,12 @@
   private final MergeGroup group;
   private final Collection<DexEncodedMethod> constructors;
   private final DexItemFactory dexItemFactory;
-  private final DexField classIdField;
 
   ConstructorMerger(
-      AppView<?> appView,
-      MergeGroup group,
-      Collection<DexEncodedMethod> constructors,
-      DexField classIdField) {
+      AppView<?> appView, MergeGroup group, Collection<DexEncodedMethod> constructors) {
     this.appView = appView;
     this.group = group;
     this.constructors = constructors;
-    this.classIdField = classIdField;
 
     // Constructors should not be empty and all constructors should have the same prototype.
     assert !constructors.isEmpty();
@@ -102,12 +96,10 @@
       return this;
     }
 
-    public List<ConstructorMerger> build(
-        AppView<?> appView, MergeGroup group, DexField classIdField) {
+    public List<ConstructorMerger> build(AppView<?> appView, MergeGroup group) {
       assert constructorGroups.stream().noneMatch(List::isEmpty);
       return ListUtils.map(
-          constructorGroups,
-          constructors -> new ConstructorMerger(appView, group, constructors, classIdField));
+          constructorGroups, constructors -> new ConstructorMerger(appView, group, constructors));
     }
   }
 
@@ -178,7 +170,7 @@
         new ConstructorEntryPointSynthesizedCode(
             typeConstructorClassMap,
             newConstructorReference,
-            classIdField,
+            group.getClassIdField(),
             appView.graphLens().getOriginalMethodSignature(representativeConstructorReference));
     DexEncodedMethod newConstructor =
         new DexEncodedMethod(
@@ -222,6 +214,7 @@
 
     classMethodsBuilder.addDirectMethod(newConstructor);
 
-    fieldAccessChangesBuilder.fieldWrittenByMethod(classIdField, newConstructorReference);
+    fieldAccessChangesBuilder.fieldWrittenByMethod(
+        group.getClassIdField(), newConstructorReference);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index 4720e94..0821a4c 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -157,7 +157,6 @@
     // TODO(b/166577694): Replace Collection<DexProgramClass> with MergeGroup
     for (MergeGroup group : groups) {
       assert !group.isEmpty();
-      group.setTarget(group.iterator().next());
       ClassMerger merger =
           new ClassMerger.Builder(appView, group)
               .build(mergedClassesBuilder, lensBuilder, fieldAccessChangesBuilder);
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/MergeGroup.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/MergeGroup.java
index d467d8a..02e4684 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/MergeGroup.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/MergeGroup.java
@@ -6,6 +6,7 @@
 
 package com.android.tools.r8.horizontalclassmerging;
 
+import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.utils.IteratorUtils;
 import com.google.common.collect.Iterables;
@@ -19,8 +20,10 @@
 
   public static class Metadata {}
 
-  private DexProgramClass target = null;
   private final LinkedList<DexProgramClass> classes;
+
+  private DexField classIdField;
+  private DexProgramClass target = null;
   private Metadata metadata = null;
 
   public MergeGroup() {
@@ -63,6 +66,19 @@
     return classes;
   }
 
+  public boolean hasClassIdField() {
+    return classIdField != null;
+  }
+
+  public DexField getClassIdField() {
+    assert hasClassIdField();
+    return classIdField;
+  }
+
+  public void setClassIdField(DexField classIdField) {
+    this.classIdField = classIdField;
+  }
+
   public Iterable<DexProgramClass> getSources() {
     assert hasTarget();
     return Iterables.filter(classes, clazz -> clazz != target);
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/VirtualMethodMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/VirtualMethodMerger.java
index fb77e22..fe3a898 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/VirtualMethodMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/VirtualMethodMerger.java
@@ -8,7 +8,6 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexAnnotationSet;
 import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
@@ -35,7 +34,6 @@
   private final DexItemFactory dexItemFactory;
   private final MergeGroup group;
   private final List<ProgramMethod> methods;
-  private final DexField classIdField;
   private final AppView<AppInfoWithLiveness> appView;
   private final DexMethod superMethod;
 
@@ -43,11 +41,9 @@
       AppView<AppInfoWithLiveness> appView,
       MergeGroup group,
       List<ProgramMethod> methods,
-      DexField classIdField,
       DexMethod superMethod) {
     this.dexItemFactory = appView.dexItemFactory();
     this.group = group;
-    this.classIdField = classIdField;
     this.methods = methods;
     this.appView = appView;
     this.superMethod = superMethod;
@@ -84,12 +80,11 @@
       return resolutionResult.getResolvedMethod().getReference();
     }
 
-    public VirtualMethodMerger build(
-        AppView<AppInfoWithLiveness> appView, MergeGroup group, DexField classIdField) {
+    public VirtualMethodMerger build(AppView<AppInfoWithLiveness> appView, MergeGroup group) {
       // If not all the classes are in the merge group, find the fallback super method to call.
       DexMethod superMethod =
           methods.size() < group.size() ? superMethod(appView, group.getTarget()) : null;
-      return new VirtualMethodMerger(appView, group, methods, classIdField, superMethod);
+      return new VirtualMethodMerger(appView, group, methods, superMethod);
     }
   }
 
@@ -259,7 +254,7 @@
     AbstractSynthesizedCode synthesizedCode =
         new VirtualMethodEntryPointSynthesizedCode(
             classIdToMethodMap,
-            classIdField,
+            group.getClassIdField(),
             superMethod,
             newMethodReference,
             bridgeMethodReference);
@@ -289,6 +284,6 @@
 
     classMethodsBuilder.addVirtualMethod(newMethod);
 
-    fieldAccessChangesBuilder.fieldReadByMethod(classIdField, newMethod.method);
+    fieldAccessChangesBuilder.fieldReadByMethod(group.getClassIdField(), newMethod.method);
   }
 }