Revert "Unify main-dex-classes and main-dex-tracing into one structure"

This reverts commit cf5c6320eba35d7d79ea70e67d70c035af884081.


Revert "Use Map for static field reference instead of ConcurrentHashMap"

This reverts commit c53865aa840fd5c25f5bbaa1232c47ad5ec8d7aa.

Reason for revert: Breaks YouTubeV1508ProtoRewritingTest

Change-Id: Idf10b2a5e252207387495294d6860088e806b0b3
diff --git a/src/main/java/com/android/tools/r8/D8.java b/src/main/java/com/android/tools/r8/D8.java
index 7335dbb..7310ecf 100644
--- a/src/main/java/com/android/tools/r8/D8.java
+++ b/src/main/java/com/android/tools/r8/D8.java
@@ -32,7 +32,6 @@
 import com.android.tools.r8.naming.signature.GenericSignatureRewriter;
 import com.android.tools.r8.origin.CommandLineOrigin;
 import com.android.tools.r8.origin.Origin;
-import com.android.tools.r8.shaking.MainDexInfo;
 import com.android.tools.r8.synthesis.SyntheticFinalization;
 import com.android.tools.r8.synthesis.SyntheticItems;
 import com.android.tools.r8.utils.AndroidApp;
@@ -185,9 +184,9 @@
       SyntheticItems.collectSyntheticInputs(appView);
 
       if (!options.mainDexKeepRules.isEmpty()) {
-        MainDexInfo mainDexInfo =
-            new GenerateMainDexList(options).traceMainDex(executor, appView.appInfo().app());
-        appView.setAppInfo(appView.appInfo().rebuildWithMainDexInfo(mainDexInfo));
+        new GenerateMainDexList(options)
+            .traceMainDex(
+                executor, appView.appInfo().app(), appView.appInfo().getMainDexClasses()::addAll);
       }
 
       final CfgPrinter printer = options.printCfg ? new CfgPrinter() : null;
@@ -301,7 +300,7 @@
           appView.setAppInfo(
               new AppInfo(
                   appView.appInfo().getSyntheticItems().commit(app),
-                  appView.appInfo().getMainDexInfo()));
+                  appView.appInfo().getMainDexClasses()));
           namingLens = NamingLens.getIdentityLens();
         }
 
@@ -359,7 +358,7 @@
     appView.setAppInfo(
         new AppInfo(
             appView.appInfo().getSyntheticItems().commit(cfApp),
-            appView.appInfo().getMainDexInfo()));
+            appView.appInfo().getMainDexClasses()));
     ConvertedCfFiles convertedCfFiles = new ConvertedCfFiles();
     NamingLens prefixRewritingNamingLens =
         PrefixRewritingNamingLens.createPrefixRewritingNamingLens(appView);
diff --git a/src/main/java/com/android/tools/r8/DexSplitterHelper.java b/src/main/java/com/android/tools/r8/DexSplitterHelper.java
index 7896d57..0230729 100644
--- a/src/main/java/com/android/tools/r8/DexSplitterHelper.java
+++ b/src/main/java/com/android/tools/r8/DexSplitterHelper.java
@@ -19,7 +19,7 @@
 import com.android.tools.r8.graph.LazyLoadedDexApplication;
 import com.android.tools.r8.naming.ClassNameMapper;
 import com.android.tools.r8.naming.NamingLens;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.utils.ExceptionUtils;
 import com.android.tools.r8.utils.FeatureClassMapping;
 import com.android.tools.r8.utils.FeatureClassMapping.FeatureMappingException;
@@ -75,7 +75,7 @@
       ApplicationReader applicationReader =
           new ApplicationReader(command.getInputApp(), options, timing);
       DexApplication app = applicationReader.read(executor);
-      MainDexInfo mainDexInfo = applicationReader.readMainDexClasses(app);
+      MainDexClasses mainDexClasses = applicationReader.readMainDexClasses(app);
 
       List<Marker> markers = app.dexItemFactory.extractMarkers();
 
@@ -94,7 +94,7 @@
         // If this is the base, we add the main dex list.
         AppInfo appInfo =
             feature.equals(featureClassMapping.getBaseName())
-                ? AppInfo.createInitialAppInfo(featureApp, mainDexInfo)
+                ? AppInfo.createInitialAppInfo(featureApp, mainDexClasses)
                 : AppInfo.createInitialAppInfo(featureApp);
         AppView<AppInfo> appView = AppView.createForD8(appInfo);
 
diff --git a/src/main/java/com/android/tools/r8/GenerateMainDexList.java b/src/main/java/com/android/tools/r8/GenerateMainDexList.java
index b3e100a..1e3d725 100644
--- a/src/main/java/com/android/tools/r8/GenerateMainDexList.java
+++ b/src/main/java/com/android/tools/r8/GenerateMainDexList.java
@@ -16,8 +16,8 @@
 import com.android.tools.r8.graph.SubtypingInfo;
 import com.android.tools.r8.shaking.Enqueuer;
 import com.android.tools.r8.shaking.EnqueuerFactory;
-import com.android.tools.r8.shaking.MainDexInfo;
 import com.android.tools.r8.shaking.MainDexListBuilder;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.android.tools.r8.shaking.RootSetUtils.MainDexRootSet;
 import com.android.tools.r8.shaking.WhyAreYouKeepingConsumer;
 import com.android.tools.r8.utils.AndroidApp;
