Revert "Remove synthesized-class-map support for main-dex computation."

This reverts commit bc4f54ff2c45b64ce9e8c7306a2951c42499507b.

Reason for revert: b/181083776

Bug: 181083776
Bug: 180074885
Bug: 179465550
Bug: 176781940
Change-Id: Iee8da68ae7d41e2c773768d70d67e3f88c7d279a
diff --git a/src/main/java/com/android/tools/r8/annotations/SynthesizedClassMap.java b/src/main/java/com/android/tools/r8/annotations/SynthesizedClassMap.java
new file mode 100644
index 0000000..b46a2ba
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/annotations/SynthesizedClassMap.java
@@ -0,0 +1,15 @@
+// Copyright (c) 2017, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.annotations;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+@Retention(RetentionPolicy.CLASS)
+@Target(ElementType.TYPE)
+public @interface SynthesizedClassMap {
+  Class<?>[] value() default {};
+}
diff --git a/src/main/java/com/android/tools/r8/graph/DexAnnotation.java b/src/main/java/com/android/tools/r8/graph/DexAnnotation.java
index f2959bb..3483fbb 100644
--- a/src/main/java/com/android/tools/r8/graph/DexAnnotation.java
+++ b/src/main/java/com/android/tools/r8/graph/DexAnnotation.java
@@ -5,6 +5,7 @@
 
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.dex.MixedSectionCollection;
+import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.graph.DexValue.DexValueAnnotation;
 import com.android.tools.r8.graph.DexValue.DexValueArray;
 import com.android.tools.r8.graph.DexValue.DexValueInt;
@@ -21,7 +22,10 @@
 import com.android.tools.r8.utils.structural.StructuralMapping;
 import com.android.tools.r8.utils.structural.StructuralSpecification;
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
+import java.util.TreeSet;
 import java.util.function.Function;
 
 public class DexAnnotation extends DexItem implements StructuralItem<DexAnnotation> {
@@ -92,7 +96,8 @@
     }
     if (annotation == options.itemFactory.dalvikFastNativeAnnotation
         || annotation == options.itemFactory.dalvikCriticalNativeAnnotation
-        || annotation == options.itemFactory.annotationSynthesizedClass) {
+        || annotation == options.itemFactory.annotationSynthesizedClass
+        || annotation == options.itemFactory.annotationSynthesizedClassMap) {
       return true;
     }
     if (options.processCovariantReturnTypeAnnotations) {
@@ -358,6 +363,43 @@
     return new DexValueString(factory.createString(string));
   }
 
+  public static Collection<DexType> readAnnotationSynthesizedClassMap(
+      DexProgramClass clazz, DexItemFactory dexItemFactory) {
+    DexAnnotation foundAnnotation =
+        clazz.annotations().getFirstMatching(dexItemFactory.annotationSynthesizedClassMap);
+    if (foundAnnotation != null) {
+      if (foundAnnotation.annotation.elements.length != 1) {
+        throw new CompilationError(getInvalidSynthesizedClassMapMessage(clazz, foundAnnotation));
+      }
+      DexAnnotationElement value = foundAnnotation.annotation.elements[0];
+      if (value.name != dexItemFactory.valueString) {
+        throw new CompilationError(getInvalidSynthesizedClassMapMessage(clazz, foundAnnotation));
+      }
+      DexValueArray existingList = value.value.asDexValueArray();
+      if (existingList == null) {
+        throw new CompilationError(getInvalidSynthesizedClassMapMessage(clazz, foundAnnotation));
+      }
+      Collection<DexType> synthesized = new ArrayList<>(existingList.values.length);
+      for (DexValue element : existingList.getValues()) {
+        if (!element.isDexValueType()) {
+          throw new CompilationError(getInvalidSynthesizedClassMapMessage(clazz, foundAnnotation));
+        }
+        synthesized.add(element.asDexValueType().value);
+      }
+      return synthesized;
+    }
+    return Collections.emptyList();
+  }
+
+  private static String getInvalidSynthesizedClassMapMessage(
+      DexProgramClass annotatedClass,
+      DexAnnotation invalidAnnotation) {
+    return annotatedClass.toSourceString()
+        + " is annotated with invalid "
+        + invalidAnnotation.annotation.type.toString()
+        + ": " + invalidAnnotation.toString();
+  }
+
   public static DexAnnotation createAnnotationSynthesizedClass(
       SyntheticKind kind, DexType synthesizingContext, DexItemFactory dexItemFactory) {
     DexAnnotationElement kindElement =
@@ -414,6 +456,26 @@
     return new Pair<>(kind, valueElement.value.asDexValueType().getValue());
   }
 
