Support classpath in keep annotation matcher

Bug: b/392828287
Change-Id: I996689dccf6d6161939261ed403a7e2a875cfb65
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index 1e64e9e..e19f197 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -239,6 +239,18 @@
     return Iterables.concat(fields(predicate), methods(predicate));
   }
 
+  public boolean hasFields() {
+    return fieldCollection.size() > 0;
+  }
+
+  public boolean hasMethods() {
+    return methodCollection.size() > 0;
+  }
+
+  public boolean hasMethodsOrFields() {
+    return hasMethods() || hasFields();
+  }
+
   public FieldCollection getFieldCollection() {
     return fieldCollection;
   }
@@ -338,6 +350,11 @@
     methodCollection.forEachMethod(consumer);
   }
 
+  public void forEachClassMember(Consumer<? super DexClassAndMember<?, ?>> consumer) {
+    forEachClassField(consumer);
+    forEachClassMethod(consumer);
+  }
+
   public List<DexEncodedField> allFieldsSorted() {
     return fieldCollection.allFieldsSorted();
   }
diff --git a/src/main/java/com/android/tools/r8/graph/DexProgramClass.java b/src/main/java/com/android/tools/r8/graph/DexProgramClass.java
index b5f8a5c..ceb1ee5 100644
--- a/src/main/java/com/android/tools/r8/graph/DexProgramClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexProgramClass.java
@@ -679,18 +679,6 @@
     return false;
   }
 