@@ -28,16 +28,20 @@
 import com.android.tools.r8.utils.Timing;
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
+import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
 
 @Keep
 public class GenerateMainDexList {
   private final Timing timing = new Timing("maindex");
   private final InternalOptions options;
 
+  private List<String> result = null;
+
   public GenerateMainDexList(InternalOptions options) {
     this.options = options;
   }
@@ -45,24 +49,32 @@
   private List<String> run(AndroidApp app, ExecutorService executor)
       throws IOException {
     try {
-      // TODO(b/178231294): Clean up this such that we do not both return the result and call the
-      //  consumer.
       DexApplication application = new ApplicationReader(app, options, timing).read(executor);
-      List<String> result = new ArrayList<>();
-      traceMainDex(executor, application)
-          .forEach(type -> result.add(type.toBinaryName() + ".class"));
-      Collections.sort(result);
-      if (options.mainDexListConsumer != null) {
-        options.mainDexListConsumer.accept(String.join("\n", result), options.reporter);
-        options.mainDexListConsumer.finished(options.reporter);
-      }
+      traceMainDex(
+          executor,
+          application,
+          mainDexTracingResult -> {
+            result =
+                mainDexTracingResult.getClasses().stream()
+                    .map(c -> c.toSourceString().replace('.', '/') + ".class")
+                    .sorted()
+                    .collect(Collectors.toList());
+
+            if (options.mainDexListConsumer != null) {
+              options.mainDexListConsumer.accept(String.join("\n", result), options.reporter);
+              options.mainDexListConsumer.finished(options.reporter);
+            }
+          });
       return result;
     } catch (ExecutionException e) {
       throw unwrapExecutionException(e);
     }
   }
 
-  public MainDexInfo traceMainDex(ExecutorService executor, DexApplication application)
+  public void traceMainDex(
+      ExecutorService executor,
+      DexApplication application,
+      Consumer<MainDexTracingResult> resultConsumer)
       throws ExecutionException {
     AppView<? extends AppInfoWithClassHierarchy> appView =
         AppView.createForR8(application.toDirect());
@@ -85,19 +97,25 @@
 
     Enqueuer enqueuer =
         EnqueuerFactory.createForFinalMainDexTracing(
-            appView, executor, subtypingInfo, graphConsumer);
-    MainDexInfo mainDexInfo = enqueuer.traceMainDex(executor, timing);
+            appView, executor, subtypingInfo, graphConsumer, MainDexTracingResult.NONE);
+    Set<DexProgramClass> liveTypes = enqueuer.traceMainDex(executor, timing);
+    // LiveTypes is the result.
+    MainDexTracingResult mainDexTracingResult = new MainDexListBuilder(liveTypes, appView).run();
+    resultConsumer.accept(mainDexTracingResult);
+
     R8.processWhyAreYouKeepingAndCheckDiscarded(
         mainDexRootSet,
         () -> {
           ArrayList<DexProgramClass> classes = new ArrayList<>();
           // TODO(b/131668850): This is not a deterministic order!
-          mainDexInfo.forEach(
-              type -> {
-                DexClass clazz = appView.definitionFor(type);
-                assert clazz.isProgramClass();
-                classes.add(clazz.asProgramClass());
-              });
+          mainDexTracingResult
+              .getClasses()
+              .forEach(
+                  type -> {
+                    DexClass clazz = appView.definitionFor(type);
+                    assert clazz.isProgramClass();
+                    classes.add(clazz.asProgramClass());
+                  });
           return classes;
         },
         whyAreYouKeepingConsumer,
@@ -107,8 +125,6 @@
         options,
         timing,
         executor);
-
-    return mainDexInfo;
   }
 
   /**
diff --git a/src/main/java/com/android/tools/r8/PrintUses.java b/src/main/java/com/android/tools/r8/PrintUses.java
index 25f4d3b..782094c 100644
--- a/src/main/java/com/android/tools/r8/PrintUses.java
+++ b/src/main/java/com/android/tools/r8/PrintUses.java
@@ -27,7 +27,7 @@
 import com.android.tools.r8.graph.ResolutionResult;
 import com.android.tools.r8.graph.UseRegistry;
 import com.android.tools.r8.ir.desugar.LambdaDescriptor;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.StringUtils;
@@ -364,7 +364,7 @@
         AppInfoWithClassHierarchy.createInitialAppInfoWithClassHierarchy(
             application,
             ClassToFeatureSplitMap.createEmptyClassToFeatureSplitMap(),
-            MainDexInfo.createEmptyMainDexClasses());
+            MainDexClasses.createEmptyMainDexClasses());
   }
 
   private void analyze() {
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index eef9f3d..a79b303 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -88,8 +88,9 @@
 import com.android.tools.r8.shaking.Enqueuer;
 import com.android.tools.r8.shaking.Enqueuer.Mode;
 import com.android.tools.r8.shaking.EnqueuerFactory;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.shaking.MainDexListBuilder;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.android.tools.r8.shaking.MissingClasses;
 import com.android.tools.r8.shaking.ProguardConfigurationRule;
 import com.android.tools.r8.shaking.ProguardConfigurationUtils;
@@ -284,12 +285,12 @@
       {
         ApplicationReader applicationReader = new ApplicationReader(inputApp, options, timing);
         DirectMappedDexApplication application = applicationReader.read(executorService).toDirect();
-        MainDexInfo mainDexInfo = applicationReader.readMainDexClasses(application);
+        MainDexClasses mainDexClasses = applicationReader.readMainDexClasses(application);
 
         // Now that the dex-application is fully loaded, close any internal archive providers.
         inputApp.closeInternalArchiveProviders();
 
-        appView = AppView.createForR8(application, mainDexInfo);
+        appView = AppView.createForR8(application, mainDexClasses);
         appView.setAppServices(AppServices.builder(appView).build());
       }
 
@@ -422,7 +423,8 @@
       // Build conservative main dex content after first round of tree shaking. This is used
       // by certain optimizations to avoid introducing additional class references into main dex
       // classes, as that can cause the final number of main dex methods to grow.
-      performInitialMainDexTracing(appView, executorService);
+      MainDexTracingResult mainDexTracingResult =
+          performInitialMainDexTracing(appView, executorService);
 
       // The class type lattice elements include information about the interfaces that a class
       // implements. This information can change as a result of vertical class merging, so we need
@@ -464,7 +466,11 @@
           timing.begin("VerticalClassMerger");
           VerticalClassMerger verticalClassMerger =
               new VerticalClassMerger(
-                  getDirectApp(appViewWithLiveness), appViewWithLiveness, executorService, timing);
+                  getDirectApp(appViewWithLiveness),
+                  appViewWithLiveness,
+                  executorService,
+                  timing,
+                  mainDexTracingResult);
           VerticalClassMergerGraphLens lens = verticalClassMerger.run();
           if (lens != null) {
             appView.rewriteWithLens(lens);
@@ -511,7 +517,7 @@
           DirectMappedDexApplication.Builder appBuilder =
               appView.appInfo().app().asDirect().builder();
           HorizontalClassMergerResult horizontalClassMergerResult =
-              merger.run(appBuilder, runtimeTypeCheckInfo);
+              merger.run(appBuilder, mainDexTracingResult, runtimeTypeCheckInfo);
           if (horizontalClassMergerResult != null) {
             // Must rewrite AppInfoWithLiveness before pruning the merged classes, to ensure that
             // allocations sites, fields accesses, etc. are correctly transferred to the target
@@ -528,6 +534,8 @@
                     .addRemovedClasses(appView.horizontallyMergedClasses().getSources())
                     .addNoLongerSyntheticItems(appView.horizontallyMergedClasses().getTargets())
                     .build());
+
+            mainDexTracingResult = horizontalClassMergerResult.getMainDexTracingResult();
           }
           timing.end();
         } else {
@@ -550,7 +558,7 @@
       timing.begin("Create IR");
       CfgPrinter printer = options.printCfg ? new CfgPrinter() : null;
       try {
-        IRConverter converter = new IRConverter(appView, timing, printer);
+        IRConverter converter = new IRConverter(appView, timing, printer, mainDexTracingResult);
         DexApplication application =
             converter.optimize(appViewWithLiveness, executorService).asDirect();
         appView.setAppInfo(appView.appInfo().rebuildWithClassHierarchy(previous -> application));
@@ -661,6 +669,12 @@
                 .setClassesToRetainInnerClassAttributeFor(classesToRetainInnerClassAttributeFor)
                 .build(appView.withLiveness(), removedClasses)
                 .run();
+            if (!mainDexTracingResult.isEmpty()) {
+              // Remove types that no longer exists from the computed main dex list.
+              mainDexTracingResult =
+                  mainDexTracingResult.prunedCopy(appView.appInfo().withLiveness());
+            }
+
             // Synthesize fields for triggering class initializers.
             new ClassInitFieldSynthesizer(appViewWithLiveness).run(executorService);
           }
@@ -673,7 +687,7 @@
             appView.protoShrinker().enumLiteProtoShrinker.verifyDeadEnumLiteMapsAreDead();
           }
 
-          IRConverter converter = new IRConverter(appView, timing, null);
+          IRConverter converter = new IRConverter(appView, timing, null, mainDexTracingResult);
 
           // If proto shrinking is enabled, we need to reprocess every dynamicMethod(). This ensures
           // that proto fields that have been removed by the second round of tree shaking are also
@@ -691,7 +705,8 @@
         }
       }
 
-      performFinalMainDexTracing(appView, executorService);
+      mainDexTracingResult =
+          performFinalMainDexTracing(appView, executorService, mainDexTracingResult);
 
       // Remove unneeded visibility bridges that have been inserted for member rebinding.
       // This can only be done if we have AppInfoWithLiveness.
@@ -729,6 +744,11 @@
         assert Repackaging.verifyIdentityRepackaging(appView.withLiveness());
       }
 
+      // Add automatic main dex classes to an eventual manual list of classes.
+      if (!options.mainDexKeepRules.isEmpty()) {
+        appView.appInfo().getMainDexClasses().addAll(mainDexTracingResult);
+      }
+
       if (appView.appInfo().hasLiveness()) {
         SyntheticFinalization.finalizeWithLiveness(appView.withLiveness());
       } else {
@@ -834,11 +854,11 @@
     }
   }
 
-  private void performInitialMainDexTracing(
+  private MainDexTracingResult performInitialMainDexTracing(
       AppView<AppInfoWithClassHierarchy> appView, ExecutorService executorService)
       throws ExecutionException {
     if (options.mainDexKeepRules.isEmpty()) {
-      return;
+      return MainDexTracingResult.NONE;
     }
     assert appView.graphLens().isIdentityLens();
     // Find classes which may have code executed before secondary dex files installation.
@@ -847,19 +867,22 @@
         MainDexRootSet.builder(appView, subtypingInfo, options.mainDexKeepRules)
             .build(executorService);
     appView.setMainDexRootSet(mainDexRootSet);
-    appView.appInfo().unsetObsolete();
     // Live types is the tracing result.
-    MainDexInfo mainDexInfo =
+    Set<DexProgramClass> mainDexBaseClasses =
         EnqueuerFactory.createForInitialMainDexTracing(appView, executorService, subtypingInfo)
             .traceMainDex(executorService, timing);
-    appView.setAppInfo(appView.appInfo().rebuildWithMainDexInfo(mainDexInfo));
+    appView.appInfo().unsetObsolete();
+    // Calculate the automatic main dex list according to legacy multidex constraints.
+    return new MainDexListBuilder(mainDexBaseClasses, appView).run();
   }
 
-  private void performFinalMainDexTracing(
-      AppView<AppInfoWithClassHierarchy> appView, ExecutorService executorService)
+  private MainDexTracingResult performFinalMainDexTracing(
+      AppView<AppInfoWithClassHierarchy> appView,
+      ExecutorService executorService,
+      MainDexTracingResult previousTracingResult)
       throws ExecutionException {
     if (options.mainDexKeepRules.isEmpty()) {
-      return;
+      return MainDexTracingResult.NONE;
     }
     // No need to build a new main dex root set
     assert appView.getMainDexRootSet() != null;
@@ -872,22 +895,31 @@
 
     Enqueuer enqueuer =
         EnqueuerFactory.createForFinalMainDexTracing(
-            appView, executorService, new SubtypingInfo(appView), mainDexKeptGraphConsumer);
+            appView,
+            executorService,
+            new SubtypingInfo(appView),
+            mainDexKeptGraphConsumer,
+            previousTracingResult);
     // Find classes which may have code executed before secondary dex files installation.
-    MainDexInfo mainDexInfo = enqueuer.traceMainDex(executorService, timing);
-    appView.setAppInfo(appView.appInfo().rebuildWithMainDexInfo(mainDexInfo));
+    // Live types is the tracing result.
+    Set<DexProgramClass> mainDexBaseClasses = enqueuer.traceMainDex(executorService, timing);
+    // Calculate the automatic main dex list according to legacy multidex constraints.
+    MainDexTracingResult mainDexTracingResult =
+        new MainDexListBuilder(mainDexBaseClasses, appView).run();
 
     processWhyAreYouKeepingAndCheckDiscarded(
         appView.getMainDexRootSet(),
         () -> {
           ArrayList<DexProgramClass> classes = new ArrayList<>();
           // TODO(b/131668850): This is not a deterministic order!
-          mainDexInfo.forEach(
-              type -> {
-                DexClass clazz = appView.definitionFor(type);
-                assert clazz.isProgramClass();
-                classes.add(clazz.asProgramClass());
-              });
+          mainDexTracingResult
+              .getClasses()
+              .forEach(
+                  type -> {
+                    DexClass clazz = appView.definitionFor(type);
+                    assert clazz.isProgramClass();
+                    classes.add(clazz.asProgramClass());
+                  });
           return classes;
         },
         whyAreYouKeepingConsumer,
@@ -897,6 +929,8 @@
         options,
         timing,
         executorService);
+
+    return mainDexTracingResult;
   }
 
   private static boolean verifyMovedMethodsHaveOriginalMethodPosition(
@@ -1028,7 +1062,11 @@
       if (forMainDex) {
         enqueuer =
             EnqueuerFactory.createForFinalMainDexTracing(
-                appView, executorService, subtypingInfo, whyAreYouKeepingConsumer);
+                appView,
+                executorService,
+                subtypingInfo,
+                whyAreYouKeepingConsumer,
+                MainDexTracingResult.NONE);
         enqueuer.traceMainDex(executorService, timing);
       } else {
         enqueuer =
diff --git a/src/main/java/com/android/tools/r8/dex/ApplicationReader.java b/src/main/java/com/android/tools/r8/dex/ApplicationReader.java
index 23db0d9..4b035d4 100644
--- a/src/main/java/com/android/tools/r8/dex/ApplicationReader.java
+++ b/src/main/java/com/android/tools/r8/dex/ApplicationReader.java
@@ -29,7 +29,7 @@
 import com.android.tools.r8.graph.JarClassFileReader;
 import com.android.tools.r8.graph.LazyLoadedDexApplication;
 import com.android.tools.r8.naming.ClassNameMapper;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.ClassProvider;
@@ -198,8 +198,8 @@
     }
   }
 
-  public MainDexInfo readMainDexClasses(DexApplication app) {
-    MainDexInfo.Builder builder = MainDexInfo.builder();
+  public MainDexClasses readMainDexClasses(DexApplication app) {
+    MainDexClasses.Builder builder = MainDexClasses.builder();
     if (inputApp.hasMainDexList()) {
       for (StringResource resource : inputApp.getMainDexListResources()) {
         addToMainDexClasses(app, builder, MainDexListParser.parseList(resource, itemFactory));
@@ -211,15 +211,15 @@
               .map(clazz -> itemFactory.createType(DescriptorUtils.javaTypeToDescriptor(clazz)))
               .collect(Collectors.toList()));
     }
-    return builder.buildList();
+    return builder.build();
   }
 
   private void addToMainDexClasses(
-      DexApplication app, MainDexInfo.Builder builder, Iterable<DexType> types) {
+      DexApplication app, MainDexClasses.Builder builder, Iterable<DexType> types) {
     for (DexType type : types) {
       DexProgramClass clazz = app.programDefinitionFor(type);
       if (clazz != null) {
-        builder.addList(clazz);
+        builder.add(clazz);
       } else if (!options.ignoreMainDexMissingClasses) {
         options.reporter.warning(
             new StringDiagnostic(
diff --git a/src/main/java/com/android/tools/r8/dex/ApplicationWriter.java b/src/main/java/com/android/tools/r8/dex/ApplicationWriter.java
index da3b04c..ab72b04 100644
--- a/src/main/java/com/android/tools/r8/dex/ApplicationWriter.java
+++ b/src/main/java/com/android/tools/r8/dex/ApplicationWriter.java
@@ -45,7 +45,7 @@
 import com.android.tools.r8.naming.ProguardMapSupplier;
 import com.android.tools.r8.naming.ProguardMapSupplier.ProguardMapId;
 import com.android.tools.r8.origin.Origin;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.utils.ArrayUtils;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.ExceptionUtils;
@@ -199,7 +199,7 @@
           options.getDexFilePerClassFileConsumer().combineSyntheticClassesWithPrimaryClass());
     } else if (!options.canUseMultidex()
         && options.mainDexKeepRules.isEmpty()
-        && appView.appInfo().getMainDexInfo().isEmpty()
+        && appView.appInfo().getMainDexClasses().isEmpty()
         && options.enableMainDexListCheck) {
       distributor = new VirtualFile.MonoDexDistributor(this, options);
     } else {
@@ -688,11 +688,10 @@
   }
 
   private static String writeMainDexList(AppView<?> appView, NamingLens namingLens) {
-    // TODO(b/178231294): Clean up by streaming directly to the consumer.
-    MainDexInfo mainDexInfo = appView.appInfo().getMainDexInfo();
+    MainDexClasses mainDexClasses = appView.appInfo().getMainDexClasses();
     StringBuilder builder = new StringBuilder();
-    List<DexType> list = new ArrayList<>(mainDexInfo.size());
-    mainDexInfo.forEach(list::add);
+    List<DexType> list = new ArrayList<>(mainDexClasses.size());
+    mainDexClasses.forEach(list::add);
     list.sort(DexType::compareTo);
     list.forEach(
         type -> builder.append(mapMainDexListName(type, namingLens)).append('\n'));
diff --git a/src/main/java/com/android/tools/r8/dex/VirtualFile.java b/src/main/java/com/android/tools/r8/dex/VirtualFile.java
index 8f56827..414031e 100644
--- a/src/main/java/com/android/tools/r8/dex/VirtualFile.java
+++ b/src/main/java/com/android/tools/r8/dex/VirtualFile.java
@@ -27,7 +27,7 @@
 import com.android.tools.r8.logging.Log;
 import com.android.tools.r8.naming.ClassNameMapper;
 import com.android.tools.r8.naming.NamingLens;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.synthesis.SyntheticNaming;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.FileUtils;
@@ -394,14 +394,13 @@
     }
 
     protected void fillForMainDexList(Set<DexProgramClass> classes) {
-      MainDexInfo mainDexInfo = appView.appInfo().getMainDexInfo();
-      if (mainDexInfo.isEmpty()) {
+      MainDexClasses mainDexClasses = appView.appInfo().getMainDexClasses();
+      if (mainDexClasses.isEmpty()) {
         return;
       }
       VirtualFile mainDexFile = virtualFiles.get(0);
-      mainDexInfo.forEach(
+      mainDexClasses.forEach(
           type -> {
-            // TODO(b/178577273): We should ensure only live types in main dex.
             DexProgramClass clazz =
                 asProgramClassOrNull(appView.appInfo().definitionForWithoutExistenceAssert(type));
             if (clazz != null) {
diff --git a/src/main/java/com/android/tools/r8/graph/AppInfo.java b/src/main/java/com/android/tools/r8/graph/AppInfo.java
index d8785d1..302dff4 100644
--- a/src/main/java/com/android/tools/r8/graph/AppInfo.java
+++ b/src/main/java/com/android/tools/r8/graph/AppInfo.java
@@ -7,7 +7,7 @@
 import com.android.tools.r8.ir.desugar.InterfaceMethodRewriter;
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.synthesis.CommittedItems;
 import com.android.tools.r8.synthesis.SyntheticItems;
 import com.android.tools.r8.utils.BooleanBox;
@@ -20,7 +20,7 @@
 
   private final DexApplication app;
   private final DexItemFactory dexItemFactory;
-  private final MainDexInfo mainDexInfo;
+  private final MainDexClasses mainDexClasses;
   private final SyntheticItems syntheticItems;
 
   // Set when a new AppInfo replaces a previous one. All public methods should verify that the
@@ -28,36 +28,37 @@
   private final BooleanBox obsolete;
 
   public static AppInfo createInitialAppInfo(DexApplication application) {
-    return createInitialAppInfo(application, MainDexInfo.createEmptyMainDexClasses());
+    return createInitialAppInfo(application, MainDexClasses.createEmptyMainDexClasses());
   }
 
-  public static AppInfo createInitialAppInfo(DexApplication application, MainDexInfo mainDexInfo) {
-    return new AppInfo(SyntheticItems.createInitialSyntheticItems(application), mainDexInfo);
+  public static AppInfo createInitialAppInfo(
+      DexApplication application, MainDexClasses mainDexClasses) {
+    return new AppInfo(SyntheticItems.createInitialSyntheticItems(application), mainDexClasses);
   }
 
-  public AppInfo(CommittedItems committedItems, MainDexInfo mainDexInfo) {
+  public AppInfo(CommittedItems committedItems, MainDexClasses mainDexClasses) {
     this(
         committedItems.getApplication(),
         committedItems.toSyntheticItems(),
-        mainDexInfo,
+        mainDexClasses,
         new BooleanBox());
   }
 
   // For desugaring.
   // This is a view onto the app info and is the only place the pending synthetics are shared.
   AppInfo(AppInfoWithClassHierarchy.CreateDesugaringViewOnAppInfo witness, AppInfo appInfo) {
-    this(appInfo.app, appInfo.syntheticItems, appInfo.mainDexInfo, appInfo.obsolete);
+    this(appInfo.app, appInfo.syntheticItems, appInfo.mainDexClasses, appInfo.obsolete);
     assert witness != null;
   }
 
   private AppInfo(
       DexApplication application,
       SyntheticItems syntheticItems,
-      MainDexInfo mainDexInfo,
+      MainDexClasses mainDexClasses,
       BooleanBox obsolete) {
     this.app = application;
     this.dexItemFactory = application.dexItemFactory;
-    this.mainDexInfo = mainDexInfo;
+    this.mainDexClasses = mainDexClasses;
     this.syntheticItems = syntheticItems;
     this.obsolete = obsolete;
   }
@@ -71,12 +72,7 @@
     }
     return new AppInfo(
         getSyntheticItems().commitPrunedItems(prunedItems),
-        getMainDexInfo().withoutPrunedItems(prunedItems));
-  }
-
-  public AppInfo rebuildWithMainDexInfo(MainDexInfo mainDexInfo) {
-    assert checkIfObsolete();
-    return new AppInfo(app, syntheticItems, mainDexInfo, new BooleanBox());
+        getMainDexClasses().withoutPrunedItems(prunedItems));
   }
 
   protected InternalOptions options() {
@@ -111,29 +107,19 @@
     return dexItemFactory;
   }
 
-  public MainDexInfo getMainDexInfo() {
-    assert checkIfObsolete();
-    return mainDexInfo;
+  public MainDexClasses getMainDexClasses() {
+    return mainDexClasses;
   }
 
   public SyntheticItems getSyntheticItems() {
-    assert checkIfObsolete();
     return syntheticItems;
   }
 
-  public void addSynthesizedClass(DexProgramClass clazz, boolean addToMainDex) {
+  public void addSynthesizedClass(DexProgramClass clazz, boolean addToMainDexClasses) {
     assert checkIfObsolete();
     syntheticItems.addLegacySyntheticClass(clazz);
-    if (addToMainDex) {
-      mainDexInfo.addSyntheticClass(clazz);
-    }
-  }
-
-  public void addSynthesizedClass(DexProgramClass clazz, ProgramDefinition context) {
-    assert checkIfObsolete();
-    syntheticItems.addLegacySyntheticClass(clazz);
-    if (context != null) {
-      mainDexInfo.addLegacySyntheticClass(clazz, context);
+    if (addToMainDexClasses && !mainDexClasses.isEmpty()) {
+      mainDexClasses.add(clazz);
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/graph/AppInfoWithClassHierarchy.java b/src/main/java/com/android/tools/r8/graph/AppInfoWithClassHierarchy.java
index cd90e16..8ff4d1e 100644
--- a/src/main/java/com/android/tools/r8/graph/AppInfoWithClassHierarchy.java
+++ b/src/main/java/com/android/tools/r8/graph/AppInfoWithClassHierarchy.java
@@ -18,7 +18,7 @@
 import com.android.tools.r8.ir.analysis.type.InterfaceCollection;
 import com.android.tools.r8.ir.analysis.type.InterfaceCollection.Builder;
 import com.android.tools.r8.ir.desugar.LambdaDescriptor;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.shaking.MissingClasses;
 import com.android.tools.r8.synthesis.CommittedItems;
 import com.android.tools.r8.synthesis.SyntheticItems;
@@ -56,11 +56,11 @@
   public static AppInfoWithClassHierarchy createInitialAppInfoWithClassHierarchy(
       DexApplication application,
       ClassToFeatureSplitMap classToFeatureSplitMap,
-      MainDexInfo mainDexInfo) {
+      MainDexClasses mainDexClasses) {
     return new AppInfoWithClassHierarchy(
         SyntheticItems.createInitialSyntheticItems(application),
         classToFeatureSplitMap,
-        mainDexInfo,
+        mainDexClasses,
         MissingClasses.empty());
   }
 
@@ -74,9 +74,9 @@
   protected AppInfoWithClassHierarchy(
       CommittedItems committedItems,
       ClassToFeatureSplitMap classToFeatureSplitMap,
-      MainDexInfo mainDexInfo,
+      MainDexClasses mainDexClasses,
       MissingClasses missingClasses) {
-    super(committedItems, mainDexInfo);
+    super(committedItems, mainDexClasses);
     this.classToFeatureSplitMap = classToFeatureSplitMap;
     this.missingClasses = missingClasses;
   }
@@ -97,36 +97,27 @@
 
   public final AppInfoWithClassHierarchy rebuildWithClassHierarchy(CommittedItems commit) {
     return new AppInfoWithClassHierarchy(
-        commit, getClassToFeatureSplitMap(), getMainDexInfo(), getMissingClasses());
+        commit, getClassToFeatureSplitMap(), getMainDexClasses(), getMissingClasses());
   }
 
   public final AppInfoWithClassHierarchy rebuildWithClassHierarchy(MissingClasses missingClasses) {
     return new AppInfoWithClassHierarchy(
         getSyntheticItems().commit(app()),
         getClassToFeatureSplitMap(),
-        getMainDexInfo(),
+        getMainDexClasses(),
         missingClasses);
   }
 
   public AppInfoWithClassHierarchy rebuildWithClassHierarchy(
       Function<DexApplication, DexApplication> fn) {
-    assert checkIfObsolete();
     return new AppInfoWithClassHierarchy(
         getSyntheticItems().commit(fn.apply(app())),
         getClassToFeatureSplitMap(),
-        getMainDexInfo(),
+        getMainDexClasses(),
         getMissingClasses());
   }
 
   @Override
-  public AppInfoWithClassHierarchy rebuildWithMainDexInfo(MainDexInfo mainDexInfo) {
-    assert getClass() == AppInfoWithClassHierarchy.class;
-    assert checkIfObsolete();
-    return new AppInfoWithClassHierarchy(
-        getSyntheticItems().commit(app()), classToFeatureSplitMap, mainDexInfo, missingClasses);
-  }
-
-  @Override
   public AppInfoWithClassHierarchy prunedCopyFrom(PrunedItems prunedItems) {
     assert getClass() == AppInfoWithClassHierarchy.class;
     assert checkIfObsolete();
@@ -137,7 +128,7 @@
     return new AppInfoWithClassHierarchy(
         getSyntheticItems().commitPrunedItems(prunedItems),
         getClassToFeatureSplitMap().withoutPrunedItems(prunedItems),
-        getMainDexInfo().withoutPrunedItems(prunedItems),
+        getMainDexClasses().withoutPrunedItems(prunedItems),
         getMissingClasses());
   }
 
diff --git a/src/main/java/com/android/tools/r8/graph/AppView.java b/src/main/java/com/android/tools/r8/graph/AppView.java
index 954580b..96df40d 100644
--- a/src/main/java/com/android/tools/r8/graph/AppView.java
+++ b/src/main/java/com/android/tools/r8/graph/AppView.java
@@ -31,7 +31,7 @@
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.KeepInfoCollection;
 import com.android.tools.r8.shaking.LibraryModeledPredicate;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.shaking.ProguardCompatibilityActions;
 import com.android.tools.r8.shaking.RootSetUtils.MainDexRootSet;
 import com.android.tools.r8.shaking.RootSetUtils.RootSet;
@@ -97,6 +97,7 @@
   private EnumDataMap unboxedEnums = EnumDataMap.empty();
   // TODO(b/169115389): Remove
   private Set<DexMethod> cfByteCodePassThrough = ImmutableSet.of();
+
   private Map<DexType, DexValueString> sourceDebugExtensions = new IdentityHashMap<>();
 
   // When input has been (partially) desugared these are the classes which has been library
@@ -156,16 +157,16 @@
   }
 
   public static AppView<AppInfoWithClassHierarchy> createForR8(DexApplication application) {
-    return createForR8(application, MainDexInfo.createEmptyMainDexClasses());
+    return createForR8(application, MainDexClasses.createEmptyMainDexClasses());
   }
 
   public static AppView<AppInfoWithClassHierarchy> createForR8(
-      DexApplication application, MainDexInfo mainDexInfo) {
+      DexApplication application, MainDexClasses mainDexClasses) {
     ClassToFeatureSplitMap classToFeatureSplitMap =
         ClassToFeatureSplitMap.createInitialClassToFeatureSplitMap(application.options);
     AppInfoWithClassHierarchy appInfo =
         AppInfoWithClassHierarchy.createInitialAppInfoWithClassHierarchy(
-            application, classToFeatureSplitMap, mainDexInfo);
+            application, classToFeatureSplitMap, mainDexClasses);
     return new AppView<>(
         appInfo, WholeProgramOptimizations.ON, defaultPrefixRewritingMapper(appInfo));
   }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index d97f12f..47a1625 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -29,8 +29,7 @@
 import com.android.tools.r8.horizontalclassmerging.policies.NotMatchedByNoHorizontalClassMerging;
 import com.android.tools.r8.horizontalclassmerging.policies.NotVerticallyMergedIntoSubtype;
 import com.android.tools.r8.horizontalclassmerging.policies.PreserveMethodCharacteristics;
-import com.android.tools.r8.horizontalclassmerging.policies.PreventMergeIntoDifferentMainDexGroups;
-import com.android.tools.r8.horizontalclassmerging.policies.PreventMergeIntoMainDexList;
+import com.android.tools.r8.horizontalclassmerging.policies.PreventMergeIntoMainDex;
 import com.android.tools.r8.horizontalclassmerging.policies.PreventMethodImplementation;
 import com.android.tools.r8.horizontalclassmerging.policies.RespectPackageBoundaries;
 import com.android.tools.r8.horizontalclassmerging.policies.SameFeatureSplit;
@@ -40,6 +39,7 @@
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.FieldAccessInfoCollectionModifier;
 import com.android.tools.r8.shaking.KeepInfoCollection;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.android.tools.r8.shaking.RuntimeTypeCheckInfo;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
@@ -60,11 +60,12 @@
   // TODO(b/165577835): replace Collection<DexProgramClass> with MergeGroup
   public HorizontalClassMergerResult run(
       DirectMappedDexApplication.Builder appBuilder,
+      MainDexTracingResult mainDexTracingResult,
       RuntimeTypeCheckInfo runtimeTypeCheckInfo) {
     MergeGroup initialGroup = new MergeGroup(appView.appInfo().classesWithDeterministicOrder());
 
     // Run the policies on all program classes to produce a final grouping.
-    List<Policy> policies = getPolicies(runtimeTypeCheckInfo);
+    List<Policy> policies = getPolicies(mainDexTracingResult, runtimeTypeCheckInfo);
     Collection<MergeGroup> groups =
         new SimplePolicyExecutor().run(Collections.singletonList(initialGroup), policies);
 
@@ -78,6 +79,8 @@
         new HorizontallyMergedClasses.Builder();
     HorizontalClassMergerGraphLens.Builder lensBuilder =
         new HorizontalClassMergerGraphLens.Builder();
+    MainDexTracingResult.Builder mainDexTracingResultBuilder =
+        mainDexTracingResult.extensionBuilder(appView.appInfo());
 
     // Set up a class merger for each group.
     List<ClassMerger> classMergers =
@@ -88,7 +91,9 @@
 
     // Merge the classes.
     SyntheticArgumentClass syntheticArgumentClass =
-        new SyntheticArgumentClass.Builder(appBuilder, appView).build(allMergeClasses);
+        new SyntheticArgumentClass.Builder(
+                appBuilder, appView, mainDexTracingResult, mainDexTracingResultBuilder)
+            .build(allMergeClasses);
     applyClassMergers(classMergers, syntheticArgumentClass);
 
     // Generate the graph lens.
@@ -101,7 +106,8 @@
     KeepInfoCollection keepInfo = appView.appInfo().getKeepInfo();
     keepInfo.mutate(mutator -> mutator.removeKeepInfoForPrunedItems(mergedClasses.getSources()));
 
-    return new HorizontalClassMergerResult(createFieldAccessInfoCollectionModifier(groups), lens);
+    return new HorizontalClassMergerResult(
+        createFieldAccessInfoCollectionModifier(groups), lens, mainDexTracingResultBuilder.build());
   }
 
   private FieldAccessInfoCollectionModifier createFieldAccessInfoCollectionModifier(
@@ -120,7 +126,8 @@
     return builder.build();
   }
 
-  private List<Policy> getPolicies(RuntimeTypeCheckInfo runtimeTypeCheckInfo) {
+  private List<Policy> getPolicies(
+      MainDexTracingResult mainDexTracingResult, RuntimeTypeCheckInfo runtimeTypeCheckInfo) {
     return ImmutableList.of(
         new NotMatchedByNoHorizontalClassMerging(appView),
         new SameInstanceFields(appView),
@@ -140,9 +147,8 @@
         new NoDirectRuntimeTypeChecks(runtimeTypeCheckInfo),
         new NoIndirectRuntimeTypeChecks(appView, runtimeTypeCheckInfo),
         new PreventMethodImplementation(appView),
-        new DontInlinePolicy(appView),
-        new PreventMergeIntoMainDexList(appView),
-        new PreventMergeIntoDifferentMainDexGroups(appView),
+        new DontInlinePolicy(appView, mainDexTracingResult),
+        new PreventMergeIntoMainDex(appView, mainDexTracingResult),
         new AllInstantiatedOrUninstantiated(appView),
         new SameParentClass(),
         new SameNestHost(),
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerResult.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerResult.java
index 820c24a..75aafb4 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerResult.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerResult.java
@@ -5,17 +5,21 @@
 package com.android.tools.r8.horizontalclassmerging;
 
 import com.android.tools.r8.shaking.FieldAccessInfoCollectionModifier;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 
 public class HorizontalClassMergerResult {
 
   private final FieldAccessInfoCollectionModifier fieldAccessInfoCollectionModifier;
   private final HorizontalClassMergerGraphLens graphLens;
+  private final MainDexTracingResult mainDexTracingResult;
 
   HorizontalClassMergerResult(
       FieldAccessInfoCollectionModifier fieldAccessInfoCollectionModifier,
-      HorizontalClassMergerGraphLens graphLens) {
+      HorizontalClassMergerGraphLens graphLens,
+      MainDexTracingResult mainDexTracingResult) {
     this.fieldAccessInfoCollectionModifier = fieldAccessInfoCollectionModifier;
     this.graphLens = graphLens;
+    this.mainDexTracingResult = mainDexTracingResult;
   }
 
   public FieldAccessInfoCollectionModifier getFieldAccessInfoCollectionModifier() {
@@ -25,4 +29,8 @@
   public HorizontalClassMergerGraphLens getGraphLens() {
     return graphLens;
   }
+
+  public MainDexTracingResult getMainDexTracingResult() {
+    return mainDexTracingResult;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/SyntheticArgumentClass.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/SyntheticArgumentClass.java
index 60c350a..ec64e8a 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/SyntheticArgumentClass.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/SyntheticArgumentClass.java
@@ -17,7 +17,7 @@
 import com.android.tools.r8.graph.GenericSignature.ClassSignature;
 import com.android.tools.r8.origin.SynthesizedOrigin;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.google.common.collect.Iterables;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -57,15 +57,24 @@
 
     private final DirectMappedDexApplication.Builder appBuilder;
     private final AppView<AppInfoWithLiveness> appView;
+    private final MainDexTracingResult mainDexTracingResult;
+    private final MainDexTracingResult.Builder mainDexTracingResultBuilder;
 
-    Builder(DirectMappedDexApplication.Builder appBuilder, AppView<AppInfoWithLiveness> appView) {
+    Builder(
+        DirectMappedDexApplication.Builder appBuilder,
+        AppView<AppInfoWithLiveness> appView,
+        MainDexTracingResult mainDexTracingResult,
+        MainDexTracingResult.Builder mainDexTracingResultBuilder) {
       this.appBuilder = appBuilder;
       this.appView = appView;
+      this.mainDexTracingResult = mainDexTracingResult;
+      this.mainDexTracingResultBuilder = mainDexTracingResultBuilder;
     }
 
     private DexType synthesizeClass(
         DexProgramClass context,
         boolean requiresMainDex,
+        boolean addToMainDexTracingResult,
         int index) {
       DexType syntheticClassType =
           appView
@@ -99,6 +108,10 @@
 
       appBuilder.addSynthesizedClass(clazz);
       appView.appInfo().addSynthesizedClass(clazz, requiresMainDex);
+      if (addToMainDexTracingResult) {
+        mainDexTracingResultBuilder.addRoot(clazz);
+      }
+
       return clazz.type;
     }
 
@@ -106,17 +119,20 @@
       // Find a fresh name in an existing package.
       DexProgramClass context = mergeClasses.iterator().next();
 
-      // Add as a root to the main dex tracing result if any of the merged classes is a root.
+      // Add to the main dex list if one of the merged classes is in the main dex.
+      boolean requiresMainDex = appView.appInfo().getMainDexClasses().containsAnyOf(mergeClasses);
+
+      // Also add as a root to the main dex tracing result if any of the merged classes is a root.
       // This is needed to satisfy an assertion in the inliner that verifies that we do not inline
       // methods with references to non-roots into classes that are roots.
-      MainDexInfo mainDexInfo = appView.appInfo().getMainDexInfo();
-      boolean requiresMainDex = Iterables.any(mergeClasses, mainDexInfo::isMainDex);
+      boolean addToMainDexTracingResult = Iterables.any(mergeClasses, mainDexTracingResult::isRoot);
 
       List<DexType> syntheticArgumentTypes = new ArrayList<>();
       for (int i = 0;
           i < appView.options().horizontalClassMergerOptions().getSyntheticArgumentCount();
           i++) {
-        syntheticArgumentTypes.add(synthesizeClass(context, requiresMainDex, i));
+        syntheticArgumentTypes.add(
+            synthesizeClass(context, requiresMainDex, addToMainDexTracingResult, i));
       }
 
       return new SyntheticArgumentClass(syntheticArgumentTypes);
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/DontInlinePolicy.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/DontInlinePolicy.java
index f7d121a..9612f98 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/DontInlinePolicy.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/DontInlinePolicy.java
@@ -12,16 +12,18 @@
 import com.android.tools.r8.horizontalclassmerging.SingleClassPolicy;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexDirectReferenceTracer;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.google.common.collect.Iterables;
 
 public class DontInlinePolicy extends SingleClassPolicy {
   private final AppView<AppInfoWithLiveness> appView;
-  private final MainDexInfo mainDexInfo;
+  private final MainDexTracingResult mainDexTracingResult;
 
-  public DontInlinePolicy(AppView<AppInfoWithLiveness> appView) {
+  public DontInlinePolicy(
+      AppView<AppInfoWithLiveness> appView, MainDexTracingResult mainDexTracingResult) {
     this.appView = appView;
-    this.mainDexInfo = appView.appInfo().getMainDexInfo();
+    this.mainDexTracingResult = mainDexTracingResult;
   }
 
   private boolean disallowInlining(ProgramMethod method) {
@@ -44,6 +46,14 @@
       return true;
     }
 
+    // Constructors can have references beyond the root main dex classes. This can increase the
+    // size of the main dex dependent classes and we should bail out.
+    if (mainDexTracingResult.getRoots().contains(method.getHolderType())
+        && MainDexDirectReferenceTracer.hasReferencesOutsideFromCode(
+            appView.appInfo(), method, mainDexTracingResult.getRoots())) {
+      return true;
+    }
+
     return false;
   }
 
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoDifferentMainDexGroups.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoDifferentMainDexGroups.java
deleted file mode 100644
index 52a1d98..0000000
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoDifferentMainDexGroups.java
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package com.android.tools.r8.horizontalclassmerging.policies;
-
-import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.horizontalclassmerging.MultiClassSameReferencePolicy;
-import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.MainDexInfo;
-import com.android.tools.r8.shaking.MainDexInfo.MainDexGroup;
-
-public class PreventMergeIntoDifferentMainDexGroups
-    extends MultiClassSameReferencePolicy<MainDexGroup> {
-
-  private final MainDexInfo mainDexInfo;
-
-  public PreventMergeIntoDifferentMainDexGroups(AppView<AppInfoWithLiveness> appView) {
-    this.mainDexInfo = appView.appInfo().getMainDexInfo();
-  }
-
-  @Override
-  public MainDexGroup getMergeKey(DexProgramClass clazz) {
-    assert !mainDexInfo.isFromList(clazz);
-    return mainDexInfo.getMergeKey(clazz);
-  }
-}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java
new file mode 100644
index 0000000..783ede5
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java
@@ -0,0 +1,45 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging.policies;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.horizontalclassmerging.MultiClassSameReferencePolicy;
+import com.android.tools.r8.horizontalclassmerging.policies.PreventMergeIntoMainDex.MainDexClassification;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.shaking.MainDexClasses;
+import com.android.tools.r8.shaking.MainDexTracingResult;
+
+public class PreventMergeIntoMainDex extends MultiClassSameReferencePolicy<MainDexClassification> {
+  private final MainDexClasses mainDexClasses;
+  private final MainDexTracingResult mainDexTracingResult;
+
+  enum MainDexClassification {
+    MAIN_DEX_LIST,
+    MAIN_DEX_ROOT,
+    MAIN_DEX_DEPENDENCY,
+    NOT_IN_MAIN_DEX
+  }
+
+  public PreventMergeIntoMainDex(
+      AppView<AppInfoWithLiveness> appView, MainDexTracingResult mainDexTracingResult) {
+    this.mainDexClasses = appView.appInfo().getMainDexClasses();
+    this.mainDexTracingResult = mainDexTracingResult;
+  }
+
+  @Override
+  public MainDexClassification getMergeKey(DexProgramClass clazz) {
+    if (mainDexClasses.contains(clazz)) {
+      return MainDexClassification.MAIN_DEX_LIST;
+    }
+    if (mainDexTracingResult.isRoot(clazz)) {
+      return MainDexClassification.MAIN_DEX_ROOT;
+    }
+    if (mainDexTracingResult.isDependency(clazz)) {
+      return MainDexClassification.MAIN_DEX_DEPENDENCY;
+    }
+    return MainDexClassification.NOT_IN_MAIN_DEX;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDexList.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDexList.java
deleted file mode 100644
index 850d02e..0000000
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDexList.java
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package com.android.tools.r8.horizontalclassmerging.policies;
-
-import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.horizontalclassmerging.SingleClassPolicy;
-import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.MainDexInfo;
-
-public class PreventMergeIntoMainDexList extends SingleClassPolicy {
-
-  private final MainDexInfo mainDexInfo;
-
-  public PreventMergeIntoMainDexList(AppView<AppInfoWithLiveness> appView) {
-    this.mainDexInfo = appView.appInfo().getMainDexInfo();
-  }
-
-  @Override
-  public boolean canMerge(DexProgramClass program) {
-    return mainDexInfo.canMerge(program);
-  }
-}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index b1e68ace..22ee5c5 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -91,6 +91,7 @@
 import com.android.tools.r8.position.MethodPosition;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.LibraryMethodOverrideAnalysis;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.android.tools.r8.utils.Action;
 import com.android.tools.r8.utils.CfgPrinter;
 import com.android.tools.r8.utils.DescriptorUtils;
@@ -120,6 +121,7 @@
   private static final int PEEPHOLE_OPTIMIZATION_PASSES = 2;
 
   public final AppView<?> appView;
+  public final MainDexTracingResult mainDexClasses;
 
   private final Timing timing;
   private final Outliner outliner;
@@ -180,7 +182,8 @@
    * The argument `appView` is used to determine if whole program optimizations are allowed or not
    * (i.e., whether we are running R8). See {@link AppView#enableWholeProgramOptimizations()}.
    */
