diff --git a/src/main/java/com/android/tools/r8/GenerateMainDexList.java b/src/main/java/com/android/tools/r8/GenerateMainDexList.java
index 02d6475..e5842df 100644
--- a/src/main/java/com/android/tools/r8/GenerateMainDexList.java
+++ b/src/main/java/com/android/tools/r8/GenerateMainDexList.java
@@ -9,9 +9,9 @@
 import com.android.tools.r8.graph.AppServices;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexApplication;
-import com.android.tools.r8.graph.DexReference;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.shaking.DiscardedChecker;
 import com.android.tools.r8.shaking.Enqueuer;
 import com.android.tools.r8.shaking.MainDexClasses;
 import com.android.tools.r8.shaking.MainDexListBuilder;
@@ -24,6 +24,7 @@
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
@@ -73,16 +74,28 @@
         options.mainDexListConsumer.accept(String.join("\n", result), options.reporter);
       }
 
-      if (!mainDexRootSet.checkDiscarded.isEmpty()) {
-        new DiscardedChecker(mainDexRootSet, mainDexClasses.getClasses(), appView).run();
-      }
-      // Print -whyareyoukeeping results if any.
-      if (whyAreYouKeepingConsumer != null) {
-        for (DexReference reference : mainDexRootSet.reasonAsked) {
-          whyAreYouKeepingConsumer.printWhyAreYouKeeping(
-              enqueuer.getGraphNode(reference), System.out);
-        }
-      }
+      R8.processWhyAreYouKeepingAndCheckDiscarded(
+          mainDexRootSet,
+          () -> {
+            ArrayList<DexProgramClass> classes = new ArrayList<>();
+            // TODO(b/131668850): This is not a deterministic order!
+            mainDexClasses
+                .getClasses()
+                .forEach(
+                    type -> {
+                      DexClass clazz = appView.definitionFor(type);
+                      assert clazz.isProgramClass();
+                      classes.add(clazz.asProgramClass());
+                    });
+            return classes;
+          },
+          whyAreYouKeepingConsumer,
+          appView,
+          enqueuer,
+          true,
+          options,
+          timing,
+          executor);
 
       return result;
     } catch (ExecutionException e) {
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index 36c1cbf..036c874 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -18,6 +18,7 @@
 import com.android.tools.r8.graph.DexApplication;
 import com.android.tools.r8.graph.DexCallSite;
 import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexDefinition;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexReference;
@@ -94,6 +95,7 @@
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.function.Supplier;
 
 /**
  * The R8 compiler.
@@ -564,15 +566,30 @@
             enqueuer.traceMainDex(mainDexRootSet, executorService, timing);
         // Calculate the automatic main dex list according to legacy multidex constraints.
         mainDexClasses = new MainDexListBuilder(mainDexBaseClasses, application).run();
-        if (!mainDexRootSet.checkDiscarded.isEmpty()) {
-          new DiscardedChecker(mainDexRootSet, mainDexClasses.getClasses(), appView).run();
-        }
-        if (whyAreYouKeepingConsumer != null) {
-          for (DexReference reference : mainDexRootSet.reasonAsked) {
-            whyAreYouKeepingConsumer.printWhyAreYouKeeping(
-                enqueuer.getGraphNode(reference), System.out);
-          }
-        }
+        final MainDexClasses finalMainDexClasses = mainDexClasses;
+
+        processWhyAreYouKeepingAndCheckDiscarded(
+            mainDexRootSet,
+            () -> {
+              ArrayList<DexProgramClass> classes = new ArrayList<>();
+              // TODO(b/131668850): This is not a deterministic order!
+              finalMainDexClasses
+                  .getClasses()
+                  .forEach(
+                      type -> {
+                        DexClass clazz = appView.definitionFor(type);
+                        assert clazz.isProgramClass();
+                        classes.add(clazz.asProgramClass());
+                      });
+              return classes;
+            },
+            whyAreYouKeepingConsumer,
+            appView,
+            enqueuer,
+            true,
+            options,
+            timing,
+            executorService);
       }
 
       appView.setAppInfo(new AppInfoWithSubtyping(application));
@@ -610,13 +627,17 @@
                         application,
                         CollectionUtils.mergeSets(prunedTypes, pruner.getRemovedClasses())));
 
-            // Print reasons on the application after pruning, so that we reflect the actual result.
-            if (whyAreYouKeepingConsumer != null) {
-              for (DexReference reference : appView.rootSet().reasonAsked) {
-                whyAreYouKeepingConsumer.printWhyAreYouKeeping(
-                    enqueuer.getGraphNode(reference), System.out);
-              }
-            }
+            processWhyAreYouKeepingAndCheckDiscarded(
+                appView.rootSet(),
+                () -> appView.appInfo().app().classesWithDeterministicOrder(),
+                whyAreYouKeepingConsumer,
+                appView,
+                enqueuer,
+                false,
+                options,
+                timing,
+                executorService);
+
             // Remove annotations that refer to types that no longer exist.
             assert classesToRetainInnerClassAttributeFor != null;
             new AnnotationRemover(appView.withLiveness(), classesToRetainInnerClassAttributeFor)
@@ -636,11 +657,6 @@
         application = application.builder().addToMainDexList(mainDexClasses.getClasses()).build();
       }
 
-      // Only perform discard-checking if tree-shaking is turned on.
-      if (options.isShrinking() && !appView.rootSet().checkDiscarded.isEmpty()) {
-        new DiscardedChecker(appView.rootSet(), application, options).run();
-      }
-
       // Perform minification.
       NamingLens namingLens;
       if (options.getProguardConfiguration().hasApplyMappingFile()) {
@@ -740,6 +756,57 @@
     }
   }
 
+  static void processWhyAreYouKeepingAndCheckDiscarded(
+      RootSet rootSet,
+      Supplier<Iterable<DexProgramClass>> classes,
+      WhyAreYouKeepingConsumer whyAreYouKeepingConsumer,
+      AppView<? extends AppInfoWithSubtyping> appView,
+      Enqueuer enqueuer,
+      boolean forMainDex,
+      InternalOptions options,
+      Timing timing,
+      ExecutorService executorService)
+      throws ExecutionException {
+    if (whyAreYouKeepingConsumer != null) {
+      for (DexReference reference : rootSet.reasonAsked) {
+        whyAreYouKeepingConsumer.printWhyAreYouKeeping(
+            enqueuer.getGraphNode(reference), System.out);
+      }
+    }
+    if (rootSet.checkDiscarded.isEmpty()) {
+      return;
+    }
+    List<DexDefinition> failed = new DiscardedChecker(rootSet, classes.get()).run();
+    if (failed.isEmpty()) {
+      return;
+    }
+    // If there is no kept-graph info, re-run the enqueueing to compute it.
+    if (whyAreYouKeepingConsumer == null) {
+      whyAreYouKeepingConsumer = new WhyAreYouKeepingConsumer(null);
+      enqueuer = new Enqueuer(appView, options, whyAreYouKeepingConsumer);
+      if (forMainDex) {
+        enqueuer.traceMainDex(rootSet, executorService, timing);
+      } else {
+        enqueuer.traceApplication(
+            rootSet,
+            options.getProguardConfiguration().getDontWarnPatterns(),
+            executorService,
+            timing);
+      }
+    }
+    for (DexDefinition definition : failed) {
+      if (!failed.isEmpty()) {
+        ByteArrayOutputStream baos = new ByteArrayOutputStream();
+        whyAreYouKeepingConsumer.printWhyAreYouKeeping(
+            enqueuer.getGraphNode(definition.toReference()), new PrintStream(baos));
+        options.reporter.info(
+            new StringDiagnostic(
+                "Item " + definition.toSourceString() + " was not discarded.\n" + baos.toString()));
+      }
+    }
+    throw new CompilationError("Discard checks failed.");
+  }
+
   private void computeKotlinInfoForProgramClasses(DexApplication application, AppView<?> appView) {
     Kotlin kotlin = appView.dexItemFactory().kotlin;
     Reporter reporter = options.reporter;
diff --git a/src/main/java/com/android/tools/r8/shaking/DiscardedChecker.java b/src/main/java/com/android/tools/r8/shaking/DiscardedChecker.java
index abc1e3e..ca86412 100644
--- a/src/main/java/com/android/tools/r8/shaking/DiscardedChecker.java
+++ b/src/main/java/com/android/tools/r8/shaking/DiscardedChecker.java
@@ -3,62 +3,40 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.shaking;
 
-import com.android.tools.r8.errors.CompilationError;
-import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexApplication;
-import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexDefinition;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexReference;
-import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.shaking.RootSetBuilder.RootSet;
-import com.android.tools.r8.utils.InternalOptions;
-import com.android.tools.r8.utils.StringDiagnostic;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
 public class DiscardedChecker {
 
   private final Set<DexReference> checkDiscarded;
-  private final List<DexProgramClass> classes;
-  private boolean fail = false;
-  private final InternalOptions options;
+  private final Iterable<DexProgramClass> classes;
 
-  public DiscardedChecker(RootSet rootSet, DexApplication application, InternalOptions options) {
-    this.checkDiscarded = rootSet.checkDiscarded;
-    this.classes = application.classes();
-    this.options = options;
+  public DiscardedChecker(RootSet rootSet, Iterable<DexProgramClass> classes) {
+    this.checkDiscarded = new HashSet<>(rootSet.checkDiscarded);
+    this.classes = classes;
   }
 
-  public DiscardedChecker(RootSet rootSet, Set<DexType> types, AppView<?> appView) {
-    this.checkDiscarded = rootSet.checkDiscarded;
-    this.classes = new ArrayList<>();
-    types.forEach(
-        type -> {
-          DexClass clazz = appView.definitionFor(type);
-          assert clazz.isProgramClass();
-          this.classes.add(clazz.asProgramClass());
-        });
-    this.options = appView.options();
-  }
-
-  public void run() {
+  public List<DexDefinition> run() {
+    List<DexDefinition> failed = new ArrayList<>(checkDiscarded.size());
+    // TODO(b/131668850): Lookup the definition based on the reference.
     for (DexProgramClass clazz : classes) {
-      checkItem(clazz);
-      clazz.forEachMethod(this::checkItem);
-      clazz.forEachField(this::checkItem);
+      checkItem(clazz, failed);
+      clazz.forEachMethod(method -> checkItem(method, failed));
+      clazz.forEachField(field -> checkItem(field, failed));
     }
-    if (fail) {
-      throw new CompilationError("Discard checks failed.");
-    }
+    return failed;
   }
 
-  private void checkItem(DexDefinition item) {
-    if (checkDiscarded.contains(item.toReference())) {
-      options.reporter.info(
-          new StringDiagnostic("Item " + item.toSourceString() + " was not discarded."));
-      fail = true;
+  private void checkItem(DexDefinition item, List<DexDefinition> failed) {
+    DexReference reference = item.toReference();
+    if (checkDiscarded.contains(reference)) {
+      failed.add(item);
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java b/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
index 7029734..4249b72 100644
--- a/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
+++ b/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
@@ -69,7 +69,7 @@
   private final LinkedHashMap<DexReference, DexReference> reasonAsked = new LinkedHashMap<>();
   private final Set<ProguardConfigurationRule> rulesThatUseExtendsOrImplementsWrong =
       Sets.newIdentityHashSet();
-  private final Set<DexReference> checkDiscarded = Sets.newIdentityHashSet();
+  private final LinkedHashMap<DexReference, DexReference> checkDiscarded = new LinkedHashMap<>();
   private final Set<DexMethod> alwaysInline = Sets.newIdentityHashSet();
   private final Set<DexMethod> forceInline = Sets.newIdentityHashSet();
   private final Set<DexMethod> neverInline = Sets.newIdentityHashSet();
@@ -264,7 +264,7 @@
         noOptimization,
         noObfuscation,
         ImmutableList.copyOf(reasonAsked.values()),
-        checkDiscarded,
+        ImmutableList.copyOf(checkDiscarded.values()),
         alwaysInline,
         forceInline,
         neverInline,
@@ -952,7 +952,7 @@
     } else if (context instanceof ProguardAssumeValuesRule) {
       assumedValues.put(item.toReference(), rule);
     } else if (context instanceof ProguardCheckDiscardRule) {
-      checkDiscarded.add(item.toReference());
+      checkDiscarded.computeIfAbsent(item.toReference(), i -> i);
     } else if (context instanceof InlineRule) {
       if (item.isDexEncodedMethod()) {
         switch (((InlineRule) context).getType()) {
@@ -1032,7 +1032,7 @@
     public final Set<DexReference> noOptimization;
     private final Set<DexReference> noObfuscation;
     public final ImmutableList<DexReference> reasonAsked;
-    public final Set<DexReference> checkDiscarded;
+    public final ImmutableList<DexReference> checkDiscarded;
     public final Set<DexMethod> alwaysInline;
     public final Set<DexMethod> forceInline;
     public final Set<DexMethod> neverInline;
@@ -1054,7 +1054,7 @@
         Set<DexReference> noOptimization,
         Set<DexReference> noObfuscation,
         ImmutableList<DexReference> reasonAsked,
-        Set<DexReference> checkDiscarded,
+        ImmutableList<DexReference> checkDiscarded,
         Set<DexMethod> alwaysInline,
         Set<DexMethod> forceInline,
         Set<DexMethod> neverInline,
@@ -1073,7 +1073,7 @@
       this.noOptimization = noOptimization;
       this.noObfuscation = noObfuscation;
       this.reasonAsked = reasonAsked;
-      this.checkDiscarded = Collections.unmodifiableSet(checkDiscarded);
+      this.checkDiscarded = checkDiscarded;
       this.alwaysInline = Collections.unmodifiableSet(alwaysInline);
       this.forceInline = Collections.unmodifiableSet(forceInline);
       this.neverInline = neverInline;
diff --git a/src/test/java/com/android/tools/r8/checkdiscarded/CheckDiscardedTest.java b/src/test/java/com/android/tools/r8/checkdiscarded/CheckDiscardedTest.java
index 6593110..b0c7ecf 100644
--- a/src/test/java/com/android/tools/r8/checkdiscarded/CheckDiscardedTest.java
+++ b/src/test/java/com/android/tools/r8/checkdiscarded/CheckDiscardedTest.java
@@ -3,33 +3,67 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.checkdiscarded;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.StringContains.containsString;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
 import com.android.tools.r8.CompilationFailedException;
+import com.android.tools.r8.Diagnostic;
+import com.android.tools.r8.R8FullTestBuilder;
+import com.android.tools.r8.R8TestCompileResult;
 import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestDiagnosticMessages;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.checkdiscarded.testclasses.Main;
 import com.android.tools.r8.checkdiscarded.testclasses.UnusedClass;
 import com.android.tools.r8.checkdiscarded.testclasses.UsedClass;
 import com.android.tools.r8.checkdiscarded.testclasses.WillBeGone;
 import com.android.tools.r8.checkdiscarded.testclasses.WillStay;
 import com.android.tools.r8.utils.InternalOptions;
-import com.google.common.collect.ImmutableList;
 import java.util.List;
-import org.junit.Assert;
+import java.util.function.Consumer;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
 
+@RunWith(Parameterized.class)
 public class CheckDiscardedTest extends TestBase {
 
-  private void run(boolean obfuscate, Class annotation, boolean checkMembers, boolean shouldFail)
-      throws Exception {
-    List<Class<?>> classes = ImmutableList.of(UnusedClass.class, UsedClass.class, Main.class);
-    String proguardConfig = keepMainProguardConfiguration(Main.class, true, obfuscate)
-        + checkDiscardRule(checkMembers, annotation);
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withNoneRuntime().build();
+  }
+
+  private final TestParameters parameters;
+
+  public CheckDiscardedTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  private void compile(
+      boolean obfuscate,
+      Class annotation,
+      boolean checkMembers,
+      Consumer<TestDiagnosticMessages> onCompilationFailure) {
+    R8FullTestBuilder builder = testForR8(Backend.DEX);
+    TestDiagnosticMessages diagnostics = builder.getState().getDiagnosticsMessages();
     try {
-      compileWithR8(classes, proguardConfig, this::noInlining);
+      R8TestCompileResult result =
+          builder
+              .addProgramClasses(UnusedClass.class, UsedClass.class, Main.class)
+              .addKeepMainRule(Main.class)
+              .addKeepRules(checkDiscardRule(checkMembers, annotation))
+              .minification(obfuscate)
+              .addOptionsModification(this::noInlining)
+              .compile();
+      assertNull(onCompilationFailure);
+      result.assertNoMessages();
     } catch (CompilationFailedException e) {
-      Assert.assertTrue(shouldFail);
-      return;
+      onCompilationFailure.accept(diagnostics);
     }
-    Assert.assertFalse(shouldFail);
   }
 
   private void noInlining(InternalOptions options) {
@@ -45,27 +79,46 @@
   }
 
   @Test
-  public void classesAreGone() throws Exception {
-    run(false, WillBeGone.class, false, false);
-    run(true, WillBeGone.class, false, false);
+  public void classesAreGone() {
+    compile(false, WillBeGone.class, false, null);
+    compile(true, WillBeGone.class, false, null);
   }
 
   @Test
-  public void classesAreNotGone() throws Exception {
-    run(false, WillStay.class, false, true);
-    run(true, WillStay.class, false, true);
+  public void classesAreNotGone() {
+    Consumer<TestDiagnosticMessages> check =
+        diagnostics -> {
+          List<Diagnostic> infos = diagnostics.getInfos();
+          assertEquals(2, infos.size());
+          String messageUsedClass = infos.get(1).getDiagnosticMessage();
+          assertThat(messageUsedClass, containsString("UsedClass was not discarded"));
+          assertThat(messageUsedClass, containsString("is instantiated in"));
+          String messageMain = infos.get(0).getDiagnosticMessage();
+          assertThat(messageMain, containsString("Main was not discarded"));
+          assertThat(messageMain, containsString("is referenced in keep rule"));
+        };
+    compile(false, WillStay.class, false, check);
+    compile(true, WillStay.class, false, check);
   }
 
   @Test
-  public void membersAreGone() throws Exception {
-    run(false, WillBeGone.class, true, false);
-    run(true, WillBeGone.class, true, false);
+  public void membersAreGone() {
+    compile(false, WillBeGone.class, true, null);
+    compile(true, WillBeGone.class, true, null);
   }
 
   @Test
-  public void membersAreNotGone() throws Exception {
-    run(false, WillStay.class, true, true);
-    run(true, WillStay.class, true, true);
+  public void membersAreNotGone() {
+    Consumer<TestDiagnosticMessages> check =
+        diagnostics -> {
+          List<Diagnostic> infos = diagnostics.getInfos();
+          assertEquals(1, infos.size());
+          String message = infos.get(0).getDiagnosticMessage();
+          assertThat(message, containsString("was not discarded"));
+          assertThat(message, containsString("is invoked from"));
+        };
+    compile(false, WillStay.class, true, check);
+    compile(true, WillStay.class, true, check);
   }
 
 }
