Fix nondeterminism from horizontal class merging

Change-Id: I4e67278a6dd08d9ccd0a45b6a88a452d1dccd8f7
diff --git a/src/main/java/com/android/tools/r8/graph/AppInfo.java b/src/main/java/com/android/tools/r8/graph/AppInfo.java
index 114bc8a..3c5f4b7 100644
--- a/src/main/java/com/android/tools/r8/graph/AppInfo.java
+++ b/src/main/java/com/android/tools/r8/graph/AppInfo.java
@@ -13,6 +13,7 @@
 import com.android.tools.r8.utils.BooleanBox;
 import com.android.tools.r8.utils.InternalOptions;
 import java.util.Collection;
+import java.util.List;
 import java.util.function.Consumer;
 
 public class AppInfo implements DexDefinitionSupplier {
@@ -120,7 +121,7 @@
     return app.classes();
   }
 
-  public Iterable<DexProgramClass> classesWithDeterministicOrder() {
+  public List<DexProgramClass> classesWithDeterministicOrder() {
     assert checkIfObsolete();
     return app.classesWithDeterministicOrder();
   }
diff --git a/src/main/java/com/android/tools/r8/graph/DexApplication.java b/src/main/java/com/android/tools/r8/graph/DexApplication.java
index 516edc0..abb91ed 100644
--- a/src/main/java/com/android/tools/r8/graph/DexApplication.java
+++ b/src/main/java/com/android/tools/r8/graph/DexApplication.java
@@ -103,7 +103,7 @@
     return box.getClasses();
   }
 