-  public IRConverter(AppView<?> appView, Timing timing, CfgPrinter printer) {
+  public IRConverter(
+      AppView<?> appView, Timing timing, CfgPrinter printer, MainDexTracingResult mainDexClasses) {
     assert appView.appInfo().hasLiveness() || appView.graphLens().isIdentityLens();
     assert appView.options() != null;
     assert appView.options().programConsumer != null;
@@ -189,6 +192,7 @@
     this.appView = appView;
     this.options = appView.options();
     this.printer = printer;
+    this.mainDexClasses = mainDexClasses;
     this.codeRewriter = new CodeRewriter(appView, this);
     this.constantCanonicalizer = new ConstantCanonicalizer(codeRewriter);
     this.classInitializerDefaultsOptimization =
@@ -299,7 +303,7 @@
               : null;
       this.enumUnboxer = options.enableEnumUnboxing ? new EnumUnboxer(appViewWithLiveness) : null;
       this.lensCodeRewriter = new LensCodeRewriter(appViewWithLiveness, enumUnboxer);
-      this.inliner = new Inliner(appViewWithLiveness, lensCodeRewriter);
+      this.inliner = new Inliner(appViewWithLiveness, mainDexClasses, lensCodeRewriter);
       this.outliner = new Outliner(appViewWithLiveness);
       this.memberValuePropagation =
           options.enableValuePropagation ? new MemberValuePropagation(appViewWithLiveness) : null;
@@ -311,7 +315,9 @@
         this.identifierNameStringMarker = null;
       }
       this.devirtualizer =
-          options.enableDevirtualization ? new Devirtualizer(appViewWithLiveness) : null;
+          options.enableDevirtualization
+              ? new Devirtualizer(appViewWithLiveness, mainDexClasses)
+              : null;
       this.typeChecker = new TypeChecker(appViewWithLiveness, VerifyTypesHelper.create(appView));
       this.d8NestBasedAccessDesugaring = null;
       this.serviceLoaderRewriter =
@@ -358,11 +364,16 @@
 
   /** Create an IR converter for processing methods with full program optimization disabled. */
   public IRConverter(AppView<?> appView, Timing timing) {
-    this(appView, timing, null);
+    this(appView, timing, null, MainDexTracingResult.NONE);
+  }
+
+  /** Create an IR converter for processing methods with full program optimization disabled. */
+  public IRConverter(AppView<?> appView, Timing timing, CfgPrinter printer) {
+    this(appView, timing, printer, MainDexTracingResult.NONE);
   }
 
   public IRConverter(AppInfo appInfo, Timing timing, CfgPrinter printer) {
-    this(AppView.createForD8(appInfo), timing, printer);
+    this(AppView.createForD8(appInfo), timing, printer, MainDexTracingResult.NONE);
   }
 
   private void removeLambdaDeserializationMethods() {
@@ -475,7 +486,7 @@
       appView.setAppInfo(
           new AppInfo(
               appView.appInfo().getSyntheticItems().commit(application),
-              appView.appInfo().getMainDexInfo()));
+              appView.appInfo().getMainDexClasses()));
       application = appView.appInfo().app();
     }
 
@@ -495,7 +506,7 @@
     appView.setAppInfo(
         new AppInfo(
             appView.appInfo().getSyntheticItems().commit(application),
-            appView.appInfo().getMainDexInfo()));
+            appView.appInfo().getMainDexClasses()));
   }
 
   private void convertClasses(DexApplication application, ExecutorService executorService)
@@ -1276,7 +1287,8 @@
     if (appView.appInfo().hasLiveness()) {
       // Reflection optimization 1. getClass() / forName() -> const-class
       timing.begin("Rewrite to const class");
-      ReflectionOptimizer.rewriteGetClassOrForNameToConstClass(appView.withLiveness(), code);
+      ReflectionOptimizer.rewriteGetClassOrForNameToConstClass(
+          appView.withLiveness(), code, mainDexClasses);
       timing.end();
     }
 
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/InterfaceMethodRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/InterfaceMethodRewriter.java
index db127f2..25f232d 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/InterfaceMethodRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/InterfaceMethodRewriter.java
@@ -73,6 +73,7 @@
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.origin.SynthesizedOrigin;
 import com.android.tools.r8.position.MethodPosition;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.synthesis.SyntheticNaming;
 import com.android.tools.r8.synthesis.SyntheticNaming.SyntheticKind;
 import com.android.tools.r8.utils.DescriptorUtils;
@@ -715,6 +716,7 @@
     // Emulated library interfaces should generate the Emulated Library EL dispatch class.
     Map<DexType, List<DexType>> emulatedInterfacesHierarchy = processEmulatedInterfaceHierarchy();
     AppInfo appInfo = appView.appInfo();
