[KeepAnno] Add native interpretation of check removed/optimized-out

Bug: b/323816623
Change-Id: I4df8b0c1eb2c70a65f791703c75d60453da5d96f
diff --git a/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepCheck.java b/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepCheck.java
index 2ea942e..1b2e42d 100644
--- a/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepCheck.java
+++ b/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepCheck.java
@@ -69,4 +69,9 @@
   public KeepItemPattern getItemPattern() {
     return itemPattern;
   }
+
+  @Override
+  public String toString() {
+    return "KeepCheck{kind=" + kind + ", item=" + itemPattern + "}";
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/KeepInfoCollection.java b/src/main/java/com/android/tools/r8/shaking/KeepInfoCollection.java
index 32835e2..e524b2a 100644
--- a/src/main/java/com/android/tools/r8/shaking/KeepInfoCollection.java
+++ b/src/main/java/com/android/tools/r8/shaking/KeepInfoCollection.java
@@ -348,6 +348,7 @@
       if (prunedItems.hasRemovedClasses() || prunedItems.hasRemovedMembers()) {
         keepMethodInfo.keySet().removeIf(prunedItems::isRemoved);
       }
+      materializedRules.pruneItems(prunedItems);
     }
 
     @Override
diff --git a/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java b/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
index 1ef228f..fa7fb7d 100644
--- a/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
+++ b/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
@@ -4,7 +4,6 @@
 
 package com.android.tools.r8.shaking.rules;
 
-import com.android.tools.r8.errors.Unimplemented;
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexProgramClass;
@@ -14,6 +13,8 @@
 import com.android.tools.r8.keepanno.ast.KeepAnnotationPattern;
 import com.android.tools.r8.keepanno.ast.KeepBindingReference;
 import com.android.tools.r8.keepanno.ast.KeepBindings.KeepBindingSymbol;
+import com.android.tools.r8.keepanno.ast.KeepCheck;
+import com.android.tools.r8.keepanno.ast.KeepCheck.KeepCheckKind;
 import com.android.tools.r8.keepanno.ast.KeepClassItemPattern;
 import com.android.tools.r8.keepanno.ast.KeepCondition;
 import com.android.tools.r8.keepanno.ast.KeepConstraint.Annotation;
@@ -32,7 +33,6 @@
 import com.android.tools.r8.keepanno.ast.KeepConstraintVisitor;
 import com.android.tools.r8.keepanno.ast.KeepConstraints;
 import com.android.tools.r8.keepanno.ast.KeepDeclaration;
-import com.android.tools.r8.keepanno.ast.KeepEdge;
 import com.android.tools.r8.keepanno.ast.KeepItemPattern;
 import com.android.tools.r8.keepanno.ast.KeepItemReference;
 import com.android.tools.r8.keepanno.ast.KeepMemberItemPattern;
@@ -99,10 +99,45 @@
                   }
                 }),
         check -> {
-          throw new Unimplemented();
+          edgeMatcher.forEachMatch(
+              check,
+              result -> {
+                assert result.preconditions.isEmpty();
+                builder.addRootRule(
+                    keepInfoCollection ->
+                        createCheckDiscardInfo(check, result, keepInfoCollection));
+              });
         });
   }
 
+  private static void createCheckDiscardInfo(
+      KeepCheck check, MatchResult result, MinimumKeepInfoCollection keepInfoCollection) {
+    boolean isRemovedCheck = check.getKind() == KeepCheckKind.REMOVED;
+    ListUtils.forEachWithIndex(
+        result.consequences,
+        (item, i) -> {
+          applyCheckDiscardConstraints(isRemovedCheck, keepInfoCollection, item);
+          // If a check-discard is annotated on a type, then it applies to all members of the type.
+          if (item.isClass()) {
+            item.asProgramClass()
+                .forEachProgramMember(
+                    member ->
+                        applyCheckDiscardConstraints(isRemovedCheck, keepInfoCollection, member));
+          }
+        });
+  }
+
+  private static void applyCheckDiscardConstraints(
+      boolean isRemovedCheck,
+      MinimumKeepInfoCollection keepInfoCollection,
+      ProgramDefinition item) {
+    Joiner<?, ?, ?> joiner = keepInfoCollection.getOrCreateMinimumKeepInfoFor(item.getReference());
+    joiner.setCheckDiscarded();
+    if (isRemovedCheck) {
+      joiner.disallowOptimization();
+    }
+  }
+
   private static MinimumKeepInfoCollection createKeepInfo(
       MatchResult result,
       MinimumKeepInfoCollection minimumKeepInfoCollection,
@@ -112,7 +147,8 @@
         (item, i) -> {
           Joiner<?, ?, ?> joiner =
               minimumKeepInfoCollection.getOrCreateMinimumKeepInfoFor(item.getReference());
-          updateWithConstraints(item, joiner, result.constraints.get(i), result.edge, predicates);
+          updateWithConstraints(
+              item, joiner, result.constraints.get(i), result.declaration, predicates);
         });
     return minimumKeepInfoCollection;
   }
