Derive the vertical-merging tree fixer from the shared base.

Bug: 166071494
Change-Id: Ic8348df0b63c6f9833a25b56dc3ebc1ec8989548
diff --git a/src/main/java/com/android/tools/r8/graph/TreeFixer.java b/src/main/java/com/android/tools/r8/graph/TreeFixer.java
index aa3597e..ca953d3 100644
--- a/src/main/java/com/android/tools/r8/graph/TreeFixer.java
+++ b/src/main/java/com/android/tools/r8/graph/TreeFixer.java
@@ -10,33 +10,59 @@
 import java.util.List;
 import java.util.Map;
 
-public class TreeFixer {
+public abstract class TreeFixer {
 
   private final AppView<?> appView;
   private final DexItemFactory dexItemFactory;
-  private final Map<DexType, DexType> repackagedClasses;
-  private final TreeFixerCallbacks callbacks;
 
-  private final Map<DexType, DexProgramClass> newProgramClasses = new IdentityHashMap<>();
+  private Map<DexType, DexProgramClass> newProgramClasses = null;
   private final Map<DexType, DexProgramClass> synthesizedFromClasses = new IdentityHashMap<>();
   private final Map<DexProto, DexProto> protoFixupCache = new IdentityHashMap<>();
 
-  public TreeFixer(
-      AppView<?> appView, Map<DexType, DexType> repackagedClasses, TreeFixerCallbacks callbacks) {
+  public TreeFixer(AppView<?> appView) {
     this.appView = appView;
     this.dexItemFactory = appView.dexItemFactory();
-    this.repackagedClasses = repackagedClasses;
-    this.callbacks = callbacks;
   }
 
-  public Collection<DexProgramClass> run() {
-    assert newProgramClasses.isEmpty();
-    for (DexProgramClass clazz : appView.appInfo().classesWithDeterministicOrder()) {
+  /** Mapping of a class type to a potentially new class type. */
+  public abstract DexType mapClassType(DexType type);
+
+  /** Callback invoked each time an encoded field changes field reference. */
+  public abstract void recordFieldChange(DexField from, DexField to);
+
+  /** Callback invoked each time an encoded method changes method reference. */
+  public abstract void recordMethodChange(DexMethod from, DexMethod to);
+
+  /** Callback invoked each time a program class definition changes type reference. */
+  public abstract void recordClassChange(DexType from, DexType to);
+
+  private DexProgramClass recordClassChange(DexProgramClass from, DexProgramClass to) {
+    recordClassChange(from.getType(), to.getType());
+    return to;
+  }
+
+  private DexEncodedField recordFieldChange(DexEncodedField from, DexEncodedField to) {
+    recordFieldChange(from.field, to.field);
+    return to;
+  }
+
+  /** Callback to allow custom handling when an encoded method changes. */
+  public DexEncodedMethod recordMethodChange(DexEncodedMethod from, DexEncodedMethod to) {
+    recordMethodChange(from.method, to.method);
+    return to;
+  }
+
+  /** Fixup a collection of classes. */
+  public Collection<DexProgramClass> fixupClasses(Collection<DexProgramClass> classes) {
+    assert newProgramClasses == null;
+    newProgramClasses = new IdentityHashMap<>();
+    for (DexProgramClass clazz : classes) {
       newProgramClasses.computeIfAbsent(clazz.getType(), ignore -> fixupClass(clazz));
     }
     return newProgramClasses.values();
   }
 
+  // Should remain private as the correctness of the fixup requires the lazy 'newProgramClasses'.
   private DexProgramClass fixupClass(DexProgramClass clazz) {
     DexProgramClass newClass =
         new DexProgramClass(
@@ -82,7 +108,7 @@
     }
     // If the class type changed, record the move in the lens.
     if (newClass.getType() != clazz.getType()) {
-      callbacks.recordMove(clazz.getType(), newClass.getType());
+      return recordClassChange(clazz, newClass);
     }
     return newClass;
   }
@@ -108,7 +134,8 @@
         : enclosingMethodAttribute;
   }
 
-  private DexEncodedField[] fixupFields(List<DexEncodedField> fields) {
+  /** Fixup a list of fields. */
+  public DexEncodedField[] fixupFields(List<DexEncodedField> fields) {
     if (fields == null) {
       return DexEncodedField.EMPTY_ARRAY;
     }
@@ -123,8 +150,7 @@
     DexField fieldReference = field.getReference();
     DexField newFieldReference = fixupFieldReference(fieldReference);
     if (newFieldReference != fieldReference) {
-      callbacks.recordMove(fieldReference, newFieldReference);
-      return field.toTypeSubstitutedField(newFieldReference);
+      return recordFieldChange(field, field.toTypeSubstitutedField(newFieldReference));
     }
     return field;
   }
@@ -169,21 +195,20 @@
     }
     return newMethods;
   }
-
-  private DexEncodedMethod fixupMethod(DexEncodedMethod method) {
+  /** Fixup a method definition. */
+  public DexEncodedMethod fixupMethod(DexEncodedMethod method) {
     DexMethod methodReference = method.getReference();
     DexMethod newMethodReference = fixupMethodReference(methodReference);
     if (newMethodReference != methodReference) {
-      callbacks.recordMove(methodReference, newMethodReference);
-      return method.toTypeSubstitutedMethod(newMethodReference);
+      return recordMethodChange(method, method.toTypeSubstitutedMethod(newMethodReference));
     }
     return method;
   }
 
-  private DexMethod fixupMethodReference(DexMethod method) {
-    return appView
-        .dexItemFactory()
-        .createMethod(fixupType(method.holder), fixupProto(method.proto), method.name);
+  /** Fixup a method reference. */
+  public DexMethod fixupMethodReference(DexMethod method) {
+    return dexItemFactory.createMethod(
+        fixupType(method.holder), fixupProto(method.proto), method.name);
   }
 
   private NestHostClassAttribute fixupNestHost(NestHostClassAttribute nestHostClassAttribute) {
@@ -220,8 +245,10 @@
     return result;
   }
 
+  // Should remain private as its correctness relies on the setup of 'newProgramClasses'.
   private Collection<DexProgramClass> fixupSynthesizedFrom(
       Collection<DexProgramClass> synthesizedFrom) {
+    assert newProgramClasses != null;
     if (synthesizedFrom.isEmpty()) {
       return synthesizedFrom;
     }
@@ -256,7 +283,7 @@
       return type.replaceBaseType(fixed, dexItemFactory);
     }
     if (type.isClassType()) {
-      return repackagedClasses.getOrDefault(type, type);
+      return mapClassType(type);
     }
     return type;
   }
diff --git a/src/main/java/com/android/tools/r8/graph/TreeFixerCallbacks.java b/src/main/java/com/android/tools/r8/graph/TreeFixerCallbacks.java
deleted file mode 100644
index bb7f3cf..0000000
--- a/src/main/java/com/android/tools/r8/graph/TreeFixerCallbacks.java
+++ /dev/null
@@ -1,13 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-package com.android.tools.r8.graph;
-
-public interface TreeFixerCallbacks {
-
-  void recordMove(DexField from, DexField to);
-
-  void recordMove(DexMethod from, DexMethod to);
-
-  void recordMove(DexType from, DexType to);
-}
diff --git a/src/main/java/com/android/tools/r8/repackaging/Repackaging.java b/src/main/java/com/android/tools/r8/repackaging/Repackaging.java
index 9510772..c84f49d 100644
--- a/src/main/java/com/android/tools/r8/repackaging/Repackaging.java
+++ b/src/main/java/com/android/tools/r8/repackaging/Repackaging.java
@@ -21,7 +21,7 @@
 import com.android.tools.r8.graph.ProgramPackageCollection;
 import com.android.tools.r8.graph.SortedProgramPackageCollection;
 import com.android.tools.r8.graph.TreeFixer;
-import com.android.tools.r8.graph.TreeFixerCallbacks;
+import com.android.tools.r8.repackaging.RepackagingLens.Builder;
 import com.android.tools.r8.shaking.AnnotationFixer;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.ProguardConfiguration;
@@ -35,6 +35,7 @@
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
@@ -78,25 +79,28 @@
     // Running the tree fixer with an identity mapping helps ensure that the fixup of of items is
     // complete as the rewrite replaces all items regardless of repackaging.
     // The identity mapping should result in no move callbacks being called.
-    TreeFixerCallbacks callbacks =
-        new TreeFixerCallbacks() {
-          @Override
-          public void recordMove(DexField from, DexField to) {
-            assert false;
-          }
-
-          @Override
-          public void recordMove(DexMethod from, DexMethod to) {
-            assert false;
-          }
-
-          @Override
-          public void recordMove(DexType from, DexType to) {
-            assert false;
-          }
-        };
     Collection<DexProgramClass> newProgramClasses =
-        new TreeFixer(appView, Collections.emptyMap(), callbacks).run();
+        new TreeFixer(appView) {
+          @Override
+          public DexType mapClassType(DexType type) {
+            return type;
+          }
+
+          @Override
+          public void recordFieldChange(DexField from, DexField to) {
+            assert false;
+          }
+
+          @Override
+          public void recordMethodChange(DexMethod from, DexMethod to) {
+            assert false;
+          }
+
+          @Override
+          public void recordClassChange(DexType from, DexType to) {
+            assert false;
+          }
+        }.fixupClasses(appView.appInfo().classesWithDeterministicOrder());
     CommittedItems committedItems =
         appView
             .getSyntheticItems()
@@ -141,14 +145,49 @@
       return null;
     }
     RepackagingLens.Builder lensBuilder = new RepackagingLens.Builder();
-    Collection<DexProgramClass> newProgramClasses =
-        new TreeFixer(appView, mappings, lensBuilder).run();
-    appBuilder.replaceProgramClasses(new ArrayList<>(newProgramClasses));
+    List<DexProgramClass> newProgramClasses =
+        new ArrayList<>(
+            new RepackagingTreeFixer(appView, mappings, lensBuilder)
+                .fixupClasses(appView.appInfo().classesWithDeterministicOrder()));
+    appBuilder.replaceProgramClasses(newProgramClasses);
     RepackagingLens lens = lensBuilder.build(appView);
     new AnnotationFixer(lens).run(appBuilder.getProgramClasses());
     return lens;
   }
 
+  private static class RepackagingTreeFixer extends TreeFixer {
+
+    private final BiMap<DexType, DexType> mappings;
+    private final Builder lensBuilder;
+
+    public RepackagingTreeFixer(
+        AppView<?> appView, BiMap<DexType, DexType> mappings, Builder lensBuilder) {
+      super(appView);
+      this.mappings = mappings;
+      this.lensBuilder = lensBuilder;
+    }
+
+    @Override
+    public DexType mapClassType(DexType type) {
+      return mappings.getOrDefault(type, type);
+    }
+
+    @Override
+    public void recordFieldChange(DexField from, DexField to) {
+      lensBuilder.recordMove(from, to);
+    }
+
+    @Override
+    public void recordMethodChange(DexMethod from, DexMethod to) {
+      lensBuilder.recordMove(from, to);
+    }
+
+    @Override
+    public void recordClassChange(DexType from, DexType to) {
+      lensBuilder.recordMove(from, to);
+    }
+  }
+
   private void processPackagesInDesiredLocation(
       ProgramPackageCollection packages,
       BiMap<DexType, DexType> mappings,
diff --git a/src/main/java/com/android/tools/r8/repackaging/RepackagingLens.java b/src/main/java/com/android/tools/r8/repackaging/RepackagingLens.java
index 04ae4a3..0fe63b8 100644
--- a/src/main/java/com/android/tools/r8/repackaging/RepackagingLens.java
+++ b/src/main/java/com/android/tools/r8/repackaging/RepackagingLens.java
@@ -9,7 +9,6 @@
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.GraphLens.NestedGraphLens;
-import com.android.tools.r8.graph.TreeFixerCallbacks;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.collections.BidirectionalOneToOneHashMap;
 import com.android.tools.r8.utils.collections.BidirectionalOneToOneMap;
@@ -47,7 +46,7 @@
     return originalTypes.get(to) == from || super.isSimpleRenaming(from, to);
   }
 
-  public static class Builder implements TreeFixerCallbacks {
+  public static class Builder {
 
     protected final BiMap<DexType, DexType> originalTypes = HashBiMap.create();
     protected final MutableBidirectionalOneToOneMap<DexField, DexField> newFieldSignatures =
@@ -55,17 +54,14 @@
     protected final MutableBidirectionalOneToOneMap<DexMethod, DexMethod> originalMethodSignatures =
         new BidirectionalOneToOneHashMap<>();
 
-    @Override
     public void recordMove(DexField from, DexField to) {
       newFieldSignatures.put(from, to);
     }
 
-    @Override
     public void recordMove(DexMethod from, DexMethod to) {
       originalMethodSignatures.put(to, from);
     }
 
-    @Override
     public void recordMove(DexType from, DexType to) {
       originalTypes.put(to, from);
     }
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index 6f7d998..324668c 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -15,12 +15,10 @@
 import com.android.tools.r8.graph.DexAnnotationSet;
 import com.android.tools.r8.graph.DexApplication;
 import com.android.tools.r8.graph.DexClass;
-import com.android.tools.r8.graph.DexClass.FieldSetter;
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMember;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMember;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
@@ -42,6 +40,7 @@
 import com.android.tools.r8.graph.RewrittenPrototypeDescription;
 import com.android.tools.r8.graph.SubtypingInfo;
 import com.android.tools.r8.graph.TopDownClassHierarchyTraversal;
+import com.android.tools.r8.graph.TreeFixer;
 import com.android.tools.r8.graph.UseRegistry;
 import com.android.tools.r8.graph.classmerging.VerticallyMergedClasses;
 import com.android.tools.r8.ir.code.Invoke.Type;
@@ -654,7 +653,8 @@
 
     timing.begin("fixup");
     VerticalClassMergerGraphLens lens =
-        new TreeFixer(appView, lensBuilder, verticallyMergedClasses, synthesizedBridges)
+        new VerticalClassMergerTreeFixer(
+                appView, lensBuilder, verticallyMergedClasses, synthesizedBridges)
             .fixupTypeReferences();
     KeepInfoCollection keepInfo = appView.appInfo().getKeepInfo();
     keepInfo.mutate(mutator -> mutator.removeKeepInfoForPrunedItems(mergedClasses.keySet()));
@@ -1465,23 +1465,20 @@
     method.accessFlags.setPrivate();
   }
 
-  private static class TreeFixer {
+  private static class VerticalClassMergerTreeFixer extends TreeFixer {
 
     private final AppView<AppInfoWithLiveness> appView;
-    private final DexItemFactory dexItemFactory;
     private final VerticalClassMergerGraphLens.Builder lensBuilder;
     private final VerticallyMergedClasses mergedClasses;
     private final List<SynthesizedBridgeCode> synthesizedBridges;
 
-    private final Map<DexProto, DexProto> protoFixupCache = new IdentityHashMap<>();
-
-    TreeFixer(
+    VerticalClassMergerTreeFixer(
         AppView<AppInfoWithLiveness> appView,
         VerticalClassMergerGraphLens.Builder lensBuilder,
         VerticallyMergedClasses mergedClasses,
         List<SynthesizedBridgeCode> synthesizedBridges) {
+      super(appView);
       this.appView = appView;
-      this.dexItemFactory = appView.dexItemFactory();
       this.lensBuilder =
           VerticalClassMergerGraphLens.Builder.createBuilderForFixup(lensBuilder, mergedClasses);
       this.mergedClasses = mergedClasses;
@@ -1492,11 +1489,11 @@
       // Globally substitute merged class types in protos and holders.
       for (DexProgramClass clazz : appView.appInfo().classes()) {
         clazz.getMethodCollection().replaceMethods(this::fixupMethod);
-        fixupFields(clazz.staticFields(), clazz::setStaticField);
-        fixupFields(clazz.instanceFields(), clazz::setInstanceField);
+        clazz.setStaticFields(fixupFields(clazz.staticFields()));
+        clazz.setInstanceFields(fixupFields(clazz.instanceFields()));
       }
       for (SynthesizedBridgeCode synthesizedBridge : synthesizedBridges) {
-        synthesizedBridge.updateMethodSignatures(this::fixupMethod);
+        synthesizedBridge.updateMethodSignatures(this::fixupMethodReference);
       }
       VerticalClassMergerGraphLens lens = lensBuilder.build(appView, mergedClasses);
       if (lens != null) {
@@ -1505,85 +1502,45 @@
       return lens;
     }
 
-    private DexEncodedMethod fixupMethod(DexEncodedMethod method) {
-      DexMethod methodReference = method.method;
-      DexMethod newMethodReference = fixupMethod(methodReference);
-      if (newMethodReference != methodReference) {
-        if (!lensBuilder.hasOriginalSignatureMappingFor(newMethodReference)) {
-          lensBuilder
-              .map(methodReference, newMethodReference)
-              .recordMove(methodReference, newMethodReference);
-        }
-        DexEncodedMethod newMethod = method.toTypeSubstitutedMethod(newMethodReference);
-        if (newMethod.isNonPrivateVirtualMethod()) {
-          // Since we changed the return type or one of the parameters, this method cannot be a
-          // classpath or library method override, since we only class merge program classes.
-          assert !method.isLibraryMethodOverride().isTrue();
-          newMethod.setLibraryMethodOverride(OptionalBool.FALSE);
-        }
-        return newMethod;
-      }
-      return method;
-    }
-
-    private void fixupFields(List<DexEncodedField> fields, FieldSetter setter) {
-      if (fields == null) {
-        return;
-      }
-      for (int i = 0; i < fields.size(); i++) {
-        DexEncodedField encodedField = fields.get(i);
-        DexField field = encodedField.field;
-        DexType newType = fixupType(field.type);
-        DexType newHolder = fixupType(field.holder);
-        DexField newField = dexItemFactory.createField(newHolder, newType, field.name);
-        if (newField != encodedField.field) {
-          if (!lensBuilder.hasOriginalSignatureMappingFor(newField)) {
-            lensBuilder.map(field, newField);
-          }
-          setter.setField(i, encodedField.toTypeSubstitutedField(newField));
-        }
-      }
-    }
-
-    private DexMethod fixupMethod(DexMethod method) {
-      return dexItemFactory.createMethod(
-          fixupType(method.holder), fixupProto(method.proto), method.name);
-    }
-
-    private DexProto fixupProto(DexProto proto) {
-      DexProto result = protoFixupCache.get(proto);
-      if (result == null) {
-        DexType returnType = fixupType(proto.returnType);
-        DexType[] arguments = fixupTypes(proto.parameters.values);
-        result = dexItemFactory.createProto(returnType, arguments);
-        protoFixupCache.put(proto, result);
-      }
-      return result;
-    }
-
-    private DexType fixupType(DexType type) {
-      if (type.isArrayType()) {
-        DexType base = type.toBaseType(dexItemFactory);
-        DexType fixed = fixupType(base);
-        if (base == fixed) {
-          return type;
-        }
-        return type.replaceBaseType(fixed, dexItemFactory);
-      }
-      if (type.isClassType()) {
-        while (mergedClasses.hasBeenMergedIntoSubtype(type)) {
-          type = mergedClasses.getTargetFor(type);
-        }
+    @Override
+    public DexType mapClassType(DexType type) {
+      while (mergedClasses.hasBeenMergedIntoSubtype(type)) {
+        type = mergedClasses.getTargetFor(type);
       }
       return type;
     }
 
-    private DexType[] fixupTypes(DexType[] types) {
-      DexType[] result = new DexType[types.length];
-      for (int i = 0; i < result.length; i++) {
-        result[i] = fixupType(types[i]);
+    @Override
+    public void recordClassChange(DexType from, DexType to) {
+      // Fixup of classes is not used so no class type should change.
+      throw new Unreachable();
+    }
+
+    @Override
+    public void recordFieldChange(DexField from, DexField to) {
+      if (!lensBuilder.hasOriginalSignatureMappingFor(to)) {
+        lensBuilder.map(from, to);
       }
-      return result;
+    }
+
+    @Override
+    public void recordMethodChange(DexMethod from, DexMethod to) {
+      if (!lensBuilder.hasOriginalSignatureMappingFor(to)) {
+        lensBuilder.map(from, to).recordMove(from, to);
+      }
+    }
+
+    @Override
+    public DexEncodedMethod recordMethodChange(
+        DexEncodedMethod method, DexEncodedMethod newMethod) {
+      recordMethodChange(method.method, newMethod.method);
+      if (newMethod.isNonPrivateVirtualMethod()) {
+        // Since we changed the return type or one of the parameters, this method cannot be a
+        // classpath or library method override, since we only class merge program classes.
+        assert !method.isLibraryMethodOverride().isTrue();
+        newMethod.setLibraryMethodOverride(OptionalBool.FALSE);
+      }
+      return newMethod;
     }
   }