+  public static DexAnnotation createAnnotationSynthesizedClassMap(
+      TreeSet<DexType> synthesized,
+      DexItemFactory dexItemFactory) {
+    DexValueType[] values = synthesized.stream()
+        .map(DexValueType::new)
+        .toArray(DexValueType[]::new);
+    DexValueArray array = new DexValueArray(values);
+    DexAnnotationElement pair =
+        new DexAnnotationElement(dexItemFactory.createString("value"), array);
+    return new DexAnnotation(
+        VISIBILITY_BUILD,
+        new DexEncodedAnnotation(
+            dexItemFactory.annotationSynthesizedClassMap, new DexAnnotationElement[]{pair}));
+  }
+
+  public static boolean isSynthesizedClassMapAnnotation(DexAnnotation annotation,
+      DexItemFactory factory) {
+    return annotation.annotation.type == factory.annotationSynthesizedClassMap;
+  }
+
   public DexAnnotation rewrite(Function<DexEncodedAnnotation, DexEncodedAnnotation> rewriter) {
     DexEncodedAnnotation rewritten = rewriter.apply(annotation);
     if (rewritten == annotation) {
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 3f6eff3..0752d20 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -605,6 +605,8 @@
   public final DexType annotationThrows = createStaticallyKnownType("Ldalvik/annotation/Throws;");
   public final DexType annotationSynthesizedClass =
       createStaticallyKnownType("Lcom/android/tools/r8/annotations/SynthesizedClass;");
+  public final DexType annotationSynthesizedClassMap =
+      createStaticallyKnownType("Lcom/android/tools/r8/annotations/SynthesizedClassMap;");
   public final DexType annotationCovariantReturnType =
       createStaticallyKnownType("Ldalvik/annotation/codegen/CovariantReturnType;");
   public final DexType annotationCovariantReturnTypes =
diff --git a/src/main/java/com/android/tools/r8/shaking/AnnotationRemover.java b/src/main/java/com/android/tools/r8/shaking/AnnotationRemover.java
index 057a80f..762617e 100644
--- a/src/main/java/com/android/tools/r8/shaking/AnnotationRemover.java
+++ b/src/main/java/com/android/tools/r8/shaking/AnnotationRemover.java
@@ -115,6 +115,10 @@
         return isAnnotationTypeLive;
 
       case DexAnnotation.VISIBILITY_BUILD:
+        if (DexAnnotation.isSynthesizedClassMapAnnotation(annotation, dexItemFactory)) {
+          // TODO(sgjesse) When should these be removed?
+          return true;
+        }
         if (!config.runtimeInvisibleAnnotations) {
           return false;
         }
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexInfo.java b/src/main/java/com/android/tools/r8/shaking/MainDexInfo.java
index a60967f..58315d4 100644
--- a/src/main/java/com/android/tools/r8/shaking/MainDexInfo.java
+++ b/src/main/java/com/android/tools/r8/shaking/MainDexInfo.java
@@ -91,11 +91,6 @@
     return this == NONE;
   }
 
-  public boolean isMainDexTypeThatShouldIncludeDependencies(DexType type) {
-    // Dependencies of 'type' are only needed if 'type' is a direct/executed main-dex type.
-    return classList.contains(type) || tracedRoots.contains(type);
-  }
-
   public boolean isMainDex(ProgramDefinition definition) {
     return isFromList(definition) || isTracedRoot(definition) || isDependency(definition);
   }
diff --git a/src/main/java/com/android/tools/r8/shaking/MissingClasses.java b/src/main/java/com/android/tools/r8/shaking/MissingClasses.java
index ba25312..3632f93 100644
--- a/src/main/java/com/android/tools/r8/shaking/MissingClasses.java
+++ b/src/main/java/com/android/tools/r8/shaking/MissingClasses.java
@@ -264,6 +264,7 @@
               dexItemFactory.annotationMethodParameters,
               dexItemFactory.annotationSourceDebugExtension,
               dexItemFactory.annotationSynthesizedClass,
+              dexItemFactory.annotationSynthesizedClassMap,
               dexItemFactory.annotationThrows,
               dexItemFactory.serializedLambdaType)
           .addAll(dexItemFactory.getJavaConversionTypes())
diff --git a/src/main/java/com/android/tools/r8/synthesis/SynthesizingContext.java b/src/main/java/com/android/tools/r8/synthesis/SynthesizingContext.java
index bff9202..82f3204 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SynthesizingContext.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SynthesizingContext.java
@@ -15,6 +15,7 @@
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.shaking.MainDexInfo;
 import java.util.Comparator;