-  public boolean hasFields() {
-    return fieldCollection.size() > 0;
-  }
-
-  public boolean hasMethods() {
-    return methodCollection.size() > 0;
-  }
-
-  public boolean hasMethodsOrFields() {
-    return hasMethods() || hasFields();
-  }
-
   /** Determine if the class or any of its methods/fields has any attributes. */
   public boolean hasClassOrMemberAnnotations() {
     return !annotations().isEmpty()
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 626bc31..6b861fe 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -3832,8 +3832,7 @@
     assert applicableRules == ApplicableRulesEvaluator.empty();
     if (mode.isInitialTreeShaking()) {
       applicableRules =
-          KeepAnnotationMatcher.computeInitialRules(
-              appView, keepDeclarations, options.getThreadingModule(), executorService);
+          KeepAnnotationMatcher.computeInitialRules(appView, keepDeclarations, executorService);
       // Amend library methods with covariant return types.
       timing.begin("Model library");
       modelLibraryMethodsWithCovariantReturnTypes(appView);
diff --git a/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java b/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
index c45a64d..0bd6498 100644
--- a/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
+++ b/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
@@ -6,10 +6,12 @@
 
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.Definition;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexClassAndMember;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramDefinition;
-import com.android.tools.r8.graph.ProgramMember;
 import com.android.tools.r8.keepanno.ast.KeepAnnotationPattern;
 import com.android.tools.r8.keepanno.ast.KeepBindingReference;
 import com.android.tools.r8.keepanno.ast.KeepBindings;
@@ -42,7 +44,6 @@
 import com.android.tools.r8.shaking.KeepAnnotationCollectionInfo.RetentionInfo;
 import com.android.tools.r8.shaking.KeepInfo.Joiner;
 import com.android.tools.r8.shaking.MinimumKeepInfoCollection;
-import com.android.tools.r8.threading.ThreadingModule;
 import com.android.tools.r8.utils.BooleanBox;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.ThreadUtils;
@@ -58,13 +59,13 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
+import java.util.function.Predicate;
 
 public class KeepAnnotationMatcher {
 
   public static ApplicableRulesEvaluator computeInitialRules(
       AppView<? extends AppInfoWithClassHierarchy> appView,
       List<KeepDeclaration> keepDeclarations,
-      ThreadingModule threadingModule,
       ExecutorService executorService)
       throws ExecutionException {
     KeepAnnotationMatcherPredicates predicates =
@@ -73,7 +74,7 @@
     ThreadUtils.processItems(
         keepDeclarations,
         declaration -> processDeclaration(declaration, appView.appInfo(), predicates, builder),
-        threadingModule,
+        appView.options().getThreadingModule(),
         executorService);
     return builder.build(appView);
   }
@@ -329,7 +330,7 @@
             appInfo
                 .dexItemFactory()
                 .createType(classPattern.getClassNamePattern().getExactDescriptor());
-        DexProgramClass clazz = DexProgramClass.asProgramClassOrNull(appInfo.definitionFor(type));
+        DexClass clazz = appInfo.definitionFor(type);
         if (clazz == null) {
           continueWithNoClass(classIndex);
           return;
@@ -349,7 +350,7 @@
       }
     }
 
-    private void continueWithClass(int classIndex, DexProgramClass clazz) {
+    private void continueWithClass(int classIndex, DexClass clazz) {
       assignment.setClass(classIndex, clazz);
       IntList classMemberIndexList = schema.classMembers.get(classIndex);
       findMatchingMember(0, classMemberIndexList, clazz, classIndex + 1);
@@ -373,7 +374,7 @@
     private void findMatchingMember(
         int memberInHolderIndex,
         IntList memberIndexTranslation,
-        DexProgramClass holder,
+        DexClass holder,
         int nextClassIndex) {
       if (memberInHolderIndex == memberIndexTranslation.size()) {
         // All members of this class are assigned, continue search for the next class.
@@ -383,7 +384,7 @@
       int memberIndex = memberIndexTranslation.getInt(memberInHolderIndex);
       KeepMemberItemPattern memberItemPattern = schema.members.get(memberIndex);
       BooleanBox didContinue = new BooleanBox(false);
-      Consumer<ProgramDefinition> continueWithMember =
+      Consumer<Definition> continueWithMember =
           m -> {
             didContinue.isTrue();
             continueWithMember(
@@ -397,7 +398,7 @@
                   // The empty class can only match the "all member" pattern but with no assignment.
                   continueWithMember.accept(holder);
                 } else {
-                  holder.forEachProgramMember(
+                  holder.forEachClassMember(
                       m -> {
                         if (predicates.matchesGeneralMember(
                             m.getDefinition(), generalMemberPattern)) {
@@ -407,10 +408,10 @@
                 }
               },
               fieldPattern ->
-                  holder.forEachProgramFieldMatching(
+                  holder.forEachClassFieldMatching(
                       f -> predicates.matchesField(f, fieldPattern, appInfo), continueWithMember),
               methodPattern ->
-                  holder.forEachProgramMethodMatching(
+                  holder.forEachClassMethodMatching(
                       m -> predicates.matchesMethod(m, methodPattern, appInfo),
                       continueWithMember));
       if (didContinue.isFalse()) {
@@ -421,13 +422,13 @@
     }
 
     private void continueWithMember(
-        ProgramDefinition definition,
+        Definition definition,
         int memberIndex,
         int nextMemberInHolderIndex,
         IntList memberIndexTranslation,
         int nextClassIndex) {
-      if (definition.isProgramMember()) {
-        assignment.setMember(memberIndex, definition.asProgramMember());
+      if (definition.isMember()) {
+        assignment.setMember(memberIndex, definition.asMember());
       } else {
         assert definition.isProgramClass();
         assert !definition.asProgramClass().hasMethodsOrFields();
@@ -549,27 +550,27 @@
    */
   private static class Assignment {
 
-    final List<DexProgramClass> classes;
-    final List<ProgramMember<?, ?>> members;
+    final List<DexClass> classes;
+    final List<DexClassAndMember<?, ?>> members;
     boolean hasEmptyMembers = false;
 
     private Assignment(NormalizedSchema schema) {
-      classes = Arrays.asList(new DexProgramClass[schema.classes.size()]);
-      members = Arrays.asList(new ProgramMember<?, ?>[schema.members.size()]);
+      classes = Arrays.asList(new DexClass[schema.classes.size()]);
+      members = Arrays.asList(new DexClassAndMember<?, ?>[schema.members.size()]);
     }
 
-    ProgramDefinition getItemForReference(int keyReference) {
+    Definition getItemForReference(int keyReference) {
       if (NormalizedSchema.isClassKeyReference(keyReference)) {
         return classes.get(NormalizedSchema.decodeClassKeyReference(keyReference));
       }
       return members.get(NormalizedSchema.decodeMemberKeyReference(keyReference));
     }
 
-    void setClass(int index, DexProgramClass type) {
+    void setClass(int index, DexClass type) {
       classes.set(index, type);
     }
 
-    void setMember(int index, ProgramMember<?, ?> member) {
+    void setMember(int index, DexClassAndMember<?, ?> member) {
       members.set(index, member);
     }
 
@@ -583,19 +584,21 @@
           schema.declaration,
           schema.preconditions.isEmpty()
               ? Collections.emptyList()
-              : getItemList(schema.preconditions, false),
-          getItemList(schema.consequences, true),
+              : getItemList(schema.preconditions, false, Definition::isProgramDefinition),
+          getItemList(schema.consequences, true, Definition::isProgramDefinition),
           getConstraints(schema));
     }
 
-    private List<ProgramDefinition> getItemList(IntList indexReferences, boolean allowUnset) {
+    @SuppressWarnings("unchecked")
+    private <T extends Definition> List<T> getItemList(
+        IntList indexReferences, boolean allowUnset, Predicate<Definition> predicate) {
       assert !indexReferences.isEmpty();
-      List<ProgramDefinition> definitions = new ArrayList<>(indexReferences.size());
+      List<T> definitions = new ArrayList<>(indexReferences.size());
       for (int i = 0; i < indexReferences.size(); i++) {
-        ProgramDefinition item = getItemForReference(indexReferences.getInt(i));
+        Definition item = getItemForReference(indexReferences.getInt(i));
         assert item != null || hasEmptyMembers || allowUnset;
-        if (item != null) {
-          definitions.add(item);
+        if (item != null && predicate.test(item)) {
+          definitions.add((T) item);
         }
       }
       return definitions;
@@ -609,7 +612,7 @@
       // constraints, so it matches the consequence list.
       ImmutableList.Builder<KeepConstraints> builder = ImmutableList.builder();
       for (int i = 0; i < schema.consequences.size(); i++) {
-        ProgramDefinition item = getItemForReference(schema.consequences.getInt(i));
+        Definition item = getItemForReference(schema.consequences.getInt(i));
         if (item != null) {
           builder.add(schema.constraints.get(i));
         }
diff --git a/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcherPredicates.java b/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcherPredicates.java
index 9587d94..1f4d01a 100644
--- a/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcherPredicates.java
+++ b/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcherPredicates.java
@@ -13,7 +13,6 @@
 import com.android.tools.r8.graph.DexEncodedMember;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory;
-import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexTypeList;
@@ -53,7 +52,7 @@
   }
 
   public boolean matchesClass(
-      DexProgramClass clazz, KeepClassItemPattern classPattern, AppInfoWithClassHierarchy appInfo) {
+      DexClass clazz, KeepClassItemPattern classPattern, AppInfoWithClassHierarchy appInfo) {
     return matchesClassName(clazz.getType(), classPattern.getClassNamePattern())
         && matchesAnnotatedBy(clazz.annotations(), classPattern.getAnnotatedByPattern())
         && matchesInstanceOfPattern(clazz, classPattern.getInstanceOfPattern(), appInfo);
diff --git a/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java b/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
index 4082c0a..e0dd195 100644
--- a/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
+++ b/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
@@ -122,6 +122,11 @@
     return this;
   }
 
+  public KeepAnnoTestBuilder applyIfR8Partial(
+      ThrowableConsumer<R8PartialTestBuilder> builderConsumer) {
+    return this;
+  }
+
   public KeepAnnoTestBuilder applyIfPG(ThrowableConsumer<ProguardTestBuilder> builderConsumer) {
     return this;
   }
@@ -346,6 +351,13 @@
     }
 
     @Override
+    public KeepAnnoTestBuilder applyIfR8Partial(
+        ThrowableConsumer<R8PartialTestBuilder> builderConsumer) {
+      builderConsumer.acceptWithRuntimeException(builder);
+      return this;
+    }
+
+    @Override
     boolean isExtractRules() {
       return config == KeepAnnoConfig.R8_PARTIAL_RULES;
     }
diff --git a/src/test/java/com/android/tools/r8/keepanno/KeepUsesReflectionAnnotationTest.java b/src/test/java/com/android/tools/r8/keepanno/KeepUsesReflectionAnnotationTest.java
index d08fd34..b51c86b 100644
--- a/src/test/java/com/android/tools/r8/keepanno/KeepUsesReflectionAnnotationTest.java
+++ b/src/test/java/com/android/tools/r8/keepanno/KeepUsesReflectionAnnotationTest.java
@@ -5,7 +5,9 @@
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresentIf;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assume.assumeTrue;
 
 import com.android.tools.r8.keepanno.annotations.KeepTarget;
 import com.android.tools.r8.keepanno.annotations.UsesReflection;
@@ -17,6 +19,7 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
 public class KeepUsesReflectionAnnotationTest extends KeepAnnoTestBase {
@@ -25,7 +28,7 @@
 
   @Parameter public KeepAnnoParameters parameters;
 
-  @Parameterized.Parameters(name = "{0}")
+  @Parameters(name = "{0}")
   public static List<KeepAnnoParameters> data() {
     return createParameters(
         getTestParameters().withDefaultRuntimes().withMaximumApiLevel().build());
@@ -40,17 +43,35 @@
         .allowUnusedProguardConfigurationRules()
         .run(TestClass.class)
         .assertSuccessWithOutput(EXPECTED)
-        .applyIf(parameters.isShrinker(), r -> r.inspect(this::checkOutput));
+        .applyIf(
+            parameters.isShrinker(), r -> r.inspect(inspector -> checkOutput(inspector, false)));
+  }
+
+  @Test
+  public void testUsesReflectionInD8OfR8Partial() throws Exception {
+    assumeTrue(parameters.isR8Partial());
+    testForKeepAnno(parameters)
+        .addProgramClasses(getInputClasses())
+        .applyIfR8Partial(
+            builder ->
+                builder
+                    .clearR8PartialConfiguration()
+                    .setR8PartialConfiguration(b -> b.includeAll().excludeClasses(A.class)))
+        .addKeepMainRule(TestClass.class)
+        .run(TestClass.class)
+        .assertSuccessWithOutput(EXPECTED)
+        .applyIf(
+            parameters.isShrinker(), r -> r.inspect(inspector -> checkOutput(inspector, true)));
   }
 
   public List<Class<?>> getInputClasses() {
     return ImmutableList.of(TestClass.class, A.class, B.class, C.class);
   }
 
-  private void checkOutput(CodeInspector inspector) {
+  private void checkOutput(CodeInspector inspector, boolean aOnClasspath) {
     assertThat(inspector.clazz(A.class), isPresent());
     assertThat(inspector.clazz(B.class), isPresent());
-    assertThat(inspector.clazz(C.class), parameters.isPG() ? isPresent() : isAbsent());
+    assertThat(inspector.clazz(C.class), isPresentIf(parameters.isPG() || aOnClasspath));
     assertThat(inspector.clazz(B.class).method("void", "bar"), isPresent());
     assertThat(inspector.clazz(B.class).method("void", "bar", "int"), isAbsent());
   }
diff --git a/src/test/testbase/java/com/android/tools/r8/R8PartialTestBuilder.java b/src/test/testbase/java/com/android/tools/r8/R8PartialTestBuilder.java
index 7f9442f..1f1a642 100644
--- a/src/test/testbase/java/com/android/tools/r8/R8PartialTestBuilder.java
+++ b/src/test/testbase/java/com/android/tools/r8/R8PartialTestBuilder.java
@@ -77,6 +77,11 @@
     return self();
   }
 
+  public R8PartialTestBuilder clearR8PartialConfiguration() {
+    r8PartialConfiguration = R8PartialCompilationConfiguration.disabledConfiguration();
+    return this;
+  }
+
   public R8PartialTestBuilder addR8IncludedClasses(Class<?>... classes) {
     return addR8IncludedClasses(true, classes);
   }