Fix main dex assertion error from class merging

Change-Id: Ia97fa3359b05c7e4934509e7a02ef2bb56639329
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java
index f8f0a3f..783ede5 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java
@@ -6,48 +6,40 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.horizontalclassmerging.MultiClassPolicy;
+import com.android.tools.r8.horizontalclassmerging.MultiClassSameReferencePolicy;
+import com.android.tools.r8.horizontalclassmerging.policies.PreventMergeIntoMainDex.MainDexClassification;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.shaking.MainDexTracingResult;
-import java.util.Collection;
-import java.util.Iterator;
-import java.util.LinkedList;
-import java.util.List;
 
-public class PreventMergeIntoMainDex extends MultiClassPolicy {
+public class PreventMergeIntoMainDex extends MultiClassSameReferencePolicy<MainDexClassification> {
   private final MainDexClasses mainDexClasses;
   private final MainDexTracingResult mainDexTracingResult;
 
+  enum MainDexClassification {
+    MAIN_DEX_LIST,
+    MAIN_DEX_ROOT,
+    MAIN_DEX_DEPENDENCY,
+    NOT_IN_MAIN_DEX
+  }
+
   public PreventMergeIntoMainDex(
       AppView<AppInfoWithLiveness> appView, MainDexTracingResult mainDexTracingResult) {
     this.mainDexClasses = appView.appInfo().getMainDexClasses();
     this.mainDexTracingResult = mainDexTracingResult;
   }
 
-  public boolean isMainDexClass(DexProgramClass clazz) {
-    return mainDexClasses.contains(clazz) || mainDexTracingResult.contains(clazz);
-  }
-
   @Override
-  public Collection<List<DexProgramClass>> apply(List<DexProgramClass> group) {
-    List<DexProgramClass> mainDexMembers = new LinkedList<>();
-    Iterator<DexProgramClass> iterator = group.iterator();
-    while (iterator.hasNext()) {
-      DexProgramClass clazz = iterator.next();
-      if (isMainDexClass(clazz)) {
-        iterator.remove();
-        mainDexMembers.add(clazz);
-      }
+  public MainDexClassification getMergeKey(DexProgramClass clazz) {
+    if (mainDexClasses.contains(clazz)) {
+      return MainDexClassification.MAIN_DEX_LIST;
     }
-
-    Collection<List<DexProgramClass>> newGroups = new LinkedList<>();
-    if (!isTrivial(mainDexMembers)) {
-      newGroups.add(mainDexMembers);
+    if (mainDexTracingResult.isRoot(clazz)) {
+      return MainDexClassification.MAIN_DEX_ROOT;
     }
-    if (!isTrivial(group)) {
-      newGroups.add(group);
+    if (mainDexTracingResult.isDependency(clazz)) {
+      return MainDexClassification.MAIN_DEX_DEPENDENCY;
     }
-    return newGroups;
+    return MainDexClassification.NOT_IN_MAIN_DEX;
   }
 }