Apply reflective identification to R8 partial excluded classes

Bug: b/414763615
Change-Id: I69aaaa6b8cc7e72bf0cbf10537e3101c6a44e791
diff --git a/src/main/java/com/android/tools/r8/partial/R8PartialUseCollector.java b/src/main/java/com/android/tools/r8/partial/R8PartialUseCollector.java
index bc4d025..71ad575 100644
--- a/src/main/java/com/android/tools/r8/partial/R8PartialUseCollector.java
+++ b/src/main/java/com/android/tools/r8/partial/R8PartialUseCollector.java
@@ -23,10 +23,13 @@
 import com.android.tools.r8.shaking.ProguardConfigurationParser.IdentifierPatternWithWildcards;
 import com.android.tools.r8.shaking.ProguardTypeMatcher;
 import com.android.tools.r8.shaking.ProguardTypeMatcher.ClassOrType;
+import com.android.tools.r8.shaking.reflectiveidentification.KeepAllReflectiveIdentificationEventConsumer;
+import com.android.tools.r8.shaking.reflectiveidentification.ReflectiveIdentification;
 import com.android.tools.r8.tracereferences.TraceReferencesConsumer;
 import com.android.tools.r8.tracereferences.UseCollector;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.NopDiagnosticsHandler;
+import com.android.tools.r8.utils.timing.Timing;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
@@ -35,6 +38,8 @@
 
 public abstract class R8PartialUseCollector extends UseCollector {
 
+  private final ReflectiveIdentification reflectiveIdentification;
+
   private final Set<DexReference> seenAllowObfuscation = ConcurrentHashMap.newKeySet();
   private final Set<DexReference> seenDisallowObfuscation = ConcurrentHashMap.newKeySet();
   private final Set<String> packagesToKeep = ConcurrentHashMap.newKeySet();
@@ -45,6 +50,9 @@
         new MissingReferencesConsumer(),
         new NopDiagnosticsHandler(),
         getTargetPredicate(appView));
+    this.reflectiveIdentification =
+        new ReflectiveIdentification(
+            appView, new KeepAllReflectiveIdentificationEventConsumer(this));
   }
 
   public static Predicate<DexType> getTargetPredicate(
@@ -56,6 +64,7 @@
     R8PartialR8SubCompilationConfiguration r8SubCompilationConfiguration =
         appView.options().partialSubCompilationConfiguration.asR8();
     traceClasses(r8SubCompilationConfiguration.getDexingOutputClasses(), executorService);
+    reflectiveIdentification.processWorklist(Timing.empty());
     commitPackagesToKeep();
   }
 
@@ -80,7 +89,7 @@
             .build());
   }
 
-  protected abstract void keep(
+  public abstract void keep(
       Definition definition, DefinitionContext referencedFrom, boolean allowObfuscation);
 
   @Override
@@ -123,6 +132,11 @@
     packagesToKeep.add(definition.getContextType().getPackageName());
   }
 