+import java.util.Set;
 
 /**
  * A synthesizing context is a description of the context that gives rise to a synthetic item.
@@ -122,15 +123,14 @@
     appView.rewritePrefix.rewriteType(hygienicType, rewrittenType);
   }
 
-  // TODO(b/180074885): Remove this once main-dex is traced at the end of of compilation.
   void addIfDerivedFromMainDexClass(
-      DexProgramClass externalSyntheticClass, MainDexInfo mainDexInfo) {
-    if (mainDexInfo.isMainDex(externalSyntheticClass)) {
-      return;
-    }
+      DexProgramClass externalSyntheticClass,
+      MainDexInfo mainDexInfo,
+      Set<DexType> allMainDexTypes) {
     // The input context type (not the annotated context) determines if the derived class is to be
-    // in main dex, as it is the input context type that is traced as part of main-dex tracing.
-    if (mainDexInfo.isMainDexTypeThatShouldIncludeDependencies(inputContextType)) {
+    // in main dex.
+    // TODO(b/168584485): Once resolved allMainDexTypes == mainDexClasses.
+    if (allMainDexTypes.contains(inputContextType)) {
       mainDexInfo.addSyntheticClass(externalSyntheticClass);
     }
   }
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticFinalization.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticFinalization.java
index 2e70db6..65de13f 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticFinalization.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticFinalization.java
@@ -7,6 +7,7 @@
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexAnnotation;
 import com.android.tools.r8.graph.DexApplication;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexEncodedMethod;
@@ -31,9 +32,11 @@
 import com.android.tools.r8.utils.collections.BidirectionalManyToOneRepresentativeMap;
 import com.android.tools.r8.utils.collections.BidirectionalOneToOneHashMap;
 import com.android.tools.r8.utils.structural.RepresentativeMap;
+import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.ImmutableCollection;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ListMultimap;
 import com.google.common.collect.Sets;
 import com.google.common.hash.HashCode;
 import java.util.ArrayList;
@@ -45,6 +48,7 @@
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
+import java.util.TreeSet;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
 
@@ -258,6 +262,7 @@
   Result computeFinalSynthetics(AppView<?> appView) {
     assert verifyNoNestedSynthetics();
     DexApplication application;
+    MainDexInfo mainDexInfo = appView.appInfo().getMainDexInfo();
     Builder lensBuilder = new Builder();
     ImmutableMap.Builder<DexType, SyntheticMethodReference> finalMethodsBuilder =
         ImmutableMap.builder();
@@ -285,6 +290,11 @@
     ImmutableMap<DexType, SyntheticProgramClassReference> finalClasses =
         finalClassesBuilder.build();
 
+    handleSynthesizedClassMapping(
+        finalSyntheticProgramDefinitions, application, options, mainDexInfo, lensBuilder.typeMap);
+
+    assert appView.appInfo().getMainDexInfo() == mainDexInfo;
+
     Set<DexType> prunedSynthetics = Sets.newIdentityHashSet();
     committed.forEachNonLegacyItem(
         reference -> {
@@ -344,6 +354,96 @@
     return true;
   }
 
+  private void handleSynthesizedClassMapping(
+      List<DexProgramClass> finalSyntheticClasses,
+      DexApplication application,
+      InternalOptions options,
+      MainDexInfo mainDexInfo,
+      Map<DexType, DexType> derivedMainDexTypesToIgnore) {
+    boolean includeSynthesizedClassMappingInOutput = shouldAnnotateSynthetics(options);
+    if (includeSynthesizedClassMappingInOutput) {
+      updateSynthesizedClassMapping(application, finalSyntheticClasses);
+    }
+    updateMainDexListWithSynthesizedClassMap(application, mainDexInfo, derivedMainDexTypesToIgnore);
+    if (!includeSynthesizedClassMappingInOutput) {
+      clearSynthesizedClassMapping(application);
+    }
+  }
+
+  private void updateSynthesizedClassMapping(
+      DexApplication application, List<DexProgramClass> finalSyntheticClasses) {
+    ListMultimap<DexProgramClass, DexProgramClass> originalToSynthesized =
+        ArrayListMultimap.create();
+    for (DexType type : committed.getLegacyTypes()) {
+      DexProgramClass clazz = DexProgramClass.asProgramClassOrNull(application.definitionFor(type));
+      if (clazz != null) {
+        for (DexProgramClass origin : clazz.getSynthesizedFrom()) {
+          originalToSynthesized.put(origin, clazz);
+        }
+      }
+    }
+    for (DexProgramClass clazz : finalSyntheticClasses) {
+      for (DexProgramClass origin : clazz.getSynthesizedFrom()) {
+        originalToSynthesized.put(origin, clazz);
+      }
+    }
+    for (Map.Entry<DexProgramClass, Collection<DexProgramClass>> entry :
+        originalToSynthesized.asMap().entrySet()) {
+      DexProgramClass original = entry.getKey();
+      // Use a tree set to make sure that we have an ordering on the types.
+      // These types are put in an array in annotations in the output and we
+      // need a consistent ordering on them.
+      TreeSet<DexType> synthesized = new TreeSet<>(DexType::compareTo);
+      entry.getValue().stream()
+          .map(dexProgramClass -> dexProgramClass.type)
+          .forEach(synthesized::add);
+      synthesized.addAll(
+          DexAnnotation.readAnnotationSynthesizedClassMap(original, application.dexItemFactory));
+
+      DexAnnotation updatedAnnotation =
+          DexAnnotation.createAnnotationSynthesizedClassMap(
+              synthesized, application.dexItemFactory);
+
+      original.setAnnotations(original.annotations().getWithAddedOrReplaced(updatedAnnotation));
+    }
+  }
+
+  private void updateMainDexListWithSynthesizedClassMap(
+      DexApplication application,
+      MainDexInfo mainDexInfo,
+      Map<DexType, DexType> derivedMainDexTypesToIgnore) {
+    if (mainDexInfo.isEmpty()) {
+      return;
+    }
+    List<DexProgramClass> newMainDexClasses = new ArrayList<>();
+    mainDexInfo.forEachExcludingDependencies(
+        dexType -> {
+          DexProgramClass programClass =
+              DexProgramClass.asProgramClassOrNull(application.definitionFor(dexType));
+          if (programClass != null) {
+            Collection<DexType> derived =
+                DexAnnotation.readAnnotationSynthesizedClassMap(
+                    programClass, application.dexItemFactory);
+            for (DexType type : derived) {
+              DexType mappedType = derivedMainDexTypesToIgnore.getOrDefault(type, type);
+              DexProgramClass syntheticClass =
+                  DexProgramClass.asProgramClassOrNull(application.definitionFor(mappedType));
+              if (syntheticClass != null) {
+                newMainDexClasses.add(syntheticClass);
+              }
+            }
+          }
+        });
+    newMainDexClasses.forEach(mainDexInfo::addSyntheticClass);
+  }
+
+  private void clearSynthesizedClassMapping(DexApplication application) {
+    for (DexProgramClass clazz : application.classes()) {
+      clazz.setAnnotations(
+          clazz.annotations().getWithout(application.dexItemFactory.annotationSynthesizedClassMap));
+    }
+  }
+
   private static DexApplication buildLensAndProgram(
       AppView<?> appView,
       Map<DexType, EquivalenceGroup<SyntheticMethodDefinition>> syntheticMethodGroups,
@@ -354,8 +454,23 @@
     DexApplication application = appView.appInfo().app();
     DexItemFactory factory = appView.dexItemFactory();
     List<DexProgramClass> newProgramClasses = new ArrayList<>();
-    Set<DexType> pruned = Sets.newIdentityHashSet();
 
+    // TODO(b/168584485): Remove this once class-mapping support is removed.
+    Set<DexType> derivedMainDexTypes = Sets.newIdentityHashSet();
+    MainDexInfo mainDexInfo = appView.appInfo().getMainDexInfo();
+    mainDexInfo.forEachExcludingDependencies(
+        mainDexType -> {
+          derivedMainDexTypes.add(mainDexType);
+          DexProgramClass mainDexClass =
+              DexProgramClass.asProgramClassOrNull(
+                  appView.appInfo().definitionForWithoutExistenceAssert(mainDexType));
+          if (mainDexClass != null) {
+            derivedMainDexTypes.addAll(
+                DexAnnotation.readAnnotationSynthesizedClassMap(mainDexClass, factory));
+          }
+        });
+
+    Set<DexType> pruned = Sets.newIdentityHashSet();
     syntheticMethodGroups.forEach(
         (syntheticType, syntheticGroup) -> {
           SyntheticMethodDefinition representative = syntheticGroup.getRepresentative();
@@ -456,7 +571,8 @@
             addMainDexAndSynthesizedFromForMember(
                 member,
                 externalSyntheticClass,
-                appView.appInfo().getMainDexInfo(),
+                mainDexInfo,
+                derivedMainDexTypes,
                 appForLookup::programDefinitionFor);
           }
         });
@@ -477,7 +593,8 @@
             addMainDexAndSynthesizedFromForMember(
                 member,
                 externalSyntheticClass,
-                appView.appInfo().getMainDexInfo(),
+                mainDexInfo,
+                derivedMainDexTypes,
                 appForLookup::programDefinitionFor);
           }
         });
@@ -527,8 +644,11 @@
       SyntheticDefinition<?, ?, ?> member,
       DexProgramClass externalSyntheticClass,
       MainDexInfo mainDexInfo,
+      Set<DexType> derivedMainDexTypes,
       Function<DexType, DexProgramClass> definitions) {
-    member.getContext().addIfDerivedFromMainDexClass(externalSyntheticClass, mainDexInfo);
+    member
+        .getContext()
+        .addIfDerivedFromMainDexClass(externalSyntheticClass, mainDexInfo, derivedMainDexTypes);
     // TODO(b/168584485): Remove this once class-mapping support is removed.
     DexProgramClass from = definitions.apply(member.getContext().getSynthesizingContextType());
     if (from != null) {
@@ -541,6 +661,7 @@
     // This is currently also disabled on CF to CF desugaring to avoid missing class references to
     // the annotated classes.
     // TODO(b/147485959): Find an alternative encoding for synthetics to avoid missing-class refs.
+    // TODO(b/168584485): Remove support for main-dex tracing with the class-map annotation.
     return options.intermediate && !options.cfToCfDesugar;
   }
 
diff --git a/src/test/java/com/android/tools/r8/D8IncrementalRunExamplesAndroidOTest.java b/src/test/java/com/android/tools/r8/D8IncrementalRunExamplesAndroidOTest.java
index 43bbd99..82cfebf 100644
--- a/src/test/java/com/android/tools/r8/D8IncrementalRunExamplesAndroidOTest.java
+++ b/src/test/java/com/android/tools/r8/D8IncrementalRunExamplesAndroidOTest.java
@@ -361,25 +361,15 @@
   abstract D8IncrementalTestRunner test(String testName, String packageName, String mainClass);
 
   @Override
-  protected void testIntermediateWithMainDexList(
-      String packageName,
-      Path input,
-      int expectedMainDexListSize,
-      List<String> mainDexClasses,
-      List<String> mainDexOverApproximation)
-      throws Throwable {
+  protected void testIntermediateWithMainDexList(String packageName, Path input,
+      int expectedMainDexListSize, String... mainDexClasses) throws Throwable {
     // Skip those tests.
     Assume.assumeTrue(false);
   }
 
   @Override
-  protected Path buildDexThroughIntermediate(
-      String packageName,
-      Path input,
-      OutputMode outputMode,
-      AndroidApiLevel minApi,
-      List<String> mainDexClasses)
-      throws Throwable {
+  protected Path buildDexThroughIntermediate(String packageName, Path input, OutputMode outputMode,
+      AndroidApiLevel minApi, String... mainDexClasses) throws Throwable {
     // tests using this should already been skipped.
     throw new Unreachable();
   }
diff --git a/src/test/java/com/android/tools/r8/D8TestBuilder.java b/src/test/java/com/android/tools/r8/D8TestBuilder.java
index 4913bc5..4c88d91 100644
--- a/src/test/java/com/android/tools/r8/D8TestBuilder.java
+++ b/src/test/java/com/android/tools/r8/D8TestBuilder.java
@@ -6,7 +6,6 @@
 import com.android.tools.r8.D8Command.Builder;
 import com.android.tools.r8.TestBase.Backend;
 import com.android.tools.r8.desugar.desugaredlibrary.DesugaredLibraryTestBase.KeepRuleConsumer;
-import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.InternalOptions;
@@ -87,16 +86,4 @@
     builder.addMainDexRulesFiles(mainDexRuleFiles);
     return self();
   }
-
-  public D8TestBuilder addMainDexRules(String... rules) {
-    builder.addMainDexRules(Arrays.asList(rules), Origin.unknown());
-    return self();
-  }
-
-  public D8TestBuilder addMainDexKeepClassRules(Class<?>... classes) {
-    for (Class<?> clazz : classes) {
-      addMainDexRules("-keep class " + clazz.getTypeName());
-    }
-    return self();
-  }
 }
diff --git a/src/test/java/com/android/tools/r8/RunExamplesAndroidOTest.java b/src/test/java/com/android/tools/r8/RunExamplesAndroidOTest.java
index 52480be..2773cf1 100644
--- a/src/test/java/com/android/tools/r8/RunExamplesAndroidOTest.java
+++ b/src/test/java/com/android/tools/r8/RunExamplesAndroidOTest.java
@@ -9,19 +9,15 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 import com.android.tools.r8.ToolHelper.DexVm;
 import com.android.tools.r8.ToolHelper.DexVm.Version;
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.AndroidApp;
-import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.InternalOptions;
-import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.OffOrAuto;
 import com.android.tools.r8.utils.StringUtils;
-import com.android.tools.r8.utils.StringUtils.BraceType;
 import com.android.tools.r8.utils.TestDescriptionWatcher;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.FoundClassSubject;
@@ -31,8 +27,6 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
-import com.google.common.collect.Sets.SetView;
 import com.google.common.io.ByteStreams;
 import java.io.IOException;
 import java.io.InputStream;
@@ -41,12 +35,13 @@
 import java.nio.file.Paths;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
+import java.util.concurrent.ExecutionException;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 import java.util.function.UnaryOperator;
@@ -119,24 +114,6 @@
       return withBuilderTransformation(builder -> builder.addMainDexClasses(classes));
     }
 
-    C withMainDexKeepClassRules(List<String> classes) {
-      return withBuilderTransformation(
-          builder -> {
-            if (builder instanceof D8Command.Builder) {
-              ((D8Command.Builder) builder)
-                  .addMainDexRules(
-                      ListUtils.map(classes, c -> "-keep class " + c), Origin.unknown());
-            } else if (builder instanceof R8Command.Builder) {
-              ((R8Command.Builder) builder)
-                  .addMainDexRules(
-                      ListUtils.map(classes, c -> "-keep class " + c), Origin.unknown());
-            } else {
-              fail("Unexpected builder type: " + builder.getClass());
-            }
-            return builder;
-          });
-    }
-
     C withInterfaceMethodDesugaring(OffOrAuto behavior) {
       return withOptionConsumer(o -> o.interfaceMethodDesugaring = behavior);
     }
@@ -488,20 +465,15 @@
     testIntermediateWithMainDexList(
         "lambdadesugaring",
         1,
-        ImmutableList.of("lambdadesugaring.LambdaDesugaring$I"),
-        ImmutableList.of());
+        "lambdadesugaring.LambdaDesugaring$I");
   }
 
   @Test
   public void testLambdaDesugaringWithMainDexList2() throws Throwable {
     // Main dex class has many lambdas.
-    testIntermediateWithMainDexList(
-        "lambdadesugaring",
-        // TODO(b/180074885): Over approximation not present in R8.
-        this instanceof R8RunExamplesAndroidOTest ? 51 : 52,
-        ImmutableList.of("lambdadesugaring.LambdaDesugaring$Refs$B"),
-        // TODO(b/180074885): Over approximation due to invoke-dynamic reference adds as dependency.
-        ImmutableList.of("lambdadesugaring.LambdaDesugaring$Refs$D"));
+    testIntermediateWithMainDexList("lambdadesugaring",
+        33,
+        "lambdadesugaring.LambdaDesugaring$Refs$B");
   }
 
   @Test
@@ -511,11 +483,7 @@
         "interfacemethods",
         Paths.get(ToolHelper.EXAMPLES_ANDROID_N_BUILD_DIR, "interfacemethods" + JAR_EXTENSION),
         2,
-        ImmutableList.of("interfacemethods.I1"),
-        // TODO(b/180074885): Over approximation due to including I1-CC by being derived from I1,
-        //  but after desugaring I1 does not reference I1$-CC (the static method is moved), so it
-        //  is incorrect to include I1-CC in the main dex.
-        ImmutableList.of("interfacemethods.I1$-CC"));
+        "interfacemethods.I1");
   }
 
 
@@ -526,11 +494,7 @@
         "interfacemethods",
         Paths.get(ToolHelper.EXAMPLES_ANDROID_N_BUILD_DIR, "interfacemethods" + JAR_EXTENSION),
         2,
-        ImmutableList.of("interfacemethods.I2"),
-        // TODO(b/180074885): Over approximation due to including I2$-CC by being derived from I2,
-        //  but after desugaring I2 does not reference I2$-CC (the default method is moved), so it
-        //  is incorrect to include I2$-CC in the main dex.
-        ImmutableList.of("interfacemethods.I2$-CC"));
+        "interfacemethods.I2");
   }
 
   @Test
@@ -546,23 +510,20 @@
   private void testIntermediateWithMainDexList(
       String packageName,
       int expectedMainDexListSize,
-      List<String> mainDexClasses,
-      List<String> mainDexOverApproximation)
+      String... mainDexClasses)
       throws Throwable {
     testIntermediateWithMainDexList(
         packageName,
         Paths.get(EXAMPLE_DIR, packageName + JAR_EXTENSION),
         expectedMainDexListSize,
-        mainDexClasses,
-        mainDexOverApproximation);
+        mainDexClasses);
   }
 
   protected void testIntermediateWithMainDexList(
       String packageName,
       Path input,
       int expectedMainDexListSize,
-      List<String> mainDexClasses,
-      List<String> mainDexOverApproximation)
+      String... mainDexClasses)
       throws Throwable {
     AndroidApiLevel minApi = AndroidApiLevel.K;
 
@@ -573,18 +534,16 @@
             .withMinApiLevel(minApi)
             .withOptionConsumer(option -> option.minimalMainDex = true)
             .withOptionConsumer(option -> option.enableInheritanceClassInDexDistributor = false)
-            .withMainDexKeepClassRules(mainDexClasses)
+            .withMainDexClass(mainDexClasses)
             .withKeepAll();
     Path fullDexes = temp.getRoot().toPath().resolve(packageName + "full" + ZIP_EXTENSION);
     full.build(input, fullDexes);
 
     // Builds with intermediate in both output mode.
-    Path dexesThroughIndexedIntermediate =
-        buildDexThroughIntermediate(
-            packageName, input, OutputMode.DexIndexed, minApi, mainDexClasses);
-    Path dexesThroughFilePerInputClassIntermediate =
-        buildDexThroughIntermediate(
-            packageName, input, OutputMode.DexFilePerClassFile, minApi, mainDexClasses);
+    Path dexesThroughIndexedIntermediate = buildDexThroughIntermediate(
+        packageName, input, OutputMode.DexIndexed, minApi, mainDexClasses);
+    Path dexesThroughFilePerInputClassIntermediate = buildDexThroughIntermediate(
+        packageName, input, OutputMode.DexFilePerClassFile, minApi, mainDexClasses);
 
     // Collect main dex types.
     CodeInspector fullInspector = getMainDexInspector(fullDexes);
@@ -592,45 +551,20 @@
         getMainDexInspector(dexesThroughIndexedIntermediate);
     CodeInspector filePerInputClassIntermediateInspector =
         getMainDexInspector(dexesThroughFilePerInputClassIntermediate);
-    Set<String> fullMainClasses = new HashSet<>();
+    Collection<String> fullMainClasses = new HashSet<>();
     fullInspector.forAllClasses(
         clazz -> fullMainClasses.add(clazz.getFinalDescriptor()));
-    Set<String> indexedIntermediateMainClasses = new HashSet<>();
+    Collection<String> indexedIntermediateMainClasses = new HashSet<>();
     indexedIntermediateInspector.forAllClasses(
         clazz -> indexedIntermediateMainClasses.add(clazz.getFinalDescriptor()));
-    Set<String> filePerInputClassIntermediateMainClasses = new HashSet<>();
+    Collection<String> filePerInputClassIntermediateMainClasses = new HashSet<>();
     filePerInputClassIntermediateInspector.forAllClasses(
         clazz -> filePerInputClassIntermediateMainClasses.add(clazz.getFinalDescriptor()));
 
     // Check.
     Assert.assertEquals(expectedMainDexListSize, fullMainClasses.size());
-    SetView<String> adjustedFull =
-        Sets.difference(
-            fullMainClasses,
-            new HashSet<>(
-                ListUtils.map(mainDexOverApproximation, DescriptorUtils::javaTypeToDescriptor)));
-    assertEqualSets(adjustedFull, indexedIntermediateMainClasses);
-    assertEqualSets(adjustedFull, filePerInputClassIntermediateMainClasses);
-  }
-
-  <T> void assertEqualSets(Set<T> expected, Set<T> actual) {
-    SetView<T> missing = Sets.difference(expected, actual);
-    SetView<T> unexpected = Sets.difference(actual, expected);
-    if (missing.isEmpty() && unexpected.isEmpty()) {
-      return;
-    }
-    StringBuilder builder = new StringBuilder("Sets differ.");
-    if (!missing.isEmpty()) {
-      builder.append("\nMissing items: [\n  ");
-      StringUtils.append(builder, missing, "\n  ", BraceType.NONE);
-      builder.append("\n]");
-    }
-    if (!unexpected.isEmpty()) {
-      builder.append("\nUnexpected items: [\n  ");
-      StringUtils.append(builder, unexpected, "\n  ", BraceType.NONE);
-      builder.append("\n]");
-    }
-    fail(builder.toString());
+    Assert.assertEquals(fullMainClasses, indexedIntermediateMainClasses);
+    Assert.assertEquals(fullMainClasses, filePerInputClassIntermediateMainClasses);
   }
 
   protected Path buildDexThroughIntermediate(
@@ -638,7 +572,7 @@
       Path input,
       OutputMode outputMode,
       AndroidApiLevel minApi,
-      List<String> mainDexClasses)
+      String... mainDexClasses)
       throws Throwable {
     Path intermediateDex =
         temp.getRoot().toPath().resolve(packageName + "intermediate" + ZIP_EXTENSION);
@@ -657,7 +591,7 @@
         test(packageName + "dex", packageName, "N/A")
             .withOptionConsumer(option -> option.minimalMainDex = true)
             .withOptionConsumer(option -> option.enableInheritanceClassInDexDistributor = false)
-            .withMainDexKeepClassRules(mainDexClasses)
+            .withMainDexClass(mainDexClasses)
             .withMinApiLevel(minApi)
             .withKeepAll();
 
@@ -710,7 +644,8 @@
     }
   }
 
-  protected CodeInspector getMainDexInspector(Path zip) throws IOException {
+  protected CodeInspector getMainDexInspector(Path zip)
+      throws IOException, ExecutionException {
     try (ZipFile zipFile = new ZipFile(zip.toFile(), StandardCharsets.UTF_8)) {
       try (InputStream in =
           zipFile.getInputStream(zipFile.getEntry(ToolHelper.DEFAULT_DEX_FILENAME))) {
diff --git a/src/test/java/com/android/tools/r8/desugar/backports/BackportMainDexTest.java b/src/test/java/com/android/tools/r8/desugar/backports/BackportMainDexTest.java
index 81f0674..05f84af 100644
--- a/src/test/java/com/android/tools/r8/desugar/backports/BackportMainDexTest.java
+++ b/src/test/java/com/android/tools/r8/desugar/backports/BackportMainDexTest.java
@@ -153,7 +153,7 @@
     testForD8(parameters.getBackend())
         .addProgramClasses(CLASSES)
         .setMinApi(parameters.getApiLevel())
-        .addMainDexRules(keepMainProguardConfiguration(TestClass.class))
+        .addMainDexListClasses(MiniAssert.class, TestClass.class, User2.class)
         .setProgramConsumer(mainDexConsumer)
         .compile()
         .inspect(this::checkExpectedSynthetics)
@@ -185,8 +185,44 @@
     testForD8()
         .addProgramFiles(perClassOutput)
         .setMinApi(parameters.getApiLevel())
-        // Trace the classes run by main which will pick up their dependencies.
-        .addMainDexRules(keepMainProguardConfiguration(TestClass.class))
+        .addMainDexListClasses(MiniAssert.class, TestClass.class, User2.class)
+        .setProgramConsumer(mainDexConsumer)
+        .compile()
+        .inspect(this::checkExpectedSynthetics)
+        .run(parameters.getRuntime(), TestClass.class, getRunArgs())
+        .assertSuccessWithOutput(EXPECTED);
+    checkMainDex(mainDexConsumer);
+  }
+
+  // TODO(b/168584485): This test should be removed once support is dropped.
+  @Test
+  public void testD8MergingWithTraceCf() throws Exception {
+    assumeTrue(parameters.isDexRuntime());
+    Path out1 =
+        testForD8()
+            .addProgramClasses(User1.class)
+            .addClasspathClasses(CLASSES)
+            .setIntermediate(true)
+            .setMinApi(parameters.getApiLevel())
+            .compile()
+            .writeToZip();
+
+    Path out2 =
+        testForD8()
+            .addProgramClasses(User2.class)
+            .addClasspathClasses(CLASSES)
+            .setIntermediate(true)
+            .setMinApi(parameters.getApiLevel())
+            .compile()
+            .writeToZip();
+
+    MainDexConsumer mainDexConsumer = new MainDexConsumer();
+    testForD8(parameters.getBackend())
+        .addProgramClasses(TestClass.class, MiniAssert.class)
+        .addProgramFiles(out1, out2)
+        .setMinApi(parameters.getApiLevel())
+        .addMainDexListClassReferences(
+            traceMainDex(CLASSES, Collections.emptyList()).getMainDexList())
         .setProgramConsumer(mainDexConsumer)
         .compile()
         .inspect(this::checkExpectedSynthetics)
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexListOutputTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexListOutputTest.java
index d3d99ee..1bbcb9e 100644
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexListOutputTest.java
+++ b/src/test/java/com/android/tools/r8/maindexlist/MainDexListOutputTest.java
@@ -116,11 +116,12 @@
 
   @Test
   public void testD8DesugaredLambdasInMainDexList() throws Exception {
+    Path mainDexList = writeTextToTempFile(testClassMainDexName);
     TestMainDexListConsumer consumer = new TestMainDexListConsumer();
     testForD8()
         .setMinApi(AndroidApiLevel.K)
         .addProgramClasses(ImmutableList.of(TestClass.class, MyConsumer.class))
-        .addMainDexListClasses(TestClass.class)
+        .addMainDexListFiles(ImmutableList.of(mainDexList))
         .setMainDexListConsumer(consumer)
         .compile();
     assertTrue(consumer.called);
@@ -128,6 +129,7 @@
 
   @Test
   public void testD8DesugaredLambdasInMainDexListMerging() throws Exception {
+    Path mainDexList = writeTextToTempFile(testClassMainDexName);
     // Build intermediate dex code first.
     Path dexOutput =
         testForD8()
@@ -141,7 +143,7 @@
     testForD8()
         .setMinApi(AndroidApiLevel.K)
         .addProgramFiles(dexOutput)
-        .addMainDexKeepClassRules(TestClass.class)
+        .addMainDexListFiles(ImmutableList.of(mainDexList))
         .setMainDexListConsumer(consumer)
         .compile();
     assertTrue(consumer.called);
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexWithSynthesizedClassesTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexWithSynthesizedClassesTest.java
index 5e88d7d..9fa7fe9 100644
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexWithSynthesizedClassesTest.java
+++ b/src/test/java/com/android/tools/r8/maindexlist/MainDexWithSynthesizedClassesTest.java
@@ -72,7 +72,7 @@
     D8TestCompileResult compileResult =
         testForD8()
             .addProgramFiles(intermediateResult.writeToZip())
-            .addMainDexKeepClassRules(TestClass.class, A.class)
+            .addMainDexListClasses(TestClass.class, A.class)
             .setMinApiThreshold(parameters.getApiLevel())
             .compile();
     checkCompilationResult(compileResult);