Refactor instance fields merger for use in constructor merging

Change-Id: I4f9f9077c1c944aea88ccf2285e8c6f074854bc0
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassInstanceFieldsMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassInstanceFieldsMerger.java
index a429d5c..e0eaf2c 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassInstanceFieldsMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassInstanceFieldsMerger.java
@@ -8,10 +8,12 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramField;
 import com.android.tools.r8.horizontalclassmerging.HorizontalClassMergerGraphLens.Builder;
 import com.android.tools.r8.horizontalclassmerging.policies.SameInstanceFields.InstanceFieldInfo;
 import com.android.tools.r8.utils.IterableUtils;
+import com.android.tools.r8.utils.collections.BidirectionalManyToOneHashMap;
+import com.android.tools.r8.utils.collections.MutableBidirectionalManyToOneMap;
 import com.google.common.collect.Iterables;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -19,27 +21,28 @@
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 public class ClassInstanceFieldsMerger {
 
   private final AppView<? extends AppInfoWithClassHierarchy> appView;
+  private final MergeGroup group;
   private final Builder lensBuilder;
 
   private DexEncodedField classIdField;
 
   // Map from target class field to all fields which should be merged into that field.
-  private final Map<DexEncodedField, List<DexEncodedField>> fieldMappings = new LinkedHashMap<>();
+  private final MutableBidirectionalManyToOneMap<DexEncodedField, DexEncodedField> fieldMappings =
+      BidirectionalManyToOneHashMap.newLinkedHashMap();
 
   public ClassInstanceFieldsMerger(
       AppView<? extends AppInfoWithClassHierarchy> appView,
       HorizontalClassMergerGraphLens.Builder lensBuilder,
       MergeGroup group) {
     this.appView = appView;
+    this.group = group;
     this.lensBuilder = lensBuilder;
-    group
-        .getTarget()
-        .instanceFields()
-        .forEach(field -> fieldMappings.computeIfAbsent(field, ignore -> new ArrayList<>()));
+    group.forEachSource(this::addFields);
   }
 
   /**
@@ -52,7 +55,7 @@
    * Bar has fields 'A b' and 'B a'), we make a prepass that matches fields with the same reference
    * type.
    */
-  public void addFields(DexProgramClass clazz) {
+  private void addFields(DexProgramClass clazz) {
     Map<InstanceFieldInfo, LinkedList<DexEncodedField>> availableFieldsByExactInfo =
         getAvailableFieldsByExactInfo();
     List<DexEncodedField> needsMerge = new ArrayList<>();
@@ -66,7 +69,7 @@
         needsMerge.add(oldField);
       } else {
         DexEncodedField newField = availableFieldsWithExactSameInfo.removeFirst();
-        fieldMappings.get(newField).add(oldField);
+        fieldMappings.put(oldField, newField);
         if (availableFieldsWithExactSameInfo.isEmpty()) {
           availableFieldsByExactInfo.remove(info);
         }
@@ -84,14 +87,14 @@
               .removeFirst();
       assert newField != null;
       assert newField.getType().isReferenceType();
-      fieldMappings.get(newField).add(oldField);
+      fieldMappings.put(oldField, newField);
     }
   }
 
   private Map<InstanceFieldInfo, LinkedList<DexEncodedField>> getAvailableFieldsByExactInfo() {
     Map<InstanceFieldInfo, LinkedList<DexEncodedField>> availableFieldsByInfo =
         new LinkedHashMap<>();
-    for (DexEncodedField field : fieldMappings.keySet()) {
+    for (DexEncodedField field : group.getTarget().instanceFields()) {
       availableFieldsByInfo
           .computeIfAbsent(InstanceFieldInfo.createExact(field), ignore -> new LinkedList<>())
           .add(field);
@@ -122,6 +125,14 @@
     }
   }
 
+  public ProgramField getTargetField(ProgramField field) {
+    if (field.getHolder() == group.getTarget()) {
+      return field;
+    }
+    DexEncodedField targetField = fieldMappings.get(field.getDefinition());
+    return new ProgramField(group.getTarget(), targetField);
+  }
+
   public void setClassIdField(DexEncodedField classIdField) {
     this.classIdField = classIdField;
   }
@@ -131,19 +142,18 @@
     if (classIdField != null) {
       newFields.add(classIdField);
     }
-    fieldMappings.forEach(
-        (targetField, oldFields) ->
-            newFields.add(mergeSourceFieldsToTargetField(targetField, oldFields)));
+    fieldMappings.forEachManyToOneMapping(
+        (sourceFields, targetField) ->
+            newFields.add(mergeSourceFieldsToTargetField(targetField, sourceFields)));
     return newFields.toArray(DexEncodedField.EMPTY_ARRAY);
   }
 
   private DexEncodedField mergeSourceFieldsToTargetField(
-      DexEncodedField targetField, List<DexEncodedField> oldFields) {
-    fixAccessFlags(targetField, oldFields);
+      DexEncodedField targetField, Set<DexEncodedField> sourceFields) {
+    fixAccessFlags(targetField, sourceFields);
 
     DexEncodedField newField;
-    DexType targetFieldType = targetField.type();
-    if (!Iterables.all(oldFields, oldField -> oldField.getType() == targetFieldType)) {
+    if (needsRelaxedType(targetField, sourceFields)) {
       newField =
           targetField.toTypeSubstitutedField(
               targetField
@@ -155,10 +165,16 @@
 
     lensBuilder.recordNewFieldSignature(
         Iterables.transform(
-            IterableUtils.append(oldFields, targetField), DexEncodedField::getReference),
+            IterableUtils.append(sourceFields, targetField), DexEncodedField::getReference),
         newField.getReference(),
         targetField.getReference());
 
     return newField;
   }
+
+  private boolean needsRelaxedType(
+      DexEncodedField targetField, Iterable<DexEncodedField> sourceFields) {
+    return Iterables.any(
+        sourceFields, sourceField -> sourceField.getType() != targetField.getType());
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
index 89e9f74..999869c 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
@@ -295,11 +295,7 @@
   }
 
   void mergeInstanceFields() {
-    group.forEachSource(
-        clazz -> {
-          classInstanceFieldsMerger.addFields(clazz);
-          clazz.clearInstanceFields();
-        });
+    group.forEachSource(DexClass::clearInstanceFields);
     group.getTarget().setInstanceFields(classInstanceFieldsMerger.merge());
   }
 
diff --git a/src/main/java/com/android/tools/r8/utils/collections/BidirectionalManyToOneHashMap.java b/src/main/java/com/android/tools/r8/utils/collections/BidirectionalManyToOneHashMap.java
index 1e53a14..10f520f5 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/BidirectionalManyToOneHashMap.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/BidirectionalManyToOneHashMap.java
@@ -6,6 +6,7 @@
 
 import java.util.Collections;
 import java.util.IdentityHashMap;
+import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
@@ -21,6 +22,10 @@
     return new BidirectionalManyToOneHashMap<>(new IdentityHashMap<>(), new IdentityHashMap<>());
   }
 
+  public static <K, V> BidirectionalManyToOneHashMap<K, V> newLinkedHashMap() {
+    return new BidirectionalManyToOneHashMap<>(new LinkedHashMap<>(), new LinkedHashMap<>());
+  }
+
   protected BidirectionalManyToOneHashMap(Map<K, V> backing, Map<V, Set<K>> inverse) {
     this.backing = backing;
     this.inverse = inverse;