+    MainDexClasses mainDexClasses = appInfo.getMainDexClasses();
     for (DexType interfaceType : emulatedInterfaces.keySet()) {
       DexClass theInterface = appInfo.definitionFor(interfaceType);
       if (theInterface == null) {
@@ -726,7 +728,8 @@
                 theProgramInterface, emulatedInterfacesHierarchy);
         if (synthesizedClass != null) {
           builder.addSynthesizedClass(synthesizedClass);
-          appInfo.addSynthesizedClass(synthesizedClass, theProgramInterface);
+          appInfo.addSynthesizedClass(
+              synthesizedClass, mainDexClasses.contains(theProgramInterface));
         }
       }
     }
@@ -1156,6 +1159,7 @@
     // make original default methods abstract, remove bridge methods, create dispatch
     // classes if needed.
     AppInfo appInfo = appView.appInfo();
+    MainDexClasses mainDexClasses = appInfo.getMainDexClasses();
     InterfaceProcessorNestedGraphLens.Builder graphLensBuilder =
         InterfaceProcessorNestedGraphLens.builder();
     Map<DexClass, DexProgramClass> classMapping =
@@ -1170,7 +1174,10 @@
           // Don't need to optimize synthesized class since all of its methods
           // are just moved from interfaces and don't need to be re-processed.
           builder.addSynthesizedClass(synthesizedClass);
-          appInfo.addSynthesizedClass(synthesizedClass, interfaceClass.asProgramClass());
+          boolean addToMainDexClasses =
+              interfaceClass.isProgramClass()
+                  && mainDexClasses.contains(interfaceClass.asProgramClass());
+          appInfo.addSynthesizedClass(synthesizedClass, addToMainDexClasses);
         });
     new InterfaceMethodRewriterFixup(appView, graphLens).run();
     if (appView.options().isDesugaredLibraryCompilation()) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
index 969b5d5..79f36ad 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
@@ -38,6 +38,7 @@
 import com.android.tools.r8.ir.optimize.inliner.WhyAreYouNotInliningReporter;
 import com.android.tools.r8.logging.Log;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.shaking.MainDexDirectReferenceTracer;
 import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.Timing;
@@ -198,17 +199,25 @@
     // Don't inline code with references beyond root main dex classes into a root main dex class.
     // If we do this it can increase the size of the main dex dependent classes.
     if (reason != Reason.FORCE
-        && inliner.mainDexInfo.disallowInliningIntoContext(
-            appView.appInfo(), method, singleTarget)) {
+        && inlineeRefersToClassesNotInMainDex(method.getHolderType(), singleTarget)) {
       whyAreYouNotInliningReporter.reportInlineeRefersToClassesNotInMainDex();
       return false;
     }
     assert reason != Reason.FORCE
-        || !inliner.mainDexInfo.disallowInliningIntoContext(
-            appView.appInfo(), method, singleTarget);
+            || !inlineeRefersToClassesNotInMainDex(method.getHolderType(), singleTarget)
+        : MainDexDirectReferenceTracer.getFirstReferenceOutsideFromCode(
+            appView.appInfo(), singleTarget, inliner.mainDexClasses.getRoots());
     return true;
   }
 