@@ -121,7 +157,7 @@
       ProgramDefinition item,
       Joiner<?, ?, ?> joiner,
       KeepConstraints constraints,
-      KeepEdge edge,
+      KeepDeclaration declaration,
       KeepAnnotationMatcherPredicates predicates) {
     constraints.forEachAccept(
         new KeepConstraintVisitor() {
@@ -129,7 +165,7 @@
           @Override
           public void onLookup(Lookup constraint) {
             joiner.disallowShrinking();
-            joiner.addRule(new KeepAnnotationFakeProguardRule(edge.getMetaInfo()));
+            joiner.addRule(new KeepAnnotationFakeProguardRule(declaration.getMetaInfo()));
           }
 
           @Override
@@ -224,17 +260,17 @@
   }
 
   public static class MatchResult {
-    private final KeepEdge edge;
+    private final KeepDeclaration declaration;
     private final List<ProgramDefinition> preconditions;
     private final List<ProgramDefinition> consequences;
     private final List<KeepConstraints> constraints;
 
     public MatchResult(
-        KeepEdge edge,
+        KeepDeclaration declaration,
         List<ProgramDefinition> preconditions,
         List<ProgramDefinition> consequences,
         List<KeepConstraints> constraints) {
-      this.edge = edge;
+      this.declaration = declaration;
       this.preconditions = preconditions;
       this.consequences = consequences;
       this.constraints = constraints;
@@ -256,9 +292,9 @@
       this.predicates = predicates;
     }
 
-    public void forEachMatch(KeepEdge edge, Consumer<MatchResult> callback) {
+    public void forEachMatch(KeepDeclaration declaration, Consumer<MatchResult> callback) {
       this.callback = callback;
-      schema = new NormalizedSchema(edge);
+      schema = new NormalizedSchema(declaration);
       assignment = new Assignment(schema);
       findMatchingClass(0);
       schema = null;
@@ -377,7 +413,7 @@
    */
   private static class NormalizedSchema {
 
-    final KeepEdge edge;
+    final KeepDeclaration declaration;
     final Reference2IntMap<KeepBindingSymbol> symbolToKey = new Reference2IntOpenHashMap<>();
     final List<KeepClassItemPattern> classes = new ArrayList<>();
     final List<KeepMemberItemPattern> members = new ArrayList<>();
@@ -386,10 +422,21 @@
     final IntList consequences = new IntArrayList();
     final List<KeepConstraints> constraints = new ArrayList<>();
 
-    public NormalizedSchema(KeepEdge edge) {
-      this.edge = edge;
-      edge.getPreconditions().forEach(this::addPrecondition);
-      edge.getConsequences().forEachTarget(this::addConsequence);
+    public NormalizedSchema(KeepDeclaration declaration) {
+      this.declaration = declaration;
+      declaration.match(
+          edge -> {
+            edge.getPreconditions().forEach(this::addPrecondition);
+            edge.getConsequences().forEachTarget(this::addConsequence);
+          },
+          check -> {
+            consequences.add(defineItemPattern(check.getItemPattern()));
+          });
+    }
+
+    private KeepItemPattern getItemForBinding(KeepBindingSymbol symbol) {
+      assert declaration.isKeepEdge();
+      return declaration.asKeepEdge().getBindings().get(symbol).getItem();
     }
 
     public static boolean isClassKeyReference(int keyRef) {
@@ -423,8 +470,7 @@
 
     private int defineBindingReference(KeepBindingReference reference) {
       return symbolToKey.computeIfAbsent(
-          reference.getName(),
-          symbol -> defineItemPattern(edge.getBindings().get(symbol).getItem()));
+          reference.getName(), symbol -> defineItemPattern(getItemForBinding(symbol)));
     }
 
     private int defineItemPattern(KeepItemPattern item) {
@@ -490,7 +536,7 @@
 
     public MatchResult createMatch(NormalizedSchema schema) {
       return new MatchResult(
-          schema.edge,
+          schema.declaration,
           schema.preconditions.isEmpty()
               ? Collections.emptyList()
               : getItemList(schema.preconditions),
diff --git a/src/main/java/com/android/tools/r8/shaking/rules/MaterializedConditionalRule.java b/src/main/java/com/android/tools/r8/shaking/rules/MaterializedConditionalRule.java
index 069b3c3..cbae8fe 100644
--- a/src/main/java/com/android/tools/r8/shaking/rules/MaterializedConditionalRule.java
+++ b/src/main/java/com/android/tools/r8/shaking/rules/MaterializedConditionalRule.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.shaking.rules;
 
 import com.android.tools.r8.graph.DexReference;
+import com.android.tools.r8.graph.PrunedItems;
 import com.android.tools.r8.shaking.MinimumKeepInfoCollection;
 import java.util.List;
 
@@ -23,4 +24,26 @@
   PendingConditionalRule asPendingRule() {
     return new PendingConditionalRule(preconditions, consequences);
   }
+
+  public boolean pruneItems(PrunedItems prunedItems) {
+    for (DexReference precondition : preconditions) {
+      if (precondition.isDexType()) {
+        if (prunedItems.getRemovedClasses().contains(precondition.asDexType())) {
+          return true;
+        }
+      } else if (precondition.isDexField()) {
+        if (prunedItems.getRemovedFields().contains(precondition.asDexField())) {
+          return true;
+        }
+      } else {
+        assert precondition.isDexMethod();
+        if (prunedItems.getRemovedMethods().contains(precondition.asDexMethod())) {
+          return true;
+        }
+      }
+    }
+    // Preconditions are in place, so trim down consequences.
+    consequences.pruneItems(prunedItems);
+    return consequences.isEmpty();
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/rules/MaterializedRules.java b/src/main/java/com/android/tools/r8/shaking/rules/MaterializedRules.java
index d12ce3d..c66cdf1 100644
--- a/src/main/java/com/android/tools/r8/shaking/rules/MaterializedRules.java
+++ b/src/main/java/com/android/tools/r8/shaking/rules/MaterializedRules.java
@@ -4,6 +4,7 @@
 
 package com.android.tools.r8.shaking.rules;
 
+import com.android.tools.r8.graph.PrunedItems;
 import com.android.tools.r8.graph.lens.NonIdentityGraphLens;
 import com.android.tools.r8.shaking.MinimumKeepInfoCollection;
 import com.android.tools.r8.utils.ListUtils;
@@ -45,4 +46,9 @@
         rootConsequences,
         ListUtils.map(conditionalRules, MaterializedConditionalRule::asPendingRule));
   }
+
+  public void pruneItems(PrunedItems prunedItems) {
+    rootConsequences.pruneItems(prunedItems);
+    conditionalRules.removeIf(c -> c.pruneItems(prunedItems));
+  }
 }
diff --git a/src/test/java/com/android/tools/r8/keepanno/CheckOptimizedOutAnnotationTest.java b/src/test/java/com/android/tools/r8/keepanno/CheckOptimizedOutAnnotationTest.java
index 656bb6b..f44c479 100644
--- a/src/test/java/com/android/tools/r8/keepanno/CheckOptimizedOutAnnotationTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/CheckOptimizedOutAnnotationTest.java
@@ -46,6 +46,7 @@
   public void test() throws Exception {
     assumeFalse(parameters.isR8());
     testForKeepAnno(parameters)
+        .enableNativeInterpretation()
         .addProgramClasses(getInputClasses())
         .addKeepMainRule(TestClass.class)
         .setExcludedOuterClass(getClass())
@@ -58,6 +59,7 @@
   public void testCurrentR8() throws Throwable {
     assumeTrue(parameters.isR8() && parameters.isCurrentR8());
     testForKeepAnno(parameters)
+        .enableNativeInterpretation()
         .addProgramClasses(getInputClasses())
         .addKeepMainRule(TestClass.class)
         .applyIfR8Current(
@@ -92,6 +94,7 @@
     assertTrue(parameters.isLegacyR8());
     try {
       testForKeepAnno(parameters)
+          .enableNativeInterpretation()
           .addProgramClasses(getInputClasses())
           .addKeepMainRule(TestClass.class)
           .run(TestClass.class);
diff --git a/src/test/java/com/android/tools/r8/keepanno/CheckRemovedAnnotationTest.java b/src/test/java/com/android/tools/r8/keepanno/CheckRemovedAnnotationTest.java
index 4010ef8..2d13454 100644
--- a/src/test/java/com/android/tools/r8/keepanno/CheckRemovedAnnotationTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/CheckRemovedAnnotationTest.java
@@ -46,6 +46,7 @@
   public void test() throws Exception {
     assumeFalse(parameters.isR8());
     testForKeepAnno(parameters)
+        .enableNativeInterpretation()
         .addProgramClasses(getInputClasses())
         .addKeepMainRule(TestClass.class)
         .setExcludedOuterClass(getClass())
@@ -57,6 +58,7 @@
   public void testCurrentR8() throws Exception {
     assumeTrue(parameters.isR8() && parameters.isCurrentR8());
     testForKeepAnno(parameters)
+        .enableNativeInterpretation()
         .addProgramClasses(getInputClasses())
         .addKeepMainRule(TestClass.class)
         .applyIfR8Current(
@@ -95,6 +97,7 @@
     assertTrue(parameters.isLegacyR8());
     try {
       testForKeepAnno(parameters)
+          .enableNativeInterpretation()
           .addProgramClasses(getInputClasses())
           .addKeepMainRule(TestClass.class)
           .run(TestClass.class);