+  @Override
+  protected void notifyReflectiveIdentification(DexMethod invokedMethod, ProgramMethod method) {
+    reflectiveIdentification.scanInvoke(invokedMethod, method);
+  }
+
   private static class MissingReferencesConsumer implements TraceReferencesConsumer {
 
     @Override
diff --git a/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java b/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
index 58861b2..39b35aa 100644
--- a/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
+++ b/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
@@ -261,7 +261,7 @@
                     .build();
 
             @Override
-            protected synchronized void keep(
+            public synchronized void keep(
                 Definition definition, DefinitionContext referencedFrom, boolean allowObfuscation) {
               if (definition.isProgramDefinition()) {
                 ReferencedFromExcludedClassInR8PartialRule rule =
diff --git a/src/main/java/com/android/tools/r8/shaking/reflectiveidentification/KeepAllReflectiveIdentificationEventConsumer.java b/src/main/java/com/android/tools/r8/shaking/reflectiveidentification/KeepAllReflectiveIdentificationEventConsumer.java
new file mode 100644
index 0000000..e3afab1
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/shaking/reflectiveidentification/KeepAllReflectiveIdentificationEventConsumer.java
@@ -0,0 +1,100 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.shaking.reflectiveidentification;
+
+import com.android.tools.r8.diagnostic.DefinitionContext;
+import com.android.tools.r8.diagnostic.internal.DefinitionContextUtils;
+import com.android.tools.r8.graph.Definition;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.ProgramField;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.partial.R8PartialUseCollector;
+import java.util.Collection;
+import java.util.Set;
+
+public class KeepAllReflectiveIdentificationEventConsumer
+    implements ReflectiveIdentificationEventConsumer {
+
+  private final R8PartialUseCollector useCollector;
+
+  public KeepAllReflectiveIdentificationEventConsumer(R8PartialUseCollector useCollector) {
+    this.useCollector = useCollector;
+  }
+
+  private void keep(Definition definition, ProgramMethod context) {
+    DefinitionContext referencedFrom = DefinitionContextUtils.create(context);
+    useCollector.keep(definition, referencedFrom, false);
+  }
+
+  @Override
+  public void onJavaLangClassForName(DexClass clazz, ProgramMethod context) {
+    keep(clazz, context);
+  }
+
+  @Override
+  public void onJavaLangClassGetField(ProgramField field, ProgramMethod context) {
+    keep(field, context);
+  }
+
+  @Override
+  public void onJavaLangClassGetMethod(ProgramMethod method, ProgramMethod context) {
+    keep(method, context);
+  }
+
+  @Override
+  public void onJavaLangClassNewInstance(DexProgramClass clazz, ProgramMethod context) {
+    ProgramMethod defaultInitializer = clazz.getProgramDefaultInitializer();
+    if (defaultInitializer != null) {
+      keep(defaultInitializer, context);
+    }
+  }
+
+  @Override
+  public void onJavaLangReflectConstructorNewInstance(
+      ProgramMethod initializer, ProgramMethod context) {
+    keep(initializer, context);
+  }
+
+  @Override
+  public void onJavaLangReflectProxyNewProxyInstance(
+      Set<DexProgramClass> classes, ProgramMethod context) {
+    for (DexProgramClass clazz : classes) {
+      keep(clazz, context);
+    }
+  }
+
+  @Override
+  public void onJavaUtilConcurrentAtomicAtomicIntegerFieldUpdaterNewUpdater(
+      ProgramField field, ProgramMethod context) {
+    keep(field, context);
+  }
+
+  @Override
+  public void onJavaUtilConcurrentAtomicAtomicLongFieldUpdaterNewUpdater(
+      ProgramField field, ProgramMethod context) {
+    keep(field, context);
+  }
+
+  @Override
+  public void onJavaUtilConcurrentAtomicAtomicReferenceFieldUpdaterNewUpdater(
+      ProgramField field, ProgramMethod context) {
+    keep(field, context);
+  }
+
+  @Override
+  public void onJavaUtilServiceLoaderLoad(
+      DexProgramClass serviceClass,
+      Collection<DexProgramClass> implementationClasses,
+      ProgramMethod context) {
+    keep(serviceClass, context);
+    for (DexProgramClass implementationClass : implementationClasses) {
+      keep(implementationClass, context);
+      ProgramMethod defaultInitializer = implementationClass.getProgramDefaultInitializer();
+      if (defaultInitializer != null) {
+        keep(defaultInitializer, context);
+      }
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/tracereferences/UseCollector.java b/src/main/java/com/android/tools/r8/tracereferences/UseCollector.java
index e88878e..2171eac 100644
--- a/src/main/java/com/android/tools/r8/tracereferences/UseCollector.java
+++ b/src/main/java/com/android/tools/r8/tracereferences/UseCollector.java
@@ -101,6 +101,10 @@
     return this;
   }
 
+  protected void notifyReflectiveIdentification(DexMethod invokedMethod, ProgramMethod method) {
+    // Intentionally empty. Overridden in R8PartialUseCollector.
+  }
+
   public void traceClasses(Collection<DexProgramClass> classes) {
     for (DexProgramClass clazz : classes) {
       traceClass(clazz);
@@ -616,6 +620,18 @@
     }
   }
 
+  private void handleInvoke(
+      DexMethod method,
+      MethodResolutionResult resolutionResult,
+      Function<SingleResolutionResult<?>, DexClassAndMethod> getResult,
+      ProgramMethod context,
+      DefinitionContext referencedFrom,
+      UseCollectorEventConsumer eventConsumer) {
+    handleMethodResolution(
+        method, resolutionResult, getResult, context, referencedFrom, eventConsumer);
+    notifyReflectiveIdentification(method, context);
+  }
+
   private void handleMethodResolution(
       DexMethod method,
       MethodResolutionResult resolutionResult,
@@ -701,7 +717,7 @@
     @Override
     public void registerInvokeDirect(DexMethod method) {
       if (getContext().getHolder().originatesFromDexResource()) {
-        handleMethodResolution(
+        handleInvoke(
             method,
             appInfo().unsafeResolveMethodDueToDexFormat(method),
             SingleResolutionResult::getResolutionPair,
@@ -734,7 +750,7 @@
 
     @Override
     public void registerInvokeStatic(DexMethod method) {
-      handleMethodResolution(
+      handleInvoke(
           method,
           appInfo().unsafeResolveMethodDueToDexFormat(method),
           SingleResolutionResult::getResolutionPair,
@@ -745,7 +761,7 @@
 
     @Override
     public void registerInvokeSuper(DexMethod method) {
-      handleMethodResolution(
+      handleInvoke(
           method,
           appInfo().unsafeResolveMethodDueToDexFormat(method),
           result -> result.lookupInvokeSuperTarget(getContext().getHolder(), appView, appInfo()),
@@ -766,7 +782,7 @@
         return;
       }
       assert invokeType.isInterface() || invokeType.isVirtual();
-      handleMethodResolution(
+      handleInvoke(
           method,
           invokeType.isInterface()
               ? appInfo().resolveMethodOnInterfaceHolder(method)
diff --git a/src/test/java/com/android/tools/r8/partial/PartialCompilationReflectiveIdentificationTest.java b/src/test/java/com/android/tools/r8/partial/PartialCompilationReflectiveIdentificationTest.java
index 7c78081..bb4772d 100644
--- a/src/test/java/com/android/tools/r8/partial/PartialCompilationReflectiveIdentificationTest.java
+++ b/src/test/java/com/android/tools/r8/partial/PartialCompilationReflectiveIdentificationTest.java
@@ -31,8 +31,7 @@
         .addR8ExcludedClasses(ExcludedMain.class)
         .compile()
         .run(parameters.getRuntime(), ExcludedMain.class)
-        // TODO(b/414763615): Enable reflective identification for excluded classes.
-        .assertFailureWithErrorThatThrows(ClassNotFoundException.class);
+        .assertSuccessWithEmptyOutput();
   }
 
   static class ExcludedMain {