-  public Iterable<DexProgramClass> classesWithDeterministicOrder() {
+  public List<DexProgramClass> classesWithDeterministicOrder() {
     List<DexProgramClass> classes = new ArrayList<>(programClasses());
     // We never actually sort by anything but the DexType, this is just here in case we ever change
     // that.
diff --git a/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java b/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
index 949fd47..bdcc10b 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
@@ -95,6 +95,11 @@
   }
 
   @Override
+  void clearDirectMethods() {
+    directMethods = DexEncodedMethod.EMPTY_ARRAY;
+  }
+
+  @Override
   DexEncodedMethod removeMethod(DexMethod method) {
     DexEncodedMethod removedDirectMethod =
         removeMethodHelper(
@@ -175,6 +180,11 @@
   }
 
   @Override
+  void clearVirtualMethods() {
+    virtualMethods = DexEncodedMethod.EMPTY_ARRAY;
+  }
+
+  @Override
   void setVirtualMethods(DexEncodedMethod[] methods) {
     virtualMethods = MoreObjects.firstNonNull(methods, DexEncodedMethod.EMPTY_ARRAY);
     assert verifyNoDuplicateMethods();
diff --git a/src/main/java/com/android/tools/r8/graph/MethodCollection.java b/src/main/java/com/android/tools/r8/graph/MethodCollection.java
index 312cec2..ef58e31 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodCollection.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodCollection.java
@@ -252,6 +252,11 @@
     backing.addDirectMethods(methods);
   }
 
+  public void clearDirectMethods() {
+    resetDirectMethodCaches();
+    backing.clearDirectMethods();
+  }
+
   public DexEncodedMethod removeMethod(DexMethod method) {
     DexEncodedMethod removed = backing.removeMethod(method);
     if (removed != null) {
@@ -283,6 +288,11 @@
     backing.addVirtualMethods(methods);
   }
 
+  public void clearVirtualMethods() {
+    resetVirtualMethodCaches();
+    backing.clearVirtualMethods();
+  }
+
   public void setVirtualMethods(DexEncodedMethod[] methods) {
     assert verifyCorrectnessOfMethodHolders(methods);
     resetVirtualMethodCaches();
diff --git a/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java b/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
index 601306c..006e992 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
@@ -94,6 +94,10 @@
 
   // Removal methods.
 
+  abstract void clearDirectMethods();
+
+  abstract void clearVirtualMethods();
+
   abstract DexEncodedMethod removeMethod(DexMethod method);
 
   abstract void removeMethods(Set<DexEncodedMethod> method);
diff --git a/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java b/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java
index d877896..4a1fe5f 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java
@@ -195,6 +195,16 @@
   }
 
   @Override
+  void clearDirectMethods() {
+    methodMap.values().removeIf(this::belongsToDirectPool);
+  }
+
+  @Override
+  void clearVirtualMethods() {
+    methodMap.values().removeIf(this::belongsToVirtualPool);
+  }
+
+  @Override
   DexEncodedMethod removeMethod(DexMethod method) {
     return methodMap.remove(wrap(method));
   }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java
index 402ddb1..b643225 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java
@@ -11,6 +11,7 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.IdentityHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.function.BiFunction;
 
@@ -32,13 +33,12 @@
   private final AppView<AppInfoWithLiveness> appView;
 
   private final Collection<DexProgramClass> roots = new ArrayList<>();
-  private final Map<DexProgramClass, Collection<DexProgramClass>> subtypeMap =
-      new IdentityHashMap<>();
+  private final Map<DexProgramClass, List<DexProgramClass>> subtypeMap = new IdentityHashMap<>();
 
-  public SubtypingForrestForClasses(AppView<AppInfoWithLiveness> appView) {
+  public SubtypingForrestForClasses(
+      AppView<AppInfoWithLiveness> appView, List<DexProgramClass> classesWithDeterministicOrder) {
     this.appView = appView;
-
-    calculateSubtyping(appView.appInfo().classes());
+    calculateSubtyping(classesWithDeterministicOrder);
   }
 
   private DexProgramClass superClass(DexProgramClass clazz) {
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
index b102b69..961a7ac 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
@@ -26,6 +26,7 @@
 import com.google.common.collect.HashBiMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.IdentityHashMap;
@@ -120,11 +121,11 @@
    * </ul>
    */
   public HorizontalClassMergerGraphLens fixupTypeReferences() {
-    Iterable<DexProgramClass> classes = appView.appInfo().classesWithDeterministicOrder();
+    List<DexProgramClass> classes = appView.appInfo().classesWithDeterministicOrder();
     Iterables.filter(classes, DexProgramClass::isInterface).forEach(this::fixupInterfaceClass);
 
     classes.forEach(this::fixupProgramClassSuperType);
-    SubtypingForrestForClasses subtypingForrest = new SubtypingForrestForClasses(appView);
+    SubtypingForrestForClasses subtypingForrest = new SubtypingForrestForClasses(appView, classes);
     // TODO(b/170078037): parallelize this code segment.
     for (DexProgramClass root : subtypingForrest.getProgramRoots()) {
       subtypingForrest.traverseNodeDepthFirst(
@@ -153,18 +154,22 @@
     Map<Wrapper<DexMethod>, DexString> remappedClassVirtualMethods =
         new HashMap<>(remappedVirtualMethods);
 
-    Set<DexMethod> newDirectMethodReferences = new LinkedHashSet<>();
-    Set<DexMethod> newVirtualMethodReferences = new LinkedHashSet<>();
+    Set<DexMethod> newVirtualMethodReferences = Sets.newIdentityHashSet();
+    List<DexEncodedMethod> newVirtualMethods = new ArrayList<>();
+    for (DexEncodedMethod method : clazz.virtualMethods()) {
+      newVirtualMethods.add(
+          fixupVirtualMethod(remappedClassVirtualMethods, newVirtualMethodReferences, method));
+    }
+    clazz.getMethodCollection().clearVirtualMethods();
+    clazz.getMethodCollection().addVirtualMethods(newVirtualMethods);
 
-    clazz
-        .getMethodCollection()
-        .replaceVirtualMethods(
-            method ->
-                fixupVirtualMethod(
-                    remappedClassVirtualMethods, newVirtualMethodReferences, method));
-    clazz
-        .getMethodCollection()
-        .replaceDirectMethods(method -> fixupDirectMethod(newDirectMethodReferences, method));
+    Set<DexMethod> newDirectMethodReferences = Sets.newIdentityHashSet();
+    List<DexEncodedMethod> newDirectMethods = new ArrayList<>();
+    for (DexEncodedMethod method : clazz.directMethods()) {
+      newDirectMethods.add(fixupDirectMethod(newDirectMethodReferences, method));
+    }
+    clazz.getMethodCollection().clearDirectMethods();
+    clazz.getMethodCollection().addDirectMethods(newDirectMethods);
 
     fixupFields(clazz.staticFields(), clazz::setStaticField);
     fixupFields(clazz.instanceFields(), clazz::setInstanceField);