+  private boolean inlineeRefersToClassesNotInMainDex(DexType holder, ProgramMethod target) {
+    if (inliner.mainDexClasses.isEmpty() || !inliner.mainDexClasses.getRoots().contains(holder)) {
+      return false;
+    }
+    return MainDexDirectReferenceTracer.hasReferencesOutsideFromCode(
+        appView.appInfo(), target, inliner.mainDexClasses.getRoots());
+  }
+
   private boolean satisfiesRequirementsForSimpleInlining(
       InvokeMethod invoke, ProgramMethod target) {
     // If we are looking for a simple method, only inline if actually simple.
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java b/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
index f5f4d70..7179327 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
@@ -25,6 +25,7 @@
 import com.android.tools.r8.ir.code.InvokeVirtual;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.android.tools.r8.utils.InternalOptions;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
@@ -47,10 +48,13 @@
 public class Devirtualizer {
 
   private final AppView<AppInfoWithLiveness> appView;
+  private final MainDexTracingResult mainDexTracingResult;
   private final InternalOptions options;
 
-  public Devirtualizer(AppView<AppInfoWithLiveness> appView) {
+  public Devirtualizer(
+      AppView<AppInfoWithLiveness> appView, MainDexTracingResult mainDexTracingResult) {
     this.appView = appView;
+    this.mainDexTracingResult = mainDexTracingResult;
     this.options = appView.options();
   }
 
@@ -150,7 +154,7 @@
           DexClassAndMethod reboundTarget = rebindSuperInvokeToMostSpecific(invokedMethod, context);
           if (reboundTarget != null
               && reboundTarget.getReference() != invokedMethod
-              && !isRebindingNewClassIntoMainDex(context, reboundTarget.getReference())) {
+              && !isRebindingNewClassIntoMainDex(invokedMethod, reboundTarget.getReference())) {
             it.replaceCurrentInstruction(
                 new InvokeSuper(
                     reboundTarget.getReference(),
@@ -193,7 +197,7 @@
         }
 
         // Ensure that we are not adding a new main dex root
-        if (isRebindingNewClassIntoMainDex(context, target.getReference())) {
+        if (isRebindingNewClassIntoMainDex(invoke.getInvokedMethod(), target.getReference())) {
           continue;
         }
 
@@ -390,7 +394,13 @@
     return newResolutionResult.getResolvedMethod().method;
   }
 
-  private boolean isRebindingNewClassIntoMainDex(ProgramMethod context, DexMethod reboundMethod) {
-    return !appView.appInfo().getMainDexInfo().canRebindReference(context, reboundMethod);
+  private boolean isRebindingNewClassIntoMainDex(
+      DexMethod originalMethod, DexMethod reboundMethod) {
+    if (!mainDexTracingResult.isRoot(originalMethod.holder)
+        && !appView.appInfo().getMainDexClasses().contains(originalMethod.holder)) {
+      return false;
+    }
+    return !mainDexTracingResult.isRoot(reboundMethod.holder)
+        && !appView.appInfo().getMainDexClasses().contains(reboundMethod.holder);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
index e034587..9e726b9 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
@@ -59,7 +59,7 @@
 import com.android.tools.r8.ir.optimize.inliner.WhyAreYouNotInliningReporter;
 import com.android.tools.r8.kotlin.Kotlin;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.IteratorUtils;
 import com.android.tools.r8.utils.ListUtils;
@@ -84,7 +84,7 @@
   protected final AppView<AppInfoWithLiveness> appView;
   private final Set<DexMethod> extraNeverInlineMethods;
   private final LensCodeRewriter lensCodeRewriter;
-  final MainDexInfo mainDexInfo;
+  final MainDexTracingResult mainDexClasses;
 
   // State for inlining methods which are known to be called twice.
   private boolean applyDoubleInlining = false;
@@ -97,6 +97,7 @@
 
   public Inliner(
       AppView<AppInfoWithLiveness> appView,
+      MainDexTracingResult mainDexClasses,
       LensCodeRewriter lensCodeRewriter) {
     Kotlin.Intrinsics intrinsics = appView.dexItemFactory().kotlin.intrinsics;
     this.appView = appView;
@@ -105,7 +106,7 @@
             ? ImmutableSet.of()
             : ImmutableSet.of(intrinsics.throwNpe, intrinsics.throwParameterIsNullException);
     this.lensCodeRewriter = lensCodeRewriter;
-    this.mainDexInfo = appView.appInfo().getMainDexInfo();
+    this.mainDexClasses = mainDexClasses;
     availableApiExceptions =
         appView.options().canHaveDalvikCatchHandlerVerificationBug()
             ? new AvailableApiExceptions(appView.options())
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/ReflectionOptimizer.java b/src/main/java/com/android/tools/r8/ir/optimize/ReflectionOptimizer.java
index afdf56a..cc5e7ed 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/ReflectionOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/ReflectionOptimizer.java
@@ -26,6 +26,7 @@
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.google.common.collect.Sets;
 import java.util.Set;
@@ -36,7 +37,7 @@
   // Rewrite getClass() to const-class if the type of the given instance is effectively final.
   // Rewrite forName() to const-class if the type is resolvable, accessible and already initialized.
   public static void rewriteGetClassOrForNameToConstClass(
-      AppView<AppInfoWithLiveness> appView, IRCode code) {
+      AppView<AppInfoWithLiveness> appView, IRCode code, MainDexTracingResult mainDexClasses) {
     if (!appView.appInfo().canUseConstClassInstructions(appView.options())) {
       return;
     }
@@ -61,14 +62,14 @@
               context,
               invoke.asInvokeStatic(),
               rewriteSingleGetClassOrForNameToConstClass(
-                  appView, code, it, invoke, affectedValues));
+                  appView, code, it, invoke, affectedValues, mainDexClasses));
         } else {
           applyTypeForGetClassTo(
               appView,
               context,
               invoke.asInvokeVirtual(),
               rewriteSingleGetClassOrForNameToConstClass(
-                  appView, code, it, invoke, affectedValues));
+                  appView, code, it, invoke, affectedValues, mainDexClasses));
         }
       }
     }
@@ -84,15 +85,14 @@
       IRCode code,
       InstructionListIterator instructionIterator,
       InvokeMethod invoke,
-      Set<Value> affectedValues) {
+      Set<Value> affectedValues,
+      MainDexTracingResult mainDexClasses) {
     return (type, baseClass) -> {
       if (invoke.getInvokedMethod().match(appView.dexItemFactory().classMethods.forName)) {
         // Bail-out if the optimization could increase the size of the main dex.
         if (baseClass.isProgramClass()
-            && !appView
-                .appInfo()
-                .getMainDexInfo()
-                .canRebindReference(code.context(), baseClass.getType())) {
+            && !mainDexClasses.canReferenceItemFromContextWithoutIncreasingMainDexSize(
+                baseClass.asProgramClass(), code.context())) {
           return;
         }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/UnboxedEnumMemberRelocator.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/UnboxedEnumMemberRelocator.java
index 48e545c..5a8ae19 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/UnboxedEnumMemberRelocator.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/UnboxedEnumMemberRelocator.java
@@ -5,7 +5,6 @@
 package com.android.tools.r8.ir.optimize.enums;
 
 import static com.android.tools.r8.ir.optimize.enums.EnumUnboxingRewriter.createValuesField;
-import static com.google.common.base.Predicates.alwaysTrue;
 
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.graph.AppView;
@@ -24,7 +23,6 @@
 import com.android.tools.r8.graph.ProgramPackageCollection;
 import com.android.tools.r8.origin.SynthesizedOrigin;
 import com.android.tools.r8.shaking.FieldAccessInfoCollectionModifier;
-import com.android.tools.r8.shaking.MainDexInfo;
 import com.android.tools.r8.utils.SetUtils;
 import com.google.common.collect.ImmutableMap;
 import java.util.ArrayList;
@@ -34,7 +32,6 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.function.Predicate;
 
 public class UnboxedEnumMemberRelocator {
 
@@ -132,8 +129,9 @@
         Set<DexProgramClass> relocatedEnums,
         DirectMappedDexApplication.Builder appBuilder,
         FieldAccessInfoCollectionModifier.Builder fieldAccessInfoCollectionModifierBuilder) {
-      DexProgramClass deterministicContext = findDeterministicContextType(contexts, alwaysTrue());
-      String descriptorString = deterministicContext.getType().toDescriptorString();
+      DexType deterministicContextType = findDeterministicContextType(contexts);
+      assert deterministicContextType.isClassType();
+      String descriptorString = deterministicContextType.toDescriptorString();
       String descriptorPrefix = descriptorString.substring(0, descriptorString.length() - 1);
       String syntheticClassDescriptor = descriptorPrefix + ENUM_UNBOXING_UTILITY_CLASS_SUFFIX + ";";
       DexType type = appView.dexItemFactory().createType(syntheticClassDescriptor);
@@ -182,25 +180,20 @@
               appView.dexItemFactory().getSkipNameValidationForTesting(),
               DexProgramClass::checksumFromType);
       appBuilder.addSynthesizedClass(syntheticClass);
-      MainDexInfo mainDexInfo = appView.appInfo().getMainDexInfo();
       appView
           .appInfo()
           .addSynthesizedClass(
-              syntheticClass, findDeterministicContextType(contexts, mainDexInfo::isMainDex));
+              syntheticClass, appView.appInfo().getMainDexClasses().containsAnyOf(contexts));
       return syntheticClass;
     }
 
-    private DexProgramClass findDeterministicContextType(
-        Set<DexProgramClass> contexts, Predicate<DexProgramClass> predicate) {
-      DexProgramClass deterministicContext = null;
+    private DexType findDeterministicContextType(Set<DexProgramClass> contexts) {
+      DexType deterministicContext = null;
       for (DexProgramClass context : contexts) {
-        if (!predicate.test(context)) {
-          continue;
-        }
         if (deterministicContext == null) {
-          deterministicContext = context;
-        } else if (context.type.compareTo(deterministicContext.type) < 0) {
-          deterministicContext = context;
+          deterministicContext = context.type;
+        } else if (context.type.compareTo(deterministicContext) < 0) {
+          deterministicContext = context.type;
         }
       }
       return deterministicContext;
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index 06c03e7..baa1189 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -195,7 +195,7 @@
   AppInfoWithLiveness(
       CommittedItems syntheticItems,
       ClassToFeatureSplitMap classToFeatureSplitMap,
-      MainDexInfo mainDexInfo,
+      MainDexClasses mainDexClasses,
       Set<DexType> deadProtoTypes,
       MissingClasses missingClasses,
       Set<DexType> liveTypes,
@@ -234,7 +234,7 @@
       Map<DexField, Int2ReferenceMap<DexField>> switchMaps,
       Set<DexType> lockCandidates,
       Map<DexType, Visibility> initClassReferences) {
-    super(syntheticItems, classToFeatureSplitMap, mainDexInfo, missingClasses);
+    super(syntheticItems, classToFeatureSplitMap, mainDexClasses, missingClasses);
     this.deadProtoTypes = deadProtoTypes;
     this.liveTypes = liveTypes;
     this.targetedMethods = targetedMethods;
@@ -279,7 +279,7 @@
     this(
         committedItems,
         previous.getClassToFeatureSplitMap(),
-        previous.getMainDexInfo(),
+        previous.getMainDexClasses(),
         previous.deadProtoTypes,
         previous.getMissingClasses(),
         CollectionUtils.mergeSets(previous.liveTypes, committedItems.getCommittedProgramTypes()),
@@ -324,7 +324,7 @@
     this(
         previous.getSyntheticItems().commitPrunedItems(prunedItems),
         previous.getClassToFeatureSplitMap().withoutPrunedItems(prunedItems),
-        previous.getMainDexInfo().withoutPrunedItems(prunedItems),
+        previous.getMainDexClasses().withoutPrunedItems(prunedItems),
         previous.deadProtoTypes,
         previous.getMissingClasses(),
         prunedItems.hasRemovedClasses()
@@ -376,52 +376,6 @@
     return true;
   }
 
-  @Override
-  public AppInfoWithLiveness rebuildWithMainDexInfo(MainDexInfo mainDexInfo) {
-    return new AppInfoWithLiveness(
-        getSyntheticItems().commit(app()),
-        getClassToFeatureSplitMap(),
-        mainDexInfo,
-        deadProtoTypes,
-        getMissingClasses(),
-        liveTypes,
-        targetedMethods,
-        failedMethodResolutionTargets,
-        failedFieldResolutionTargets,
-        bootstrapMethods,
-        methodsTargetedByInvokeDynamic,
-        virtualMethodsTargetedByInvokeDirect,
-        liveMethods,
-        fieldAccessInfoCollection,
-        methodAccessInfoCollection,
-        objectAllocationInfoCollection,
-        callSites,
-        keepInfo,
-        mayHaveSideEffects,
-        noSideEffects,
-        assumedValues,
-        alwaysInline,
-        forceInline,
-        neverInline,
-        neverInlineDueToSingleCaller,
-        whyAreYouNotInlining,
-        keepConstantArguments,
-        keepUnusedArguments,
-        reprocess,
-        neverReprocess,
-        alwaysClassInline,
-        neverClassInline,
-        noClassMerging,
-        noVerticalClassMerging,
-        noHorizontalClassMerging,
-        neverPropagateValue,
-        identifierNameStrings,
-        prunedTypes,
-        switchMaps,
-        lockCandidates,
-        initClassReferences);
-  }
-
   private static KeepInfoCollection extendPinnedItems(
       AppInfoWithLiveness previous, Collection<? extends DexReference> additionalPinnedItems) {
     if (additionalPinnedItems == null || additionalPinnedItems.isEmpty()) {
@@ -464,7 +418,7 @@
     super(
         previous.getSyntheticItems().commit(previous.app()),
         previous.getClassToFeatureSplitMap(),
-        previous.getMainDexInfo(),
+        previous.getMainDexClasses(),
         previous.getMissingClasses());
     this.deadProtoTypes = previous.deadProtoTypes;
     this.liveTypes = previous.liveTypes;
@@ -1045,7 +999,7 @@
     return new AppInfoWithLiveness(
         committedItems,
         getClassToFeatureSplitMap().rewrittenWithLens(lens),
-        getMainDexInfo().rewrittenWithLens(lens),
+        getMainDexClasses().rewrittenWithLens(lens),
         deadProtoTypes,
         getMissingClasses().commitSyntheticItems(committedItems),
         lens.rewriteTypes(liveTypes),
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 9c672b1..d734eb5 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -393,13 +393,15 @@
   private final Map<DexMethod, ProgramMethod> methodsWithTwrCloseResource = new IdentityHashMap<>();
   private final Set<DexProgramClass> classesWithSerializableLambdas = Sets.newIdentityHashSet();
   private final ProgramMethodSet pendingDesugaring = ProgramMethodSet.create();
+  private final MainDexTracingResult previousMainDexTracingResult;
 
   Enqueuer(
       AppView<? extends AppInfoWithClassHierarchy> appView,
       ExecutorService executorService,
       SubtypingInfo subtypingInfo,
       GraphConsumer keptGraphConsumer,
-      Mode mode) {
+      Mode mode,
+      MainDexTracingResult previousMainDexTracingResult) {
     assert appView.appServices() != null;
     InternalOptions options = appView.options();
     this.appInfo = appView.appInfo();
@@ -417,6 +419,7 @@
         mode.isInitialTreeShaking() && options.forceProguardCompatibility
             ? ProguardCompatibilityActions.builder()
             : null;
+    this.previousMainDexTracingResult = previousMainDexTracingResult;
 
     if (mode.isInitialOrFinalTreeShaking()) {
       if (options.protoShrinking().enableGeneratedMessageLiteShrinking) {
@@ -1743,7 +1746,7 @@
 
     assert !mode.isFinalMainDexTracing()
             || !options.testing.checkForNotExpandingMainDexTracingResult
-            || appView.appInfo().getMainDexInfo().isTracedRoot(clazz)
+            || previousMainDexTracingResult.isRoot(clazz)
             || clazz.toSourceString().contains(ENUM_UNBOXING_UTILITY_CLASS_SUFFIX)
         : "Class " + clazz.toSourceString() + " was not a main dex root in the first round";
 
@@ -3020,7 +3023,7 @@
   }
 
   // Returns the set of live types.
-  public MainDexInfo traceMainDex(ExecutorService executorService, Timing timing)
+  public Set<DexProgramClass> traceMainDex(ExecutorService executorService, Timing timing)
       throws ExecutionException {
     assert analyses.isEmpty();
     assert mode.isMainDexTracing();
@@ -3029,12 +3032,7 @@
     enqueueRootItems(rootSet.noShrinking);
     trace(executorService, timing);
     options.reporter.failIfPendingErrors();
-    // Calculate the automatic main dex list according to legacy multidex constraints.
-    MainDexInfo.Builder builder = MainDexInfo.builder();
-    liveTypes.getItems().forEach(builder::addRoot);
-    new MainDexListBuilder(appView, builder.getRoots(), builder).run();
-    MainDexInfo previousMainDexInfo = appInfo.getMainDexInfo();
-    return builder.build(previousMainDexInfo);
+    return liveTypes.getItems();
   }
 
   public AppInfoWithLiveness traceApplication(
@@ -3212,9 +3210,9 @@
       appBuilder.addClasspathClasses(syntheticClasspathClasses.values());
     }
 
-    void amendMainDexClasses(MainDexInfo mainDexInfo) {
+    void amendMainDexClasses(MainDexClasses mainDexClasses) {
       assert !isEmpty();
-      mainDexTypes.forEach(mainDexInfo::addSyntheticClass);
+      mainDexClasses.addAll(mainDexTypes);
     }
 
     void enqueueWorkItems(Enqueuer enqueuer) {
@@ -3298,7 +3296,7 @@
               additions.amendApplication(appBuilder);
               return appBuilder.build();
             });
-    additions.amendMainDexClasses(appInfo.getMainDexInfo());
+    additions.amendMainDexClasses(appInfo.getMainDexClasses());
     appView.setAppInfo(appInfo);
     subtypingInfo = new SubtypingInfo(appView);
 
@@ -3492,7 +3490,7 @@
         new AppInfoWithLiveness(
             appInfo.getSyntheticItems().commit(app),
             appInfo.getClassToFeatureSplitMap(),
-            appInfo.getMainDexInfo(),
+            appInfo.getMainDexClasses(),
             deadProtoTypes,
             appView.testing().enableExperimentalMissingClassesReporting
                 ? missingClassesBuilder.reportMissingClasses(appView)
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerFactory.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerFactory.java
index 7636fe8..04fb840 100644
--- a/src/main/java/com/android/tools/r8/shaking/EnqueuerFactory.java
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerFactory.java
@@ -19,7 +19,13 @@
       AppView<? extends AppInfoWithClassHierarchy> appView,
       ExecutorService executorService,
       SubtypingInfo subtypingInfo) {
-    return new Enqueuer(appView, executorService, subtypingInfo, null, Mode.INITIAL_TREE_SHAKING);
+    return new Enqueuer(
+        appView,
+        executorService,
+        subtypingInfo,
+        null,
+        Mode.INITIAL_TREE_SHAKING,
+        MainDexTracingResult.NONE);
   }
 
   public static Enqueuer createForFinalTreeShaking(
@@ -30,7 +36,12 @@
       Set<DexType> initialPrunedTypes) {
     Enqueuer enqueuer =
         new Enqueuer(
-            appView, executorService, subtypingInfo, keptGraphConsumer, Mode.FINAL_TREE_SHAKING);
+            appView,
+            executorService,
+            subtypingInfo,
+            keptGraphConsumer,
+            Mode.FINAL_TREE_SHAKING,
+            MainDexTracingResult.NONE);
     appView.withProtoShrinker(
         shrinker -> enqueuer.setInitialDeadProtoTypes(shrinker.getDeadProtoTypes()));
     enqueuer.setInitialPrunedTypes(initialPrunedTypes);
@@ -42,16 +53,27 @@
       ExecutorService executorService,
       SubtypingInfo subtypingInfo) {
     return new Enqueuer(
-        appView, executorService, subtypingInfo, null, Mode.INITIAL_MAIN_DEX_TRACING);
+        appView,
+        executorService,
+        subtypingInfo,
+        null,
+        Mode.INITIAL_MAIN_DEX_TRACING,
+        MainDexTracingResult.NONE);
   }
 
   public static Enqueuer createForFinalMainDexTracing(
       AppView<? extends AppInfoWithClassHierarchy> appView,
       ExecutorService executorService,
       SubtypingInfo subtypingInfo,
-      GraphConsumer keptGraphConsumer) {
+      GraphConsumer keptGraphConsumer,
+      MainDexTracingResult previousMainDexTracingResult) {
     return new Enqueuer(
-        appView, executorService, subtypingInfo, keptGraphConsumer, Mode.FINAL_MAIN_DEX_TRACING);
+        appView,
+        executorService,
+        subtypingInfo,
+        keptGraphConsumer,
+        Mode.FINAL_MAIN_DEX_TRACING,
+        previousMainDexTracingResult);
   }
 
   public static Enqueuer createForWhyAreYouKeeping(
@@ -60,6 +82,11 @@
       SubtypingInfo subtypingInfo,
       GraphConsumer keptGraphConsumer) {
     return new Enqueuer(
-        appView, executorService, subtypingInfo, keptGraphConsumer, Mode.WHY_ARE_YOU_KEEPING);
+        appView,
+        executorService,
+        subtypingInfo,
+        keptGraphConsumer,
+        Mode.WHY_ARE_YOU_KEEPING,
+        MainDexTracingResult.NONE);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexClasses.java b/src/main/java/com/android/tools/r8/shaking/MainDexClasses.java
new file mode 100644
index 0000000..9c1c529
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/shaking/MainDexClasses.java
@@ -0,0 +1,118 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking;
+
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.GraphLens;
+import com.android.tools.r8.graph.PrunedItems;
+import com.google.common.collect.Sets;
+import java.util.Set;
+import java.util.function.Consumer;
+
+public class MainDexClasses {
+
+  private final Set<DexType> mainDexClasses;
+
+  private MainDexClasses() {
+    this(Sets.newIdentityHashSet());
+  }
+
+  private MainDexClasses(Set<DexType> mainDexClasses) {
+    this.mainDexClasses = mainDexClasses;
+  }
+
+  public static Builder builder() {
+    return new Builder();
+  }
+
+  public static MainDexClasses createEmptyMainDexClasses() {
+    return new MainDexClasses();
+  }
+
+  public void add(DexProgramClass clazz) {
+    mainDexClasses.add(clazz.getType());
+  }
+
+  public void addAll(MainDexClasses other) {
+    mainDexClasses.addAll(other.mainDexClasses);
+  }
+
+  public void addAll(MainDexTracingResult other) {
+    mainDexClasses.addAll(other.getClasses());
+  }
+
+  public void addAll(Iterable<DexProgramClass> classes) {
+    for (DexProgramClass clazz : classes) {
+      add(clazz);
+    }
+  }
+
+  public boolean contains(DexType type) {
+    return mainDexClasses.contains(type);
+  }
+
+  public boolean contains(DexProgramClass clazz) {
+    return contains(clazz.getType());
+  }
+
+  public boolean containsAnyOf(Iterable<DexProgramClass> classes) {
+    for (DexProgramClass clazz : classes) {
+      if (contains(clazz)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  public void forEach(Consumer<DexType> fn) {
+    mainDexClasses.forEach(fn);
+  }
+
+  public boolean isEmpty() {
+    return mainDexClasses.isEmpty();
+  }
+
+  public MainDexClasses rewrittenWithLens(GraphLens lens) {
+    MainDexClasses rewrittenMainDexClasses = createEmptyMainDexClasses();
+    for (DexType mainDexClass : mainDexClasses) {
+      DexType rewrittenMainDexClass = lens.lookupType(mainDexClass);
+      rewrittenMainDexClasses.mainDexClasses.add(rewrittenMainDexClass);
+    }
+    return rewrittenMainDexClasses;
+  }
+
+  public int size() {
+    return mainDexClasses.size();
+  }
+
+  public MainDexClasses withoutPrunedItems(PrunedItems prunedItems) {
+    if (prunedItems.isEmpty()) {
+      return this;
+    }
+    MainDexClasses mainDexClassesAfterPruning = createEmptyMainDexClasses();
+    for (DexType mainDexClass : mainDexClasses) {
+      if (!prunedItems.getRemovedClasses().contains(mainDexClass)) {
+        mainDexClassesAfterPruning.mainDexClasses.add(mainDexClass);
+      }
+    }
+    return mainDexClassesAfterPruning;
+  }
+
+  public static class Builder {
+
+    private final Set<DexType> mainDexClasses = Sets.newIdentityHashSet();
+
+    private Builder() {}
+
+    public void add(DexProgramClass clazz) {
+      mainDexClasses.add(clazz.getType());
+    }
+
+    public MainDexClasses build() {
+      return new MainDexClasses(mainDexClasses);
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexDirectReferenceTracer.java b/src/main/java/com/android/tools/r8/shaking/MainDexDirectReferenceTracer.java
index 459dea6..6f7cade 100644
--- a/src/main/java/com/android/tools/r8/shaking/MainDexDirectReferenceTracer.java
+++ b/src/main/java/com/android/tools/r8/shaking/MainDexDirectReferenceTracer.java
@@ -25,7 +25,6 @@
 import com.android.tools.r8.utils.Box;
 import java.util.Set;
 import java.util.function.Consumer;
-import java.util.function.Predicate;
 
 public class MainDexDirectReferenceTracer {
   private final AnnotationDirectReferenceCollector annotationDirectReferenceCollector =
@@ -67,19 +66,19 @@
     method.registerCodeReferences(codeDirectReferenceCollector);
   }
 
-  public static boolean hasReferencesOutsideMainDexClasses(
-      AppInfoWithClassHierarchy appInfo, ProgramMethod method, Predicate<DexType> isOutside) {
-    return getFirstReferenceOutsideFromCode(appInfo, method, isOutside) != null;
+  public static boolean hasReferencesOutsideFromCode(
+      AppInfoWithClassHierarchy appInfo, ProgramMethod method, Set<DexType> classes) {
+    return getFirstReferenceOutsideFromCode(appInfo, method, classes) != null;
   }
 
   public static DexProgramClass getFirstReferenceOutsideFromCode(
-      AppInfoWithClassHierarchy appInfo, ProgramMethod method, Predicate<DexType> isOutside) {
+      AppInfoWithClassHierarchy appInfo, ProgramMethod method, Set<DexType> classes) {
     Box<DexProgramClass> result = new Box<>();
     new MainDexDirectReferenceTracer(
             appInfo,
             type -> {
               DexType baseType = type.toBaseType(appInfo.dexItemFactory());
-              if (baseType.isClassType() && isOutside.test(baseType)) {
+              if (baseType.isClassType() && !classes.contains(baseType)) {
                 DexClass cls = appInfo.definitionFor(baseType);
                 if (cls != null && cls.isProgramClass()) {
                   result.set(cls.asProgramClass());
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexInfo.java b/src/main/java/com/android/tools/r8/shaking/MainDexInfo.java
deleted file mode 100644
index c95ea3b..0000000
--- a/src/main/java/com/android/tools/r8/shaking/MainDexInfo.java
+++ /dev/null
@@ -1,359 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package com.android.tools.r8.shaking;
-
-import static com.android.tools.r8.shaking.MainDexInfo.MainDexGroup.MAIN_DEX_ROOT;
-import static com.android.tools.r8.utils.PredicateUtils.not;
-
-import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
-import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexReference;
-import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.graph.GraphLens;
-import com.android.tools.r8.graph.ProgramDefinition;
-import com.android.tools.r8.graph.ProgramMethod;
-import com.android.tools.r8.graph.PrunedItems;
-import com.android.tools.r8.utils.ConsumerUtils;
-import com.android.tools.r8.utils.SetUtils;
-import com.google.common.collect.Sets;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Consumer;
-
-public class MainDexInfo {
-
-  private static final MainDexInfo NONE = new MainDexInfo(Sets.newIdentityHashSet());
-
-  public enum MainDexGroup {
-    MAIN_DEX_LIST,
-    MAIN_DEX_ROOT,
-    MAIN_DEX_DEPENDENCY,
-    NOT_IN_MAIN_DEX
-  }
-
-  // Specific set of classes specified to be in main-dex
-  private final Set<DexType> classList;
-  private final Map<DexType, DexType> synthesizedClassesMap;
-  // Traced roots are traced main dex references.
-  private final Set<DexType> tracedRoots;
-  // Traced dependencies are those classes that are directly referenced from traced roots, but will
-  // not be loaded before loading the remaining dex files.
-  private final Set<DexType> tracedDependencies;
-
-  private MainDexInfo(Set<DexType> classList) {
-    this(
-        classList, Sets.newIdentityHashSet(), Sets.newIdentityHashSet(), Sets.newIdentityHashSet());
-  }
-
-  private MainDexInfo(
-      Set<DexType> classList,
-      Set<DexType> tracedRoots,
-      Set<DexType> tracedDependencies,
-      Set<DexType> synthesizedClasses) {
-    this.classList = classList;
-    this.tracedRoots = tracedRoots;
-    this.tracedDependencies = tracedDependencies;
-    this.synthesizedClassesMap = new ConcurrentHashMap<>();
-    synthesizedClasses.forEach(type -> synthesizedClassesMap.put(type, type));
-    assert tracedDependencies.stream().noneMatch(tracedRoots::contains);
-    assert tracedRoots.containsAll(synthesizedClasses);
-  }
-
-  public boolean isNone() {
-    assert none() == NONE;
-    return this == NONE;
-  }
-
-  public boolean isMainDex(ProgramDefinition definition) {
-    return isFromList(definition) || isTracedRoot(definition) || isDependency(definition);
-  }
-
-  public boolean isFromList(ProgramDefinition definition) {
-    return isFromList(definition.getContextType());
-  }
-
-  private boolean isFromList(DexReference reference) {
-    return classList.contains(reference.getContextType());
-  }
-
-  public boolean isTracedRoot(ProgramDefinition definition) {
-    return isTracedRoot(definition.getContextType());
-  }
-
-  private boolean isTracedRoot(DexReference reference) {
-    return tracedRoots.contains(reference.getContextType());
-  }
-
-  private boolean isDependency(ProgramDefinition definition) {
-    return isDependency(definition.getContextType());
-  }
-
-  private boolean isDependency(DexReference reference) {
-    return tracedDependencies.contains(reference.getContextType());
-  }
-
-  public boolean canRebindReference(ProgramMethod context, DexReference referenceToTarget) {
-    MainDexGroup holderGroup = getMainDexGroupInternal(context);
-    if (holderGroup == MainDexGroup.NOT_IN_MAIN_DEX
-        || holderGroup == MainDexGroup.MAIN_DEX_DEPENDENCY) {
-      // We are always free to rebind/inline into something not in main-dex or traced dependencies.
-      return true;
-    }
-    if (holderGroup == MainDexGroup.MAIN_DEX_LIST) {
-      // If the holder is in the class list, we are not allowed to make any assumptions on the
-      // holder
-      // being a root or a dependency. Therefore we cannot merge.
-      return false;
-    }
-    assert holderGroup == MAIN_DEX_ROOT;
-    // Otherwise we allow if either is both root.
-    return getMainDexGroupInternal(referenceToTarget) == MAIN_DEX_ROOT;
-  }
-
-  public boolean canMerge(ProgramDefinition candidate) {
-    return !isFromList(candidate);
-  }
-
-  public boolean canMerge(ProgramDefinition source, ProgramDefinition target) {
-    return canMerge(source.getContextType(), target.getContextType());
-  }
-
-  private boolean canMerge(DexReference source, DexReference target) {
-    MainDexGroup sourceGroup = getMainDexGroupInternal(source);
-    MainDexGroup targetGroup = getMainDexGroupInternal(target);
-    if (sourceGroup != targetGroup) {
-      return false;
-    }
-    // If the holder is in the class list, we are not allowed to make any assumptions on the holder
-    // being a root or a dependency. Therefore we cannot merge.
-    return sourceGroup != MainDexGroup.MAIN_DEX_LIST;
-  }
-
-  public MainDexGroup getMergeKey(ProgramDefinition mergeCandidate) {
-    MainDexGroup mainDexGroupInternal = getMainDexGroupInternal(mergeCandidate);
-    return mainDexGroupInternal == MainDexGroup.MAIN_DEX_LIST ? null : mainDexGroupInternal;
-  }
-
-  private MainDexGroup getMainDexGroupInternal(ProgramDefinition definition) {
-    return getMainDexGroupInternal(definition.getReference());
-  }
-
-  private MainDexGroup getMainDexGroupInternal(DexReference reference) {
-    if (isFromList(reference)) {
-      return MainDexGroup.MAIN_DEX_LIST;
-    }
-    if (isTracedRoot(reference)) {
-      return MAIN_DEX_ROOT;
-    }
-    if (isDependency(reference)) {
-      return MainDexGroup.MAIN_DEX_DEPENDENCY;
-    }
-    return MainDexGroup.NOT_IN_MAIN_DEX;
-  }
-
-  public boolean disallowInliningIntoContext(
-      AppInfoWithClassHierarchy appInfo, ProgramDefinition context, ProgramMethod method) {
-    if (context.getContextType() == method.getContextType()) {
-      return false;
-    }
-    MainDexGroup mainDexGroupInternal = getMainDexGroupInternal(context);
-    if (mainDexGroupInternal == MainDexGroup.NOT_IN_MAIN_DEX
-        || mainDexGroupInternal == MainDexGroup.MAIN_DEX_DEPENDENCY) {
-      return false;
-    }
-    if (mainDexGroupInternal == MainDexGroup.MAIN_DEX_LIST) {
-      return MainDexDirectReferenceTracer.hasReferencesOutsideMainDexClasses(
-          appInfo, method, not(this::isFromList));
-    }
-    assert mainDexGroupInternal == MAIN_DEX_ROOT;
-    return MainDexDirectReferenceTracer.hasReferencesOutsideMainDexClasses(
-        appInfo, method, not(this::isTracedRoot));
-  }
-
-  public boolean isEmpty() {
-    assert !tracedRoots.isEmpty() || tracedDependencies.isEmpty();
-    return tracedRoots.isEmpty() && classList.isEmpty();
-  }
-
-  public static MainDexInfo none() {
-    return NONE;
-  }
-
-  // TODO(b/178127572): This mutates the MainDexClasses which otherwise should be immutable.
-  public void addSyntheticClass(DexProgramClass clazz) {
-    // TODO(b/178127572): This will add a synthesized type as long as the initial set is not empty.
-    //  A better approach would be to use the context for the synthetic with a containment check.
-    assert !isNone();
-    if (!classList.isEmpty()) {
-      synthesizedClassesMap.computeIfAbsent(
-          clazz.type,
-          type -> {
-            classList.add(type);
-            return type;
-          });
-    }
-    if (!tracedRoots.isEmpty()) {
-      synthesizedClassesMap.computeIfAbsent(
-          clazz.type,
-          type -> {
-            classList.add(type);
-            return type;
-          });
-    }
-  }
-
-  public void addLegacySyntheticClass(DexProgramClass clazz, ProgramDefinition context) {
-    if (isTracedRoot(context) || isFromList(context) || isDependency(context)) {
-      addSyntheticClass(clazz);
-    }
-  }
-
-  public int size() {
-    return classList.size() + tracedRoots.size() + tracedDependencies.size();
-  }
-
-  public void forEachExcludingDependencies(Consumer<DexType> fn) {
-    // Prevent seeing duplicates in the list and roots.
-    Set<DexType> seen = Sets.newIdentityHashSet();
-    classList.forEach(ConsumerUtils.acceptIfNotSeen(fn, seen));
-    tracedRoots.forEach(ConsumerUtils.acceptIfNotSeen(fn, seen));
-  }
-
-  public void forEach(Consumer<DexType> fn) {
-    // Prevent seeing duplicates in the list and roots.
-    Set<DexType> seen = Sets.newIdentityHashSet();
-    classList.forEach(ConsumerUtils.acceptIfNotSeen(fn, seen));
-    tracedRoots.forEach(ConsumerUtils.acceptIfNotSeen(fn, seen));
-    tracedDependencies.forEach(ConsumerUtils.acceptIfNotSeen(fn, seen));
-  }
-
-  public MainDexInfo withoutPrunedItems(PrunedItems prunedItems) {
-    if (prunedItems.isEmpty()) {
-      return this;
-    }
-    Set<DexType> removedClasses = prunedItems.getRemovedClasses();
-    MainDexInfo.Builder builder = builder();
-    Set<DexType> modifiedClassList = Sets.newIdentityHashSet();
-    Set<DexType> modifiedSynthesized = Sets.newIdentityHashSet();
-    classList.forEach(type -> ifNotRemoved(type, removedClasses, modifiedClassList::add));
-    synthesizedClassesMap
-        .keySet()
-        .forEach(type -> ifNotRemoved(type, removedClasses, modifiedSynthesized::add));
-    tracedRoots.forEach(type -> ifNotRemoved(type, removedClasses, builder::addRoot));
-    tracedDependencies.forEach(type -> ifNotRemoved(type, removedClasses, builder::addDependency));
-    return builder.build(modifiedClassList, modifiedSynthesized);
-  }
-
-  private void ifNotRemoved(
-      DexType type, Set<DexType> removedClasses, Consumer<DexType> notRemoved) {
-    if (!removedClasses.contains(type)) {
-      notRemoved.accept(type);
-    }
-  }
-
-  public MainDexInfo rewrittenWithLens(GraphLens lens) {
-    MainDexInfo.Builder builder = builder();
-    Set<DexType> modifiedClassList = Sets.newIdentityHashSet();
-    Set<DexType> modifiedSynthesized = Sets.newIdentityHashSet();
-    classList.forEach(type -> modifiedClassList.add(lens.lookupType(type)));
-    synthesizedClassesMap.keySet().forEach(type -> modifiedSynthesized.add(lens.lookupType(type)));
-    tracedRoots.forEach(type -> builder.addRoot(lens.lookupType(type)));
-    tracedDependencies.forEach(type -> builder.addDependency(lens.lookupType(type)));
-    return builder.build(modifiedClassList, modifiedSynthesized);
-  }
-
-  public static Builder builder() {
-    return new Builder();
-  }
-
-  public static MainDexInfo createEmptyMainDexClasses() {
-    return new MainDexInfo(
-        Sets.newIdentityHashSet(),
-        Sets.newIdentityHashSet(),
-        Sets.newIdentityHashSet(),
-        Sets.newIdentityHashSet());
-  }
-
-  public static class Builder {
-
-    private final Set<DexType> list = Sets.newIdentityHashSet();
-    private final Set<DexType> roots = Sets.newIdentityHashSet();
-    private final Set<DexType> dependencies = Sets.newIdentityHashSet();
-
-    private Builder() {}
-
-    public void addList(DexProgramClass clazz) {
-      list.add(clazz.getType());
-    }
-
-    public void addRoot(DexProgramClass clazz) {
-      roots.add(clazz.getType());
-    }
-
-    public void addRoot(DexType type) {
-      roots.add(type);
-    }
-
-    public void addDependency(DexProgramClass clazz) {
-      addDependency(clazz.getType());
-    }
-
-    public void addDependency(DexType type) {
-      assert !roots.contains(type);
-      dependencies.add(type);
-    }
-
-    public boolean isTracedRoot(DexProgramClass clazz) {
-      return isTracedRoot(clazz.getType());
-    }
-
-    public boolean isTracedRoot(DexType type) {
-      return roots.contains(type);
-    }
-
-    public boolean isDependency(DexProgramClass clazz) {
-      return isDependency(clazz.getType());
-    }
-
-    public boolean isDependency(DexType type) {
-      return dependencies.contains(type);
-    }
-
-    public boolean contains(DexProgramClass clazz) {
-      return contains(clazz.type);
-    }
-
-    public boolean contains(DexType type) {
-      return isTracedRoot(type) || isDependency(type);
-    }
-
-    public Set<DexType> getRoots() {
-      return roots;
-    }
-
-    public MainDexInfo buildList() {
-      // When building without passing the list, the method roots and dependencies should
-      // be empty since no tracing has been done.
-      assert dependencies.isEmpty();
-      assert roots.isEmpty();
-      return new MainDexInfo(list);
-    }
-
-    public MainDexInfo build(Set<DexType> classList, Set<DexType> synthesizedClasses) {
-      // Class can contain dependencies which we should not regard as roots.
-      assert list.isEmpty();
-      return new MainDexInfo(
-          classList,
-          SetUtils.unionIdentityHashSet(roots, synthesizedClasses),
-          Sets.difference(dependencies, synthesizedClasses),
-          synthesizedClasses);
-    }
-
-    public MainDexInfo build(MainDexInfo previous) {
-      return build(previous.classList, previous.synthesizedClassesMap.keySet());
-    }
-  }
-}
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexListBuilder.java b/src/main/java/com/android/tools/r8/shaking/MainDexListBuilder.java
index afb39bf..4fa5d3a 100644
--- a/src/main/java/com/android/tools/r8/shaking/MainDexListBuilder.java
+++ b/src/main/java/com/android/tools/r8/shaking/MainDexListBuilder.java
@@ -13,6 +13,7 @@
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.utils.SetUtils;
 import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.Set;
@@ -29,7 +30,7 @@
   private final Set<DexType> roots;
   private final AppView<? extends AppInfoWithClassHierarchy> appView;
   private final Map<DexType, Boolean> annotationTypeContainEnum;
-  private final MainDexInfo.Builder mainDexInfoBuilder;
+  private final MainDexTracingResult.Builder mainDexClassesBuilder;
 
   public static void checkForAssumedLibraryTypes(AppInfo appInfo) {
     DexClass enumType = appInfo.definitionFor(appInfo.dexItemFactory().enumType);
@@ -49,14 +50,11 @@
    * @param appView the dex appplication.
    */
   public MainDexListBuilder(
-      AppView<? extends AppInfoWithClassHierarchy> appView,
-      Set<DexType> roots,
-      MainDexInfo.Builder mainDexInfoBuilder) {
+      Set<DexProgramClass> roots, AppView<? extends AppInfoWithClassHierarchy> appView) {
     this.appView = appView;
     // Only consider program classes for the root set.
-    assert roots.stream().allMatch(type -> appView.definitionFor(type).isProgramClass());
-    this.roots = roots;
-    this.mainDexInfoBuilder = mainDexInfoBuilder;
+    this.roots = SetUtils.mapIdentityHashSet(roots, DexProgramClass::getType);
+    mainDexClassesBuilder = MainDexTracingResult.builder(appView.appInfo()).addRoots(this.roots);
     annotationTypeContainEnum = new IdentityHashMap<>();
   }
 
@@ -64,17 +62,18 @@
     return appView.appInfo();
   }
 
-  public void run() {
+  public MainDexTracingResult run() {
     traceMainDexDirectDependencies();
     traceRuntimeAnnotationsWithEnumForMainDex();
+    return mainDexClassesBuilder.build();
   }
 
   private void traceRuntimeAnnotationsWithEnumForMainDex() {
     for (DexProgramClass clazz : appInfo().classes()) {
-      if (mainDexInfoBuilder.contains(clazz)) {
+      DexType dexType = clazz.type;
+      if (mainDexClassesBuilder.contains(dexType)) {
         continue;
       }
-      DexType dexType = clazz.type;
       if (isAnnotation(dexType) && isAnnotationWithEnum(dexType)) {
         addAnnotationsWithEnum(clazz);
         continue;
@@ -83,11 +82,10 @@
       // annotations with enums goes into the main dex, move annotated classes there as well.
       clazz.forEachAnnotation(
           annotation -> {
-            if (!mainDexInfoBuilder.contains(clazz)
+            if (!mainDexClassesBuilder.contains(dexType)
                 && annotation.visibility == DexAnnotation.VISIBILITY_RUNTIME
                 && isAnnotationWithEnum(annotation.annotation.type)) {
-              // Just add classes annotated with annotations with enum as direct dependencies.
-              mainDexInfoBuilder.addDependency(clazz);
+              addClassAnnotatedWithAnnotationWithEnum(dexType);
             }
           });
     }
@@ -153,10 +151,16 @@
     }
   }
 
+  private void addClassAnnotatedWithAnnotationWithEnum(DexType type) {
+    // Just add classes annotated with annotations with enum ad direct dependencies.
+    addDirectDependency(type);
+  }
+
   private void addDirectDependency(DexType type) {
     // Consider only component type of arrays
     type = type.toBaseType(appView.dexItemFactory());
-    if (!type.isClassType() || mainDexInfoBuilder.contains(type)) {
+
+    if (!type.isClassType() || mainDexClassesBuilder.contains(type)) {
       return;
     }
 
@@ -169,8 +173,9 @@
   }
 
   private void addDirectDependency(DexProgramClass dexClass) {
-    assert !mainDexInfoBuilder.contains(dexClass);
-    mainDexInfoBuilder.addDependency(dexClass);
+    DexType type = dexClass.type;
+    assert !mainDexClassesBuilder.contains(type);
+    mainDexClassesBuilder.addDependency(type);
     if (dexClass.superType != null) {
       addDirectDependency(dexClass.superType);
     }
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java b/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java
new file mode 100644
index 0000000..5bf287f
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java
@@ -0,0 +1,172 @@
+// Copyright (c) 2018, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking;
+
+import com.android.tools.r8.graph.AppInfo;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramDefinition;
+import com.android.tools.r8.utils.SetUtils;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
+
+public class MainDexTracingResult {
+
+  public static MainDexTracingResult NONE =
+      new MainDexTracingResult(ImmutableSet.of(), ImmutableSet.of());
+
+  public static class Builder {
+    public final AppInfo appInfo;
+    public final Set<DexType> roots;
+    public final Set<DexType> dependencies;
+
+    private Builder(AppInfo appInfo) {
+      this(appInfo, Sets.newIdentityHashSet(), Sets.newIdentityHashSet());
+    }
+
+    private Builder(AppInfo appInfo, MainDexTracingResult mainDexTracingResult) {
+      this(
+          appInfo,
+          SetUtils.newIdentityHashSet(mainDexTracingResult.getRoots()),
+          SetUtils.newIdentityHashSet(mainDexTracingResult.getDependencies()));
+    }
+
+    private Builder(AppInfo appInfo, Set<DexType> roots, Set<DexType> dependencies) {
+      this.appInfo = appInfo;
+      this.roots = roots;
+      this.dependencies = dependencies;
+    }
+
+    public Builder addRoot(DexProgramClass clazz) {
+      roots.add(clazz.getType());
+      return this;
+    }
+
+    public Builder addRoot(DexType type) {
+      assert isProgramClass(type) : type.toSourceString();
+      roots.add(type);
+      return this;
+    }
+
+    public Builder addRoots(Collection<DexType> rootSet) {
+      assert rootSet.stream().allMatch(this::isProgramClass);
+      this.roots.addAll(rootSet);
+      return this;
+    }
+
+    public Builder addDependency(DexType type) {
+      assert isProgramClass(type);
+      dependencies.add(type);
+      return this;
+    }
+
+    public boolean contains(DexType type) {
+      return roots.contains(type) || dependencies.contains(type);
+    }
+
+    public MainDexTracingResult build() {
+      return new MainDexTracingResult(roots, dependencies);
+    }
+
+    private boolean isProgramClass(DexType dexType) {
+      DexClass clazz = appInfo.definitionFor(dexType);
+      return clazz != null && clazz.isProgramClass();
+    }
+  }
+
+  // The classes in the root set.
+  private final Set<DexType> roots;
+  // Additional dependencies (direct dependencies and runtime annotations with enums).
+  private final Set<DexType> dependencies;
+  // All main dex classes.
+  private final Set<DexType> classes;
+
+  private MainDexTracingResult(Set<DexType> roots, Set<DexType> dependencies) {
+    assert Sets.intersection(roots, dependencies).isEmpty();
+    this.roots = Collections.unmodifiableSet(roots);
+    this.dependencies = Collections.unmodifiableSet(dependencies);
+    this.classes = Sets.union(roots, dependencies);
+  }
+
+  public boolean canReferenceItemFromContextWithoutIncreasingMainDexSize(
+      ProgramDefinition item, ProgramDefinition context) {
+    // If the context is not a root, then additional references from inside the context will not
+    // increase the size of the main dex.
+    if (!isRoot(context)) {
+      return true;
+    }
+    // Otherwise, require that the item is a root itself.
+    return isRoot(item);
+  }
+
+  public boolean isEmpty() {
+    assert !roots.isEmpty() || dependencies.isEmpty();
+    return roots.isEmpty();
+  }
+
+  public Set<DexType> getRoots() {
+    return roots;
+  }
+
+  public Set<DexType> getDependencies() {
+    return dependencies;
+  }
+
+  public Set<DexType> getClasses() {
+    return classes;
+  }
+
+  public boolean contains(ProgramDefinition clazz) {
+    return contains(clazz.getContextType());
+  }
+
+  public boolean contains(DexType type) {
+    return getClasses().contains(type);
+  }
+
+  private void collectTypesMatching(
+      Set<DexType> types, Predicate<DexType> predicate, Consumer<DexType> consumer) {
+    types.forEach(
+        type -> {
+          if (predicate.test(type)) {
+            consumer.accept(type);
+          }
+        });
+  }
+
+  public boolean isRoot(ProgramDefinition definition) {
+    return getRoots().contains(definition.getContextType());
+  }
+
+  public boolean isRoot(DexType type) {
+    return getRoots().contains(type);
+  }
+
+  public boolean isDependency(ProgramDefinition definition) {
+    return getDependencies().contains(definition.getContextType());
+  }
+
+  public MainDexTracingResult prunedCopy(AppInfoWithLiveness appInfo) {
+    Builder builder = builder(appInfo);
+    Predicate<DexType> wasPruned = appInfo::wasPruned;
+    collectTypesMatching(roots, wasPruned.negate(), builder::addRoot);
+    collectTypesMatching(dependencies, wasPruned.negate(), builder::addDependency);
+    return builder.build();
+  }
+
+  public static Builder builder(AppInfo appInfo) {
+    return new Builder(appInfo);
+  }
+
+  public Builder extensionBuilder(AppInfo appInfo) {
+    return new Builder(appInfo, this);
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index 4963f21..c33cf70 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -219,22 +219,23 @@
   // All the bridge methods that have been synthesized during vertical class merging.
   private final List<SynthesizedBridgeCode> synthesizedBridges = new ArrayList<>();
 
-  private final MainDexInfo mainDexInfo;
+  private final MainDexTracingResult mainDexClasses;
 
   public VerticalClassMerger(
       DexApplication application,
       AppView<AppInfoWithLiveness> appView,
       ExecutorService executorService,
-      Timing timing) {
+      Timing timing,
+      MainDexTracingResult mainDexClasses) {
     this.application = application;
     this.appInfo = appView.appInfo();
     this.appView = appView;
-    this.mainDexInfo = appInfo.getMainDexInfo();
     this.subtypingInfo = appInfo.computeSubtypingInfo();
     this.executorService = executorService;
     this.methodPoolCollection = new MethodPoolCollection(appView, subtypingInfo);
     this.lensBuilder = new VerticalClassMergerGraphLens.Builder(appView.dexItemFactory());
     this.timing = timing;
+    this.mainDexClasses = mainDexClasses;
 
     Iterable<DexProgramClass> classes = application.classesWithDeterministicOrder();
     initializePinnedTypes(classes); // Must be initialized prior to mergeCandidates.
@@ -825,8 +826,20 @@
       return;
     }
 
-    // Check with main dex classes to see if we are allowed to merge.
-    if (!mainDexInfo.canMerge(clazz, targetClass)) {
+    // For a main dex class in the dependent set only merge with other classes in either main dex
+    // set.
+    if ((mainDexClasses.getDependencies().contains(clazz.type)
+            || mainDexClasses.getDependencies().contains(targetClass.type))
+        && !(mainDexClasses.getClasses().contains(clazz.type)
+            && mainDexClasses.getClasses().contains(targetClass.type))) {
+      return;
+    }
+
+    // For a main dex class in the root set only merge with other classes in main dex root set.
+    if ((mainDexClasses.getRoots().contains(clazz.type)
+            || mainDexClasses.getRoots().contains(targetClass.type))
+        && !(mainDexClasses.getRoots().contains(clazz.type)
+            && mainDexClasses.getRoots().contains(targetClass.type))) {
       return;
     }
 
@@ -1657,7 +1670,9 @@
         }
         // Constructors can have references beyond the root main dex classes. This can increase the
         // size of the main dex dependent classes and we should bail out.
-        if (mainDexInfo.disallowInliningIntoContext(appView.appInfo(), context, method)) {
+        if (mainDexClasses.getRoots().contains(context.type)
+            && MainDexDirectReferenceTracer.hasReferencesOutsideFromCode(
+                appView.appInfo(), method, mainDexClasses.getRoots())) {
           return AbortReason.MAIN_DEX_ROOT_OUTSIDE_REFERENCE;
         }
         return null;
diff --git a/src/main/java/com/android/tools/r8/synthesis/SynthesizingContext.java b/src/main/java/com/android/tools/r8/synthesis/SynthesizingContext.java
index 1067e93..f6da35b 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SynthesizingContext.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SynthesizingContext.java
@@ -14,7 +14,7 @@
 import com.android.tools.r8.graph.GraphLens.NonIdentityGraphLens;
 import com.android.tools.r8.graph.ProgramDefinition;
 import com.android.tools.r8.origin.Origin;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.synthesis.SyntheticNaming.Phase;
 import com.android.tools.r8.synthesis.SyntheticNaming.SyntheticKind;
 import java.util.Comparator;
@@ -155,13 +155,13 @@
 
   void addIfDerivedFromMainDexClass(
       DexProgramClass externalSyntheticClass,
-      MainDexInfo mainDexInfo,
+      MainDexClasses mainDexClasses,
       Set<DexType> allMainDexTypes) {
     // The input context type (not the annotated context) determines if the derived class is to be
     // in main dex.
     // TODO(b/168584485): Once resolved allMainDexTypes == mainDexClasses.
     if (allMainDexTypes.contains(inputContextType)) {
-      mainDexInfo.addSyntheticClass(externalSyntheticClass);
+      mainDexClasses.add(externalSyntheticClass);
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticFinalization.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticFinalization.java
index 4c7432b..2d5ec60 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticFinalization.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticFinalization.java
@@ -23,7 +23,7 @@
 import com.android.tools.r8.graph.TreeFixerBase;
 import com.android.tools.r8.ir.code.NumberGenerator;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.synthesis.SyntheticNaming.SyntheticKind;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.SetUtils;
@@ -225,7 +225,7 @@
     assert !appView.appInfo().hasClassHierarchy();
     assert !appView.appInfo().hasLiveness();
     Result result = appView.getSyntheticItems().computeFinalSynthetics(appView);
-    appView.setAppInfo(new AppInfo(result.commit, appView.appInfo().getMainDexInfo()));
+    appView.setAppInfo(new AppInfo(result.commit, appView.appInfo().getMainDexClasses()));
     appView.pruneItems(result.prunedItems);
     if (result.lens != null) {
       appView.setGraphLens(result.lens);
@@ -252,7 +252,7 @@
   Result computeFinalSynthetics(AppView<?> appView) {
     assert verifyNoNestedSynthetics();
     DexApplication application;
-    MainDexInfo mainDexInfo = appView.appInfo().getMainDexInfo();
+    MainDexClasses mainDexClasses = appView.appInfo().getMainDexClasses();
     Builder lensBuilder = new Builder();
     ImmutableMap.Builder<DexType, SyntheticMethodReference> finalMethodsBuilder =
         ImmutableMap.builder();
@@ -266,6 +266,7 @@
               appView,
               computeEquivalences(appView, synthetics.getNonLegacyMethods().values(), generators),
               computeEquivalences(appView, synthetics.getNonLegacyClasses().values(), generators),
+              mainDexClasses,
               lensBuilder,
               (clazz, reference) -> {
                 finalSyntheticProgramDefinitions.add(clazz);
@@ -281,9 +282,13 @@
         finalClassesBuilder.build();
 
     handleSynthesizedClassMapping(
-        finalSyntheticProgramDefinitions, application, options, mainDexInfo, lensBuilder.typeMap);
+        finalSyntheticProgramDefinitions,
+        application,
+        options,
+        mainDexClasses,
+        lensBuilder.typeMap);
 
-    assert appView.appInfo().getMainDexInfo() == mainDexInfo;
+    assert appView.appInfo().getMainDexClasses() == mainDexClasses;
 
     Set<DexType> prunedSynthetics = Sets.newIdentityHashSet();
     synthetics.forEachNonLegacyItem(
@@ -347,13 +352,14 @@
       List<DexProgramClass> finalSyntheticClasses,
       DexApplication application,
       InternalOptions options,
-      MainDexInfo mainDexInfo,
+      MainDexClasses mainDexClasses,
       Map<DexType, DexType> derivedMainDexTypesToIgnore) {
     boolean includeSynthesizedClassMappingInOutput = shouldAnnotateSynthetics(options);
     if (includeSynthesizedClassMappingInOutput) {
       updateSynthesizedClassMapping(application, finalSyntheticClasses);
     }
-    updateMainDexListWithSynthesizedClassMap(application, mainDexInfo, derivedMainDexTypesToIgnore);
+    updateMainDexListWithSynthesizedClassMap(
+        application, mainDexClasses, derivedMainDexTypesToIgnore);
     if (!includeSynthesizedClassMappingInOutput) {
       clearSynthesizedClassMapping(application);
     }
@@ -399,13 +405,13 @@
 
   private void updateMainDexListWithSynthesizedClassMap(
       DexApplication application,
-      MainDexInfo mainDexInfo,
+      MainDexClasses mainDexClasses,
       Map<DexType, DexType> derivedMainDexTypesToIgnore) {
-    if (mainDexInfo.isEmpty()) {
+    if (mainDexClasses.isEmpty()) {
       return;
     }
     List<DexProgramClass> newMainDexClasses = new ArrayList<>();
-    mainDexInfo.forEachExcludingDependencies(
+    mainDexClasses.forEach(
         dexType -> {
           DexProgramClass programClass =
               DexProgramClass.asProgramClassOrNull(application.definitionFor(dexType));
@@ -423,7 +429,7 @@
             }
           }
         });
-    newMainDexClasses.forEach(mainDexInfo::addSyntheticClass);
+    mainDexClasses.addAll(newMainDexClasses);
   }
 
   private void clearSynthesizedClassMapping(DexApplication application) {
@@ -437,6 +443,7 @@
       AppView<?> appView,
       Map<DexType, EquivalenceGroup<SyntheticMethodDefinition>> syntheticMethodGroups,
       Map<DexType, EquivalenceGroup<SyntheticProgramClassDefinition>> syntheticClassGroups,
+      MainDexClasses mainDexClasses,
       Builder lensBuilder,
       BiConsumer<DexProgramClass, SyntheticProgramClassReference> addFinalSyntheticClass,
       BiConsumer<DexProgramClass, SyntheticMethodReference> addFinalSyntheticMethod) {
@@ -446,8 +453,7 @@
 
     // TODO(b/168584485): Remove this once class-mapping support is removed.
     Set<DexType> derivedMainDexTypes = Sets.newIdentityHashSet();
-    MainDexInfo mainDexInfo = appView.appInfo().getMainDexInfo();
-    mainDexInfo.forEachExcludingDependencies(
+    mainDexClasses.forEach(
         mainDexType -> {
           derivedMainDexTypes.add(mainDexType);
           DexProgramClass mainDexClass =
@@ -560,7 +566,7 @@
             addMainDexAndSynthesizedFromForMember(
                 member,
                 externalSyntheticClass,
-                mainDexInfo,
+                mainDexClasses,
                 derivedMainDexTypes,
                 appForLookup::programDefinitionFor);
           }
@@ -582,7 +588,7 @@
             addMainDexAndSynthesizedFromForMember(
                 member,
                 externalSyntheticClass,
-                mainDexInfo,
+                mainDexClasses,
                 derivedMainDexTypes,
                 appForLookup::programDefinitionFor);
           }
@@ -632,12 +638,12 @@
   private static void addMainDexAndSynthesizedFromForMember(
       SyntheticDefinition<?, ?, ?> member,
       DexProgramClass externalSyntheticClass,
-      MainDexInfo mainDexInfo,
+      MainDexClasses mainDexClasses,
       Set<DexType> derivedMainDexTypes,
       Function<DexType, DexProgramClass> definitions) {
     member
         .getContext()
-        .addIfDerivedFromMainDexClass(externalSyntheticClass, mainDexInfo, derivedMainDexTypes);
+        .addIfDerivedFromMainDexClass(externalSyntheticClass, mainDexClasses, derivedMainDexTypes);
     // TODO(b/168584485): Remove this once class-mapping support is removed.
     DexProgramClass from = definitions.apply(member.getContext().getSynthesizingContextType());
     if (from != null) {
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticItems.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticItems.java
index 14f9cf8..c40dbee 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticItems.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticItems.java
@@ -134,7 +134,7 @@
     CommittedItems commit =
         new CommittedItems(
             synthetics.nextSyntheticId, appView.appInfo().app(), committed, ImmutableList.of());
-    appView.setAppInfo(new AppInfo(commit, appView.appInfo().getMainDexInfo()));
+    appView.setAppInfo(new AppInfo(commit, appView.appInfo().getMainDexClasses()));
   }
 
   // Internal synthetic id creation helpers.
diff --git a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
index 2bac333..286c83d 100644
--- a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
+++ b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
@@ -31,7 +31,7 @@
 import com.android.tools.r8.references.FieldReference;
 import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.references.Reference;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.tracereferences.TraceReferencesConsumer.AccessFlags;
 import com.android.tools.r8.tracereferences.TraceReferencesConsumer.ClassAccessFlags;
 import com.android.tools.r8.tracereferences.TraceReferencesConsumer.FieldAccessFlags;
@@ -236,7 +236,7 @@
         AppInfoWithClassHierarchy.createInitialAppInfoWithClassHierarchy(
             application,
             ClassToFeatureSplitMap.createEmptyClassToFeatureSplitMap(),
-            MainDexInfo.createEmptyMainDexClasses());
+            MainDexClasses.createEmptyMainDexClasses());
   }
 
   void run(TraceReferencesConsumer consumer) {
diff --git a/src/main/java/com/android/tools/r8/utils/SetUtils.java b/src/main/java/com/android/tools/r8/utils/SetUtils.java
index a278196..022f5d9 100644
--- a/src/main/java/com/android/tools/r8/utils/SetUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/SetUtils.java
@@ -62,11 +62,4 @@
     }
     return out;
   }
-
-  public static <T> Set<T> unionIdentityHashSet(Set<T> one, Set<T> other) {
-    Set<T> union = Sets.newIdentityHashSet();
-    union.addAll(one);
-    union.addAll(other);
-    return union;
-  }
 }
diff --git a/src/test/java/com/android/tools/r8/TestBase.java b/src/test/java/com/android/tools/r8/TestBase.java
index 3df21b1..fe0aa6c 100644
--- a/src/test/java/com/android/tools/r8/TestBase.java
+++ b/src/test/java/com/android/tools/r8/TestBase.java
@@ -46,7 +46,7 @@
 import com.android.tools.r8.references.TypeReference;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.EnqueuerFactory;
-import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.shaking.NoHorizontalClassMergingRule;
 import com.android.tools.r8.shaking.NoVerticalClassMergingRule;
 import com.android.tools.r8.shaking.ProguardClassNameList;
@@ -710,7 +710,7 @@
     return AppInfoWithClassHierarchy.createInitialAppInfoWithClassHierarchy(
         readApplicationForDexOutput(app, new InternalOptions()),
         ClassToFeatureSplitMap.createEmptyClassToFeatureSplitMap(),
-        MainDexInfo.createEmptyMainDexClasses());
+        MainDexClasses.createEmptyMainDexClasses());
   }
 
   protected static AppView<AppInfoWithClassHierarchy> computeAppViewWithClassHierachy(
diff --git a/src/test/java/com/android/tools/r8/ir/IrInjectionTestBase.java b/src/test/java/com/android/tools/r8/ir/IrInjectionTestBase.java
index 55fcc74..1d7686c 100644
--- a/src/test/java/com/android/tools/r8/ir/IrInjectionTestBase.java
+++ b/src/test/java/com/android/tools/r8/ir/IrInjectionTestBase.java
@@ -14,6 +14,7 @@
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.NumberGenerator;
 import com.android.tools.r8.ir.conversion.IRConverter;
+import com.android.tools.r8.shaking.MainDexTracingResult;
 import com.android.tools.r8.smali.SmaliBuilder;
 import com.android.tools.r8.smali.SmaliBuilder.MethodSignature;
 import com.android.tools.r8.smali.SmaliTestBase;
@@ -112,7 +113,7 @@
 
     public String run() throws IOException {
       Timing timing = Timing.empty();
-      IRConverter converter = new IRConverter(appView, timing, null);
+      IRConverter converter = new IRConverter(appView, timing, null, MainDexTracingResult.NONE);
       converter.replaceCodeForTesting(method, code);
       AndroidApp app = writeDex();
       return runOnArtRaw(app, DEFAULT_MAIN_CLASS_NAME).stdout;
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningSpuriousRootTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningSpuriousRootTest.java
deleted file mode 100644
index 5f21fee..0000000
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningSpuriousRootTest.java
+++ /dev/null
@@ -1,152 +0,0 @@
-// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-package com.android.tools.r8.maindexlist;
-
-import static com.android.tools.r8.utils.codeinspector.CodeMatchers.invokesMethod;
-import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertEquals;
-
-import com.android.tools.r8.NeverInline;
-import com.android.tools.r8.NoHorizontalClassMerging;
-import com.android.tools.r8.R8TestCompileResult;
-import com.android.tools.r8.TestBase;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
-import com.android.tools.r8.ToolHelper;
-import com.android.tools.r8.references.ClassReference;
-import com.android.tools.r8.utils.codeinspector.ClassSubject;
-import com.android.tools.r8.utils.codeinspector.CodeInspector;
-import com.android.tools.r8.utils.codeinspector.MethodSubject;
-import com.google.common.collect.ImmutableSet;
-import java.util.List;
-import org.junit.BeforeClass;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
-import org.junit.runners.Parameterized.Parameters;
-
-@RunWith(Parameterized.class)
-public class MainDexListFromGenerateMainDexInliningSpuriousRootTest extends TestBase {
-
-  private static List<ClassReference> mainDexList;
-  private final TestParameters parameters;
-
-  @Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters()
-        .withDexRuntimes()
-        .withApiLevelsEndingAtExcluding(apiLevelWithNativeMultiDexSupport())
-        .build();
-  }
-
-  @BeforeClass
-  public static void setup() throws Exception {
-    mainDexList =
-        testForMainDexListGenerator(getStaticTemp())
-            .addInnerClasses(MainDexListFromGenerateMainDexInliningSpuriousRootTest.class)
-            .addLibraryFiles(ToolHelper.getMostRecentAndroidJar())
-            .addMainDexRules(
-                "-keep class " + Main.class.getTypeName() + " {",
-                "  public static void main(java.lang.String[]);",
-                "}")
-            .run()
-            .getMainDexList();
-  }
-
-  public MainDexListFromGenerateMainDexInliningSpuriousRootTest(TestParameters parameters) {
-    this.parameters = parameters;
-  }
-
-  @Test
-  public void test() throws Exception {
-    // The generated main dex list should contain Main (which is a root) and A (which is a direct
-    // dependency of Main).
-    assertEquals(2, mainDexList.size());
-    assertEquals(A.class.getTypeName(), mainDexList.get(0).getTypeName());
-    assertEquals(Main.class.getTypeName(), mainDexList.get(1).getTypeName());
-    R8TestCompileResult compileResult =
-        testForR8(parameters.getBackend())
-            .addInnerClasses(getClass())
-            .addInliningAnnotations()
-            .addKeepClassAndMembersRules(Main.class)
-            .addKeepMainRule(Main2.class)
-            .addMainDexListClassReferences(mainDexList)
-            .addMainDexRules(
-                "-keep class " + Main2.class.getTypeName() + " {",
-                "  public static void main(java.lang.String[]);",
-                "}")
-            .collectMainDexClasses()
-            .enableInliningAnnotations()
-            .enableNoHorizontalClassMergingAnnotations()
-            .setMinApi(parameters.getApiLevel())
-            .noMinification()
-            .compile();
-    CodeInspector inspector = compileResult.inspector();
-    ClassSubject mainClassSubject = inspector.clazz(Main.class);
-    assertThat(mainClassSubject, isPresent());
-    MethodSubject fooMethodSubject = mainClassSubject.uniqueMethodWithName("foo");
-    assertThat(fooMethodSubject, isPresent());
-    ClassSubject main2ClassSubject = inspector.clazz(Main2.class);
-    assertThat(main2ClassSubject, isPresent());
-    ClassSubject aClassSubject = inspector.clazz(A.class);
-    assertThat(aClassSubject, isPresent());
-    MethodSubject barMethodSubject = aClassSubject.uniqueMethodWithName("bar");
-    assertThat(barMethodSubject, isPresent());
-    ClassSubject bClassSubject = inspector.clazz(B.class);
-    assertThat(bClassSubject, isPresent());
-    MethodSubject bazMethodSubject = bClassSubject.uniqueMethodWithName("baz");
-    assertThat(bazMethodSubject, isPresent());
-    assertThat(fooMethodSubject, invokesMethod(barMethodSubject));
-    assertThat(barMethodSubject, invokesMethod(bazMethodSubject));
-    assertEquals(
-        ImmutableSet.of(
-            mainClassSubject.getFinalName(),
-            main2ClassSubject.getFinalName(),
-            aClassSubject.getFinalName()),
-        compileResult.getMainDexClasses());
-  }
-
-  static class Main {
-
-    public static void main(String[] args) {
-      System.out.println("Main.main()");
-    }
-
-    static void foo() {
-      A.bar();
-    }
-  }
-
-  static class Main2 {
-
-    public static void main(String[] args) {
-      if (getFalse()) {
-        A.bar();
-      }
-    }
-
-    static boolean getFalse() {
-      return false;
-    }
-  }
-
-  @NoHorizontalClassMerging
-  static class A {
-    // Must not be inlined into Main.foo(), since that would cause B to become direct dependence of
-    // Main without ending up in the main dex.
-    static void bar() {
-      B.baz();
-    }
-  }
-
-  @NoHorizontalClassMerging
-  static class B {
-
-    @NeverInline
-    static void baz() {
-      System.out.println("B.baz");
-    }
-  }
-}
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningTest.java
index 98ec426..a76d8cc 100644
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningTest.java
+++ b/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningTest.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.maindexlist;
 
 import static com.android.tools.r8.utils.codeinspector.CodeMatchers.invokesMethod;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
@@ -20,7 +21,6 @@
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.MethodSubject;
-import com.google.common.collect.ImmutableSet;
 import java.util.List;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -90,10 +90,12 @@
     assertThat(fooMethodSubject, isPresent());
 
     ClassSubject aClassSubject = inspector.clazz(A.class);
-    assertThat(aClassSubject, isPresent());
+    // TODO(b/178353726): Should be present, but was inlined.
+    assertThat(aClassSubject, isAbsent());
 
     MethodSubject barMethodSubject = aClassSubject.uniqueMethodWithName("bar");
-    assertThat(barMethodSubject, isPresent());
+    // TODO(b/178353726): Should be present, but was inlined.
+    assertThat(barMethodSubject, isAbsent());
 
     ClassSubject bClassSubject = inspector.clazz(B.class);
     assertThat(bClassSubject, isPresent());
@@ -101,13 +103,16 @@
     MethodSubject bazMethodSubject = bClassSubject.uniqueMethodWithName("baz");
     assertThat(bazMethodSubject, isPresent());
 
-    assertThat(fooMethodSubject, invokesMethod(barMethodSubject));
-    assertThat(barMethodSubject, invokesMethod(bazMethodSubject));
+    // TODO(b/178353726): foo() should invoke bar() and bar() should invoke baz().
+    assertThat(fooMethodSubject, invokesMethod(bazMethodSubject));
 
-    // The main dex classes should be the same as the input main dex list.
-    assertEquals(
-        ImmutableSet.of(mainClassSubject.getFinalName(), aClassSubject.getFinalName()),
-        compileResult.getMainDexClasses());
+    // TODO(b/178353726): Main is the only class guaranteed to be in the main dex, but it has a
+    //  direct reference to B.
+    compileResult.inspectMainDexClasses(
+        mainDexClasses -> {
+          assertEquals(1, mainDexClasses.size());
+          assertEquals(mainClassSubject.getFinalName(), mainDexClasses.iterator().next());
+        });
   }
 
   static class Main {
@@ -117,8 +122,8 @@
     }
 
     static void foo() {
-      // Should not allow inlining bar into foo(), since that adds B as a direct
-      // dependence, and we don't include the direct dependencies of main dex list classes.
+      // TODO(b/178353726): Should not allow inlining bar into foo(), since that adds B as a direct
+      //  dependence, and we don't include the direct dependencies of main dex list classes.
       A.bar();
     }
   }
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningWithTracingTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningWithTracingTest.java
deleted file mode 100644
index f5971f9..0000000
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainDexInliningWithTracingTest.java
+++ /dev/null
@@ -1,158 +0,0 @@
-// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package com.android.tools.r8.maindexlist;
-
-import static com.android.tools.r8.utils.codeinspector.CodeMatchers.invokesMethod;
-import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertEquals;
-
-import com.android.tools.r8.NeverInline;
-import com.android.tools.r8.NoHorizontalClassMerging;
-import com.android.tools.r8.R8TestCompileResult;
-import com.android.tools.r8.TestBase;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
-import com.android.tools.r8.ToolHelper;
-import com.android.tools.r8.references.ClassReference;
-import com.android.tools.r8.utils.codeinspector.ClassSubject;
-import com.android.tools.r8.utils.codeinspector.CodeInspector;
-import com.android.tools.r8.utils.codeinspector.MethodSubject;
-import com.google.common.collect.ImmutableSet;
-import java.util.List;
-import org.junit.BeforeClass;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
-import org.junit.runners.Parameterized.Parameters;
-
-@RunWith(Parameterized.class)
-public class MainDexListFromGenerateMainDexInliningWithTracingTest extends TestBase {
-
-  private static List<ClassReference> mainDexList;
-
-  private final TestParameters parameters;
-
-  @Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters()
-        .withDexRuntimes()
-        .withApiLevelsEndingAtExcluding(apiLevelWithNativeMultiDexSupport())
-        .build();
-  }
-
-  @BeforeClass
-  public static void setup() throws Exception {
-    mainDexList =
-        testForMainDexListGenerator(getStaticTemp())
-            .addInnerClasses(MainDexListFromGenerateMainDexInliningWithTracingTest.class)
-            .addLibraryFiles(ToolHelper.getMostRecentAndroidJar())
-            .addMainDexRules(
-                "-keep class " + Main.class.getTypeName() + " {",
-                "  public static void main(java.lang.String[]);",
-                "}")
-            .run()
-            .getMainDexList();
-  }
-
-  public MainDexListFromGenerateMainDexInliningWithTracingTest(TestParameters parameters) {
-    this.parameters = parameters;
-  }
-
-  @Test
-  public void test() throws Exception {
-    // The generated main dex list should contain Main (which is a root) and A (which is a direct
-    // dependency of Main).
-    assertEquals(2, mainDexList.size());
-    assertEquals(A.class.getTypeName(), mainDexList.get(0).getTypeName());
-    assertEquals(Main.class.getTypeName(), mainDexList.get(1).getTypeName());
-
-    R8TestCompileResult compileResult =
-        testForR8(parameters.getBackend())
-            .addInnerClasses(getClass())
-            .addInliningAnnotations()
-            .addKeepClassAndMembersRules(Main.class)
-            .addMainDexListClassReferences(mainDexList)
-            .addMainDexRules(
-                "-keep class " + Main.class.getTypeName() + " {",
-                "  public static void foo(java.lang.String[]);",
-                "}")
-            .collectMainDexClasses()
-            .enableInliningAnnotations()
-            .enableNoHorizontalClassMergingAnnotations()
-            .noMinification()
-            .setMinApi(parameters.getApiLevel())
-            .compile();
-
-    CodeInspector inspector = compileResult.inspector();
-    ClassSubject mainClassSubject = inspector.clazz(Main.class);
-    assertThat(mainClassSubject, isPresent());
-
-    MethodSubject fooMethodSubject = mainClassSubject.uniqueMethodWithName("foo");
-    assertThat(fooMethodSubject, isPresent());
-
-    MethodSubject notCalledAtStartupMethodSubject =
-        mainClassSubject.uniqueMethodWithName("notCalledAtStartup");
-    assertThat(notCalledAtStartupMethodSubject, isPresent());
-
-    ClassSubject aClassSubject = inspector.clazz(A.class);
-    assertThat(aClassSubject, isPresent());
-
-    MethodSubject barMethodSubject = aClassSubject.uniqueMethodWithName("bar");
-    assertThat(barMethodSubject, isPresent());
-
-    ClassSubject bClassSubject = inspector.clazz(B.class);
-    assertThat(bClassSubject, isPresent());
-
-    MethodSubject bazMethodSubject = bClassSubject.uniqueMethodWithName("baz");
-    assertThat(bazMethodSubject, isPresent());
-
-    assertThat(notCalledAtStartupMethodSubject, invokesMethod(barMethodSubject));
-    assertThat(barMethodSubject, invokesMethod(bazMethodSubject));
-
-    // The main dex classes should be the same as the input main dex list.
-    assertEquals(
-        ImmutableSet.of(mainClassSubject.getFinalName(), aClassSubject.getFinalName()),
-        compileResult.getMainDexClasses());
-  }
-
-  static class Main {
-
-    public static void main(String[] args) {
-      System.out.println("Main.main()");
-    }
-
-    public static void notCalledAtStartup() {
-      // Should not allow inlining bar into notCalledAtStartup(), since that adds B as a direct
-      // dependence, and we don't include the direct dependencies of main dex list classes.
-      new A().bar();
-    }
-
-    // This method is traced for main dex when running with R8, to add A as a dependency.
-    static A foo(A a) {
-      if (a != null) {
-        System.out.println("Hello World");
-      }
-      return a;
-    }
-  }
-
-  @NoHorizontalClassMerging
-  static class A {
-
-    static void bar() {
-      B.baz();
-    }
-  }
-
-  @NoHorizontalClassMerging
-  static class B {
-
-    @NeverInline
-    static void baz() {
-      System.out.println("B.baz");
-    }
-  }
-}
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainHorizontalMergingTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainHorizontalMergingTest.java
index c2c96ec..867b1ec 100644
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainHorizontalMergingTest.java
+++ b/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainHorizontalMergingTest.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.maindexlist;
 
 import static com.android.tools.r8.utils.codeinspector.CodeMatchers.invokesMethod;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -13,12 +14,10 @@
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
-import com.android.tools.r8.R8FullTestBuilder;
 import com.android.tools.r8.R8TestCompileResult;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
-import com.android.tools.r8.ThrowableConsumer;
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.references.ClassReference;
 import com.android.tools.r8.references.TypeReference;
@@ -39,6 +38,7 @@
 public class MainDexListFromGenerateMainHorizontalMergingTest extends TestBase {
 
   private static List<ClassReference> mainDexList;
+
   private final TestParameters parameters;
 
   @Parameters(name = "{0}")
@@ -66,45 +66,30 @@
   }
 
   @Test
-  public void testMainDexList() throws Exception {
+  public void test() throws Exception {
     assertEquals(3, mainDexList.size());
     Set<String> mainDexReferences =
         mainDexList.stream().map(TypeReference::getTypeName).collect(Collectors.toSet());
     assertTrue(mainDexReferences.contains(A.class.getTypeName()));
     assertTrue(mainDexReferences.contains(B.class.getTypeName()));
     assertTrue(mainDexReferences.contains(Main.class.getTypeName()));
-    runTest(builder -> builder.addMainDexListClassReferences(mainDexList));
-  }
 
-  @Test
-  public void testMainDexTracing() throws Exception {
-    runTest(
-        builder ->
-            builder.addMainDexRules(
-                "-keep class " + Main.class.getTypeName() + " { public static void foo(); }"));
-  }
+    R8TestCompileResult compileResult =
+        testForR8(parameters.getBackend())
+            .addInnerClasses(getClass())
+            .addInliningAnnotations()
+            .addKeepClassAndMembersRules(Main.class, Outside.class)
+            .addMainDexListClassReferences(mainDexList)
+            .collectMainDexClasses()
+            .enableInliningAnnotations()
+            .enableNeverClassInliningAnnotations()
+            .setMinApi(parameters.getApiLevel())
+            .addHorizontallyMergedClassesInspector(
+                horizontallyMergedClassesInspector -> {
+                  horizontallyMergedClassesInspector.assertMergedInto(B.class, A.class);
+                })
+            .compile();
 
-  private void runTest(ThrowableConsumer<R8FullTestBuilder> testBuilder) throws Exception {
-    testForR8(parameters.getBackend())
-        .addInnerClasses(getClass())
-        .addInliningAnnotations()
-        .addKeepClassAndMembersRules(Main.class, Outside.class)
-        .collectMainDexClasses()
-        .enableInliningAnnotations()
-        .enableNeverClassInliningAnnotations()
-        .setMinApi(parameters.getApiLevel())
-        .addHorizontallyMergedClassesInspector(
-            horizontallyMergedClassesInspector -> {
-              horizontallyMergedClassesInspector.assertClassesNotMerged(B.class, A.class);
-            })
-        .apply(testBuilder)
-        .compile()
-        .apply(this::inspectCompileResult)
-        .run(parameters.getRuntime(), Main.class)
-        .assertSuccessWithOutputThatMatches(containsString(Outside.class.getTypeName()));
-  }
-
-  private void inspectCompileResult(R8TestCompileResult compileResult) throws Exception {
     CodeInspector inspector = compileResult.inspector();
     ClassSubject mainClassSubject = inspector.clazz(Main.class);
     assertThat(mainClassSubject, isPresent());
@@ -115,8 +100,9 @@
     ClassSubject aClassSubject = inspector.clazz(A.class);
     assertThat(aClassSubject, isPresent());
 
+    // TODO(b/178460068): Should be present, but was merged with A.
     ClassSubject bClassSubject = inspector.clazz(B.class);
-    assertThat(bClassSubject, isPresent());
+    assertThat(bClassSubject, isAbsent());
 
     MethodSubject fooASubject = aClassSubject.uniqueMethodWithName("foo");
     assertThat(fooASubject, isPresent());
@@ -126,16 +112,19 @@
     compileResult.inspectMainDexClasses(
         mainDexClasses -> {
           assertEquals(
-              ImmutableSet.of(
-                  mainClassSubject.getFinalName(),
-                  aClassSubject.getFinalName(),
-                  bClassSubject.getFinalName()),
+              ImmutableSet.of(mainClassSubject.getFinalName(), aClassSubject.getFinalName()),
               mainDexClasses);
         });
+
+    compileResult
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputThatMatches(containsString(Outside.class.getTypeName()));
   }
 
   static class Main {
 
+    // public static B b;
+
     public static void main(String[] args) {
       B.print();
     }
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainVerticalMergingTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainVerticalMergingTest.java
index ec50d1e..e3d49f5 100644
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainVerticalMergingTest.java
+++ b/src/test/java/com/android/tools/r8/maindexlist/MainDexListFromGenerateMainVerticalMergingTest.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.maindexlist;
 
 import static com.android.tools.r8.utils.codeinspector.CodeMatchers.invokesMethod;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
@@ -12,17 +13,16 @@
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
-import com.android.tools.r8.R8FullTestBuilder;
 import com.android.tools.r8.R8TestCompileResult;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
-import com.android.tools.r8.ThrowableConsumer;
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.references.ClassReference;
 import com.android.tools.r8.references.TypeReference;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FieldSubject;
 import com.android.tools.r8.utils.codeinspector.MethodSubject;
 import com.google.common.collect.ImmutableSet;
 import java.util.List;
@@ -66,34 +66,23 @@
   }
 
   @Test
-  public void testMainDexList() throws Exception {
+  public void test() throws Exception {
     assertEquals(3, mainDexList.size());
     Set<String> mainDexReferences =
         mainDexList.stream().map(TypeReference::getTypeName).collect(Collectors.toSet());
     assertTrue(mainDexReferences.contains(A.class.getTypeName()));
     assertTrue(mainDexReferences.contains(B.class.getTypeName()));
     assertTrue(mainDexReferences.contains(Main.class.getTypeName()));
-    runTest(builder -> builder.addMainDexListClassReferences(mainDexList));
-  }
 
-  @Test
-  public void testMainTracing() throws Exception {
-    runTest(
-        builder ->
-            builder.addMainDexRules(
-                "-keep class " + Main.class.getTypeName() + " { public static void foo(); }"));
-  }
-
-  private void runTest(ThrowableConsumer<R8FullTestBuilder> testBuilder) throws Exception {
     R8TestCompileResult compileResult =
         testForR8(parameters.getBackend())
             .addInnerClasses(getClass())
             .addInliningAnnotations()
             .addKeepClassAndMembersRules(Main.class, Outside.class)
+            .addMainDexListClassReferences(mainDexList)
             .collectMainDexClasses()
             .enableInliningAnnotations()
             .enableNeverClassInliningAnnotations()
-            .apply(testBuilder)
             .setMinApi(parameters.getApiLevel())
             .compile();
 
@@ -105,23 +94,25 @@
     assertThat(fooMethodSubject, isPresent());
 
     ClassSubject aClassSubject = inspector.clazz(A.class);
-    assertThat(aClassSubject, isPresent());
-
-    MethodSubject fooAMethodSubject = aClassSubject.uniqueMethodWithName("foo");
-    assertThat(fooAMethodSubject, isPresent());
+    // TODO(b/178460068): Should be present, but was merged with B.
+    assertThat(aClassSubject, isAbsent());
 
     ClassSubject bClassSubject = inspector.clazz(B.class);
     assertThat(bClassSubject, isPresent());
 
-    assertThat(fooMethodSubject, invokesMethod(fooAMethodSubject));
+    FieldSubject outsideFieldSubject = bClassSubject.uniqueFieldWithName("outsideField");
+    assertThat(outsideFieldSubject, isPresent());
 
+    MethodSubject fooBMethodSubject = bClassSubject.uniqueMethodWithName("foo");
+    assertThat(fooBMethodSubject, isPresent());
+
+    assertThat(fooMethodSubject, invokesMethod(fooBMethodSubject));
+
+    // TODO(b/178460068): B should not be in main dex.
     compileResult.inspectMainDexClasses(
         mainDexClasses -> {
           assertEquals(
-              ImmutableSet.of(
-                  mainClassSubject.getFinalName(),
-                  aClassSubject.getFinalName(),
-                  bClassSubject.getFinalName()),
+              ImmutableSet.of(mainClassSubject.getFinalName(), bClassSubject.getFinalName()),
               mainDexClasses);
         });
 
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexListInliningTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexListInliningTest.java
index 09df9d4..82f5e7d 100644
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexListInliningTest.java
+++ b/src/test/java/com/android/tools/r8/maindexlist/MainDexListInliningTest.java
@@ -4,6 +4,7 @@
 
 package com.android.tools.r8.maindexlist;
 
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertFalse;
@@ -54,9 +55,10 @@
     ClassSubject mainClassSubject = inspector.clazz(Main.class);
     assertThat(mainClassSubject, isPresent());
 
-    // A is not allowed to be inlined and is therefore present.
+    // A is absent due to inlining.
+    // TODO(b/178353726): Inlining should be prohibited.
     ClassSubject aClassSubject = inspector.clazz(A.class);
-    assertThat(aClassSubject, isPresent());
+    assertThat(aClassSubject, isAbsent());
 
     // B should be referenced from Main.main.
     ClassSubject bClassSubject = inspector.clazz(B.class);
@@ -65,9 +67,6 @@
     compileResult.inspectMainDexClasses(
         mainDexClasses -> {
           assertTrue(mainDexClasses.contains(mainClassSubject.getFinalName()));
-          // Since we passed a main-dex list the traced references A and B are not automagically
-          // included.
-          assertFalse(mainDexClasses.contains(aClassSubject.getFinalName()));
           assertFalse(mainDexClasses.contains(bClassSubject.getFinalName()));
         });
   }