[KeepAnno] Check for checkdiscard support in rule extraction

Bug: b/321674067
Change-Id: I6c0a5f1509dc395798c713d62a74855082baa9e5
diff --git a/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractor.java b/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractor.java
index 1c554fc..815f6ba 100644
--- a/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractor.java
+++ b/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractor.java
@@ -72,14 +72,17 @@
     ruleConsumer.accept(builder.toString());
   }
 
-  private static List<PgRule> split(KeepDeclaration declaration) {
+  private List<PgRule> split(KeepDeclaration declaration) {
     if (declaration.isKeepCheck()) {
       return generateCheckRules(declaration.asKeepCheck());
     }
     return doSplit(KeepEdgeNormalizer.normalize(declaration.asKeepEdge()));
   }
 
-  private static List<PgRule> generateCheckRules(KeepCheck check) {
+  private List<PgRule> generateCheckRules(KeepCheck check) {
+    if (!options.hasCheckDiscardSupport()) {
+      return Collections.emptyList();
+    }
     KeepItemPattern itemPattern = check.getItemPattern();
     boolean isRemovedPattern = check.getKind() == KeepCheckKind.REMOVED;
     List<PgRule> rules = new ArrayList<>(isRemovedPattern ? 2 : 1);
diff --git a/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractorOptions.java b/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractorOptions.java
index e428358..c30b430 100644
--- a/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractorOptions.java
+++ b/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractorOptions.java
@@ -8,8 +8,10 @@
 
 public class KeepRuleExtractorOptions {
 
-  private static final KeepRuleExtractorOptions PG_OPTIONS = new KeepRuleExtractorOptions(false);
-  private static final KeepRuleExtractorOptions R8_OPTIONS = new KeepRuleExtractorOptions(true);
+  private static final KeepRuleExtractorOptions PG_OPTIONS =
+      new KeepRuleExtractorOptions(false, false);
+  private static final KeepRuleExtractorOptions R8_OPTIONS =
+      new KeepRuleExtractorOptions(true, true);
 
   public static KeepRuleExtractorOptions getPgOptions() {
     return PG_OPTIONS;
@@ -19,13 +21,20 @@
     return R8_OPTIONS;
   }
 
+  private final boolean allowCheckDiscard;
   private final boolean allowAccessModificationOption;
   private final boolean allowAnnotationRemovalOption = false;
 
-  private KeepRuleExtractorOptions(boolean allowAccessModificationOption) {
+  private KeepRuleExtractorOptions(
+      boolean allowCheckDiscard, boolean allowAccessModificationOption) {
+    this.allowCheckDiscard = allowCheckDiscard;
     this.allowAccessModificationOption = allowAccessModificationOption;
   }
 
+  public boolean hasCheckDiscardSupport() {
+    return allowCheckDiscard;
+  }
+
   private boolean hasAllowAccessModificationOptionSupport() {
     return allowAccessModificationOption;
   }
diff --git a/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternAnyRetentionTest.java b/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternAnyRetentionTest.java
index 4b11ba6..31de01d 100644
--- a/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternAnyRetentionTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternAnyRetentionTest.java
@@ -8,9 +8,6 @@
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresentAndRenamed;
 import static org.hamcrest.MatcherAssert.assertThat;
 
-import com.android.tools.r8.TestBase;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.keepanno.annotations.AnnotationPattern;
 import com.android.tools.r8.keepanno.annotations.KeepItemKind;
 import com.android.tools.r8.keepanno.annotations.KeepTarget;
@@ -29,40 +26,29 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
 @RunWith(Parameterized.class)
-public class AnnotationPatternAnyRetentionTest extends TestBase {
+public class AnnotationPatternAnyRetentionTest extends KeepAnnoTestBase {
 
   static final String EXPECTED = StringUtils.lines("C1: A1");
 
-  private final TestParameters parameters;
+  @Parameter public KeepAnnoParameters parameters;
 
   @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build();
-  }
-
-  public AnnotationPatternAnyRetentionTest(TestParameters parameters) {
-    this.parameters = parameters;
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build());
   }
 
   @Test
-  public void testReference() throws Exception {
-    testForRuntime(parameters)
+  public void test() throws Exception {
+    testForKeepAnno(parameters)
         .addProgramClasses(getInputClasses())
-        .run(parameters.getRuntime(), TestClass.class)
-        .assertSuccessWithOutput(EXPECTED);
-  }
-
-  @Test
-  public void testR8() throws Exception {
-    testForR8(parameters.getBackend())
-        .enableExperimentalKeepAnnotations()
-        .addProgramClasses(getInputClasses())
-        .setMinApi(parameters)
-        .run(parameters.getRuntime(), TestClass.class)
+        .setExcludedOuterClass(getClass())
+        .run(TestClass.class)
         .assertSuccessWithOutput(EXPECTED)
-        .inspect(this::checkOutput);
+        .applyIf(parameters.isShrinker(), r -> r.inspect(this::checkOutput));
   }
 
   public List<Class<?>> getInputClasses() {
diff --git a/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternClassRetentionTest.java b/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternClassRetentionTest.java
index efc77cf..df1bee8 100644
--- a/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternClassRetentionTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternClassRetentionTest.java
@@ -9,9 +9,6 @@
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresentAndRenamed;
 import static org.hamcrest.MatcherAssert.assertThat;
 
-import com.android.tools.r8.TestBase;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.keepanno.annotations.AnnotationPattern;
 import com.android.tools.r8.keepanno.annotations.KeepItemKind;
 import com.android.tools.r8.keepanno.annotations.KeepTarget;
@@ -30,40 +27,29 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
 @RunWith(Parameterized.class)
-public class AnnotationPatternClassRetentionTest extends TestBase {
+public class AnnotationPatternClassRetentionTest extends KeepAnnoTestBase {
 
   static final String EXPECTED = StringUtils.lines("C1:");
 
-  private final TestParameters parameters;
+  @Parameter public KeepAnnoParameters parameters;
 
   @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build();
-  }
-
-  public AnnotationPatternClassRetentionTest(TestParameters parameters) {
-    this.parameters = parameters;
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build());
   }
 
   @Test
-  public void testReference() throws Exception {
-    testForRuntime(parameters)
+  public void test() throws Exception {
+    testForKeepAnno(parameters)
         .addProgramClasses(getInputClasses())
-        .run(parameters.getRuntime(), TestClass.class)
-        .assertSuccessWithOutput(EXPECTED);
-  }
-
-  @Test
-  public void testR8() throws Exception {
-    testForR8(parameters.getBackend())
-        .enableExperimentalKeepAnnotations()
-        .addProgramClasses(getInputClasses())
-        .setMinApi(parameters)
-        .run(parameters.getRuntime(), TestClass.class)
+        .setExcludedOuterClass(getClass())
+        .run(TestClass.class)
         .assertSuccessWithOutput(EXPECTED)
-        .inspect(this::checkOutput);
+        .applyIf(parameters.isShrinker(), r -> r.inspect(this::checkOutput));
   }
 
   public List<Class<?>> getInputClasses() {
diff --git a/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternMultipleTest.java b/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternMultipleTest.java
index 33a4f6d..feafc7f 100644
--- a/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternMultipleTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/AnnotationPatternMultipleTest.java
@@ -9,9 +9,6 @@
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresentAndRenamed;
 import static org.hamcrest.MatcherAssert.assertThat;
 
-import com.android.tools.r8.TestBase;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.keepanno.annotations.AnnotationPattern;
 import com.android.tools.r8.keepanno.annotations.ClassNamePattern;
 import com.android.tools.r8.keepanno.annotations.KeepItemKind;
@@ -31,40 +28,29 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
 @RunWith(Parameterized.class)
-public class AnnotationPatternMultipleTest extends TestBase {
+public class AnnotationPatternMultipleTest extends KeepAnnoTestBase {
 
   static final String EXPECTED = StringUtils.lines("C1: A1", "C2:");
 
-  private final TestParameters parameters;
+  @Parameter public KeepAnnoParameters parameters;
 
   @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build();
-  }
-
-  public AnnotationPatternMultipleTest(TestParameters parameters) {
-    this.parameters = parameters;
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build());
   }
 
   @Test
-  public void testReference() throws Exception {
-    testForRuntime(parameters)
+  public void test() throws Exception {
+    testForKeepAnno(parameters)
         .addProgramClasses(getInputClasses())
-        .run(parameters.getRuntime(), TestClass.class)
-        .assertSuccessWithOutput(EXPECTED);
-  }
-
-  @Test
-  public void testR8() throws Exception {
-    testForR8(parameters.getBackend())
-        .enableExperimentalKeepAnnotations()
-        .addProgramClasses(getInputClasses())
-        .setMinApi(parameters)
-        .run(parameters.getRuntime(), TestClass.class)
+        .setExcludedOuterClass(getClass())
+        .run(TestClass.class)
         .assertSuccessWithOutput(EXPECTED)
-        .inspect(this::checkOutput);
+        .applyIf(parameters.isShrinker(), r -> r.inspect(this::checkOutput));
   }
 
   public List<Class<?>> getInputClasses() {
diff --git a/src/test/java/com/android/tools/r8/keepanno/ArrayPatternsTest.java b/src/test/java/com/android/tools/r8/keepanno/ArrayPatternsTest.java
index a7cf54f..2f1f572 100644
--- a/src/test/java/com/android/tools/r8/keepanno/ArrayPatternsTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/ArrayPatternsTest.java
@@ -8,9 +8,6 @@
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresentAndRenamed;
 import static org.hamcrest.MatcherAssert.assertThat;
 
-import com.android.tools.r8.TestBase;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.keepanno.annotations.KeepItemKind;
 import com.android.tools.r8.keepanno.annotations.KeepTarget;
 import com.android.tools.r8.keepanno.annotations.TypePattern;
@@ -25,41 +22,30 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
 @RunWith(Parameterized.class)
-public class ArrayPatternsTest extends TestBase {
+public class ArrayPatternsTest extends KeepAnnoTestBase {
 
   static final String EXPECTED =
       StringUtils.lines("int[] [1, 2, 3]", "int[][] [[42]]", "Integer[][][] [[[333]]]");
 
-  private final TestParameters parameters;
+  @Parameter public KeepAnnoParameters parameters;
 
   @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build();
-  }
-
-  public ArrayPatternsTest(TestParameters parameters) {
-    this.parameters = parameters;
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build());
   }
 
   @Test
-  public void testReference() throws Exception {
-    testForRuntime(parameters)
+  public void test() throws Exception {
+    testForKeepAnno(parameters)
         .addProgramClasses(getInputClasses())
-        .run(parameters.getRuntime(), TestClass.class)
-        .assertSuccessWithOutput(EXPECTED);
-  }
-
-  @Test
-  public void testR8() throws Exception {
-    testForR8(parameters.getBackend())
-        .enableExperimentalKeepAnnotations()
-        .addProgramClasses(getInputClasses())
-        .setMinApi(parameters)
-        .run(parameters.getRuntime(), TestClass.class)
+        .setExcludedOuterClass(getClass())
+        .run(TestClass.class)
         .assertSuccessWithOutput(EXPECTED)
-        .inspect(this::checkOutput);
+        .applyIf(parameters.isShrinker(), r -> r.inspect(this::checkOutput));
   }
 
   public List<Class<?>> getInputClasses() {
diff --git a/src/test/java/com/android/tools/r8/keepanno/CheckOptimizedOutAnnotationTest.java b/src/test/java/com/android/tools/r8/keepanno/CheckOptimizedOutAnnotationTest.java
index adc446c..ab4125f 100644
--- a/src/test/java/com/android/tools/r8/keepanno/CheckOptimizedOutAnnotationTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/CheckOptimizedOutAnnotationTest.java
@@ -3,16 +3,19 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.keepanno;
 
+import static com.android.tools.r8.DiagnosticsMatcher.diagnosticMessage;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.allOf;
+import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeFalse;
+import static org.junit.Assume.assumeTrue;
 
 import com.android.tools.r8.DiagnosticsLevel;
 import com.android.tools.r8.DiagnosticsMatcher;
-import com.android.tools.r8.TestBase;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.errors.CheckDiscardDiagnostic;
 import com.android.tools.r8.keepanno.annotations.CheckOptimizedOut;
 import com.android.tools.r8.utils.AndroidApiLevel;
@@ -23,56 +26,82 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
 @RunWith(Parameterized.class)
-public class CheckOptimizedOutAnnotationTest extends TestBase {
+public class CheckOptimizedOutAnnotationTest extends KeepAnnoTestBase {
 
   static final String EXPECTED = StringUtils.lines("A", "B.baz");
 
-  private final TestParameters parameters;
+  @Parameter public KeepAnnoParameters parameters;
 
   @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build();
-  }
-
-  public CheckOptimizedOutAnnotationTest(TestParameters parameters) {
-    this.parameters = parameters;
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build());
   }
 
   @Test
-  public void testReference() throws Exception {
-    testForRuntime(parameters)
-        .addProgramClasses(getInputClasses())
-        .run(parameters.getRuntime(), TestClass.class)
-        .assertSuccessWithOutput(EXPECTED);
-  }
-
-  @Test
-  public void testWithRuleExtraction() throws Exception {
-    testForR8(parameters.getBackend())
-        .enableExperimentalKeepAnnotations()
+  public void test() throws Exception {
+    assumeFalse(parameters.isR8());
+    testForKeepAnno(parameters)
         .addProgramClasses(getInputClasses())
         .addKeepMainRule(TestClass.class)
-        .setMinApi(parameters)
-        .allowDiagnosticWarningMessages()
-        .setDiagnosticsLevelModifier(
-            (level, diagnostic) ->
-                level == DiagnosticsLevel.ERROR ? DiagnosticsLevel.WARNING : level)
-        .compileWithExpectedDiagnostics(
-            diagnostics -> {
-              diagnostics
-                  .assertOnlyWarnings()
-                  .assertWarningsMatch(
-                      DiagnosticsMatcher.diagnosticType(CheckDiscardDiagnostic.class));
-              CheckDiscardDiagnostic discard =
-                  (CheckDiscardDiagnostic) diagnostics.getWarnings().get(0);
-              // The discard error should report one error for A.toString.
-              assertEquals(discard.getDiagnosticMessage(), 1, discard.getNumberOfFailures());
-            })
-        .run(parameters.getRuntime(), TestClass.class)
+        .setExcludedOuterClass(getClass())
+        .run(TestClass.class)
         .assertSuccessWithOutput(EXPECTED)
-        .inspect(this::checkOutput);
+        .applyIf(parameters.isShrinker(), r -> r.inspect(this::checkOutput));
+  }
+
+  @Test
+  public void testR8Native() throws Throwable {
+    assumeTrue(parameters.isR8() && parameters.isNative());
+    testForKeepAnno(parameters)
+        .addProgramClasses(getInputClasses())
+        .addKeepMainRule(TestClass.class)
+        .applyIfR8Native(
+            b ->
+                b.allowDiagnosticWarningMessages()
+                    .setDiagnosticsLevelModifier(
+                        (level, diagnostic) ->
+                            level == DiagnosticsLevel.ERROR ? DiagnosticsLevel.WARNING : level)
+                    .compileWithExpectedDiagnostics(
+                        diagnostics -> {
+                          diagnostics
+                              .assertOnlyWarnings()
+                              .assertWarningsMatch(
+                                  DiagnosticsMatcher.diagnosticType(CheckDiscardDiagnostic.class));
+                          CheckDiscardDiagnostic discard =
+                              (CheckDiscardDiagnostic) diagnostics.getWarnings().get(0);
+                          // The discard error should report one error for A.toString.
+                          assertEquals(
+                              discard.getDiagnosticMessage(), 1, discard.getNumberOfFailures());
+                          assertThat(
+                              discard,
+                              diagnosticMessage(containsString("A.toString() was not discarded")));
+                        })
+                    .run(parameters.parameters().getRuntime(), TestClass.class)
+                    .assertSuccessWithOutput(EXPECTED)
+                    .inspect(this::checkOutput));
+  }
+
+  @Test
+  public void testR8Extract() throws Throwable {
+    assumeTrue(parameters.isR8() && !parameters.isNative());
+    try {
+      testForKeepAnno(parameters)
+          .addProgramClasses(getInputClasses())
+          .addKeepMainRule(TestClass.class)
+          .run(TestClass.class);
+    } catch (AssertionError e) {
+      assertThat(
+          e.getMessage(),
+          allOf(
+              containsString("Discard checks failed"),
+              containsString("A.toString() was not discarded")));
+      return;
+    }
+    fail("Expected compile failure");
   }
 
   public List<Class<?>> getInputClasses() {
@@ -88,8 +117,8 @@
     assertThat(inspector.clazz(A.class).uniqueMethodWithOriginalName("foo"), isAbsent());
     assertThat(inspector.clazz(A.class).uniqueMethodWithOriginalName("bar"), isAbsent());
 
-    // B is fully inlined and not in the residual program.
-    assertThat(inspector.clazz(B.class), isAbsent());
+    // B is fully inlined and not in the residual program (in R8).
+    assertThat(inspector.clazz(B.class), parameters.isPG() ? isPresent() : isAbsent());
   }
 
   static class A {
diff --git a/src/test/java/com/android/tools/r8/keepanno/CheckRemovedAnnotationTest.java b/src/test/java/com/android/tools/r8/keepanno/CheckRemovedAnnotationTest.java
index 0ed35c1..ca6b9e8 100644
--- a/src/test/java/com/android/tools/r8/keepanno/CheckRemovedAnnotationTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/CheckRemovedAnnotationTest.java
@@ -3,16 +3,19 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.keepanno;
 
+import static com.android.tools.r8.DiagnosticsMatcher.diagnosticMessage;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.allOf;
+import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeFalse;
+import static org.junit.Assume.assumeTrue;
 
 import com.android.tools.r8.DiagnosticsLevel;
 import com.android.tools.r8.DiagnosticsMatcher;
-import com.android.tools.r8.TestBase;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.errors.CheckDiscardDiagnostic;
 import com.android.tools.r8.keepanno.annotations.CheckRemoved;
 import com.android.tools.r8.utils.AndroidApiLevel;
@@ -23,56 +26,86 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
 @RunWith(Parameterized.class)
-public class CheckRemovedAnnotationTest extends TestBase {
+public class CheckRemovedAnnotationTest extends KeepAnnoTestBase {
 
   static final String EXPECTED = StringUtils.lines("A.foo", "B.baz");
 
-  private final TestParameters parameters;
+  @Parameter public KeepAnnoParameters parameters;
 
   @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build();
-  }
-
-  public CheckRemovedAnnotationTest(TestParameters parameters) {
-    this.parameters = parameters;
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build());
   }
 
   @Test
-  public void testReference() throws Exception {
-    testForRuntime(parameters)
+  public void test() throws Exception {
+    assumeFalse(parameters.isR8());
+    testForKeepAnno(parameters)
         .addProgramClasses(getInputClasses())
-        .run(parameters.getRuntime(), TestClass.class)
+        .addKeepMainRule(TestClass.class)
+        .setExcludedOuterClass(getClass())
+        .run(TestClass.class)
         .assertSuccessWithOutput(EXPECTED);
   }
 
   @Test
-  public void testWithRuleExtraction() throws Exception {
-    testForR8(parameters.getBackend())
-        .enableExperimentalKeepAnnotations()
+  public void testR8Native() throws Exception {
+    assumeTrue(parameters.isR8() && parameters.isNative());
+    testForKeepAnno(parameters)
         .addProgramClasses(getInputClasses())
         .addKeepMainRule(TestClass.class)
-        .setMinApi(parameters)
-        .allowDiagnosticWarningMessages()
-        .setDiagnosticsLevelModifier(
-            (level, diagnostic) ->
-                level == DiagnosticsLevel.ERROR ? DiagnosticsLevel.WARNING : level)
-        .compileWithExpectedDiagnostics(
-            diagnostics -> {
-              diagnostics
-                  .assertOnlyWarnings()
-                  .assertWarningsMatch(
-                      DiagnosticsMatcher.diagnosticType(CheckDiscardDiagnostic.class));
-              CheckDiscardDiagnostic discard =
-                  (CheckDiscardDiagnostic) diagnostics.getWarnings().get(0);
-              // The discard error should report for both the method A.foo and the class B.
-              assertEquals(discard.getDiagnosticMessage(), 2, discard.getNumberOfFailures());
-            })
-        .run(parameters.getRuntime(), TestClass.class)
-        .assertSuccessWithOutput(EXPECTED)
-        .inspect(this::checkOutput);
+        .applyIfR8Native(
+            b ->
+                b.allowDiagnosticWarningMessages()
+                    .setDiagnosticsLevelModifier(
+                        (level, diagnostic) ->
+                            level == DiagnosticsLevel.ERROR ? DiagnosticsLevel.WARNING : level)
+                    .compileWithExpectedDiagnostics(
+                        diagnostics -> {
+                          diagnostics
+                              .assertOnlyWarnings()
+                              .assertWarningsMatch(
+                                  DiagnosticsMatcher.diagnosticType(CheckDiscardDiagnostic.class));
+                          CheckDiscardDiagnostic discard =
+                              (CheckDiscardDiagnostic) diagnostics.getWarnings().get(0);
+                          // The discard error should report for both the method A.foo and the class
+                          // B.
+                          assertEquals(
+                              discard.getDiagnosticMessage(), 2, discard.getNumberOfFailures());
+                          assertThat(
+                              discard,
+                              diagnosticMessage(
+                                  allOf(
+                                      containsString("A.foo() was not discarded"),
+                                      containsString("B was not discarded"))));
+                        })
+                    .run(parameters.parameters().getRuntime(), TestClass.class)
+                    .assertSuccessWithOutput(EXPECTED)
+                    .inspect(this::checkOutput));
+  }
+
+  @Test
+  public void testR8Extract() throws Throwable {
+    assumeTrue(parameters.isR8() && !parameters.isNative());
+    try {
+      testForKeepAnno(parameters)
+          .addProgramClasses(getInputClasses())
+          .addKeepMainRule(TestClass.class)
+          .run(TestClass.class);
+    } catch (AssertionError e) {
+      assertThat(
+          e.getMessage(),
+          allOf(
+              containsString("Discard checks failed"),
+              containsString("A.foo() was not discarded"),
+              containsString("B was not discarded")));
+      return;
+    }
+    fail("Expected compile failure");
   }
 
   public List<Class<?>> getInputClasses() {
diff --git a/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java b/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
index cd475f0..737c95e 100644
--- a/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
+++ b/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
@@ -13,7 +13,7 @@
 import com.android.tools.r8.TestBuilder;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestShrinkerBuilder;
-import com.android.tools.r8.examples.sync.Sync.Consumer;
+import com.android.tools.r8.ThrowableConsumer;
 import com.android.tools.r8.keepanno.keeprules.KeepRuleExtractorOptions;
 import java.io.IOException;
 import java.util.List;
@@ -55,15 +55,15 @@
   public abstract SingleTestRunResult<?> run(Class<?> mainClass) throws Exception;
 
   public KeepAnnoTestBuilder applyIfR8(
-      Consumer<TestShrinkerBuilder<?, ?, ?, ?, ?>> builderConsumer) {
+      ThrowableConsumer<TestShrinkerBuilder<?, ?, ?, ?, ?>> builderConsumer) {
     return this;
   }
 
-  public KeepAnnoTestBuilder applyIfR8Native(Consumer<R8TestBuilder<?>> builderConsumer) {
+  public KeepAnnoTestBuilder applyIfR8Native(ThrowableConsumer<R8TestBuilder<?>> builderConsumer) {
     return this;
   }
 
-  public KeepAnnoTestBuilder applyIfPG(Consumer<ProguardTestBuilder> builderConsumer) {
+  public KeepAnnoTestBuilder applyIfPG(ThrowableConsumer<ProguardTestBuilder> builderConsumer) {
     return this;
   }
 
@@ -121,14 +121,15 @@
 
     @Override
     public KeepAnnoTestBuilder applyIfR8(
-        Consumer<TestShrinkerBuilder<?, ?, ?, ?, ?>> builderConsumer) {
-      builderConsumer.accept(builder);
+        ThrowableConsumer<TestShrinkerBuilder<?, ?, ?, ?, ?>> builderConsumer) {
+      builderConsumer.acceptWithRuntimeException(builder);
       return this;
     }
 
     @Override
-    public KeepAnnoTestBuilder applyIfR8Native(Consumer<R8TestBuilder<?>> builderConsumer) {
-      builderConsumer.accept(builder);
+    public KeepAnnoTestBuilder applyIfR8Native(
+        ThrowableConsumer<R8TestBuilder<?>> builderConsumer) {
+      builderConsumer.acceptWithRuntimeException(builder);
       return this;
     }
 
@@ -167,8 +168,8 @@
 
     @Override
     public KeepAnnoTestBuilder applyIfR8(
-        Consumer<TestShrinkerBuilder<?, ?, ?, ?, ?>> builderConsumer) {
-      builderConsumer.accept(builder);
+        ThrowableConsumer<TestShrinkerBuilder<?, ?, ?, ?, ?>> builderConsumer) {
+      builderConsumer.acceptWithRuntimeException(builder);
       return this;
     }
 
@@ -207,8 +208,8 @@
     }
 
     @Override
-    public KeepAnnoTestBuilder applyIfPG(Consumer<ProguardTestBuilder> builderConsumer) {
-      builderConsumer.accept(builder);
+    public KeepAnnoTestBuilder applyIfPG(ThrowableConsumer<ProguardTestBuilder> builderConsumer) {
+      builderConsumer.acceptWithRuntimeException(builder);
       return this;
     }