Support taking custom class for R8Assistant callbacks

Bug: b/393265921
Change-Id: Iae00494550edd770e7dca4f0dbaf76f5df74f8e5
diff --git a/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java b/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
index ac46d0a..e653b72 100644
--- a/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
+++ b/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
@@ -17,13 +17,19 @@
       // TODO(b/393249304): Support injecting alternative receiver.
       synchronized (instanceLock) {
         if (INSTANCE == null) {
-          INSTANCE = new ReflectiveOperationLogger();
+          INSTANCE = getReceiver();
         }
       }
     }
     return INSTANCE;
   }
 
+  // Might be rewritten to call new instance on a custom receiver.
+  private static ReflectiveOperationReceiver getReceiver() {
+    // Default, might be replaced, don't change this without changing the instrumentation
+    return new ReflectiveOperationLogger();
+  }
+
   @KeepForApi
   public static class Stack {
 
diff --git a/src/main/java/com/android/tools/r8/R8Assistant.java b/src/main/java/com/android/tools/r8/R8Assistant.java
index 2e26ebe..30092f1 100644
--- a/src/main/java/com/android/tools/r8/R8Assistant.java
+++ b/src/main/java/com/android/tools/r8/R8Assistant.java
@@ -65,6 +65,10 @@
       reflectiveInstrumentation.instrumentClasses();
       // Convert cf classes
       converter.convert(appView, executorService);
+      if (command.getReflectiveReceiverDescriptor() != null) {
+        reflectiveInstrumentation.updateReflectiveReceiver(
+            command.getReflectiveReceiverDescriptor());
+      }
       SyntheticFinalization.finalize(appView, timing, executorService);
       ApplicationWriter writer = ApplicationWriter.create(appView, options.getMarker());
       writer.write(executorService);
diff --git a/src/main/java/com/android/tools/r8/R8AssistantCommand.java b/src/main/java/com/android/tools/r8/R8AssistantCommand.java
index f5dba70..f89a8e2 100644
--- a/src/main/java/com/android/tools/r8/R8AssistantCommand.java
+++ b/src/main/java/com/android/tools/r8/R8AssistantCommand.java
@@ -16,6 +16,7 @@
 import com.android.tools.r8.origin.SynthesizedOrigin;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.AndroidApp;
+import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.DumpInputFlags;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.InternalOptions.DesugarState;
@@ -30,12 +31,15 @@
 @KeepForApi
 public class R8AssistantCommand extends BaseCompilerCommand {
 
+  private final String reflectiveReceiverDescriptor;
+
   public R8AssistantCommand(
       AndroidApp app,
       CompilationMode mode,
       ProgramConsumer programConsumer,
       int minApiLevel,
-      Reporter reporter) {
+      Reporter reporter,
+      String reflectiveReceiverDescriptor) {
     super(
         app,
         mode,
@@ -58,6 +62,7 @@
         Collections.emptyList(),
         null,
         null);
+    this.reflectiveReceiverDescriptor = reflectiveReceiverDescriptor;
   }
 
   public static Builder builder(DiagnosticsHandler reporter) {
@@ -83,6 +88,10 @@
     return options;
   }
 
+  public String getReflectiveReceiverDescriptor() {
+    return reflectiveReceiverDescriptor;
+  }
+
   /**
    * This is an experimental API for injecting reflective identification callbacks into dex code.
    * This API is subject to change.
@@ -90,6 +99,8 @@
   @KeepForApi
   public static class Builder extends BaseCompilerCommand.Builder<R8AssistantCommand, Builder> {
 
+    private String reflectiveReceiverDescriptor;
+
     private Builder() {
       this(new DiagnosticsHandler() {});
     }
@@ -104,6 +115,20 @@
       }
     }
 
+    public Builder addReflectiveOperationReceiverInput(ProgramResourceProvider provider) {
+      // This code is simply added to the program input, will be called by the ReflectiveOracle.
+      addProgramResourceProvider(provider);
+      return self();
+    }
+
+    public Builder setReflectiveReceiverClassDescriptor(String descriptor) {
+      if (!DescriptorUtils.isClassDescriptor(descriptor)) {
+        getReporter().error("Not a valid descriptor " + descriptor);
+      }
+      this.reflectiveReceiverDescriptor = descriptor;
+      return self();
+    }
+
     @Override
     CompilationMode defaultCompilationMode() {
       return CompilationMode.RELEASE;
@@ -138,7 +163,8 @@
           getMode(),
           getProgramConsumer(),
           getMinApiLevel(),
-          getReporter());
+          getReporter(),
+          reflectiveReceiverDescriptor);
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java b/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
index 1df0015..d7f9326 100644
--- a/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
+++ b/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
@@ -3,22 +3,26 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.assistant;
 
-import com.android.tools.r8.assistant.runtime.ReflectiveOracle;
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.code.BasicBlockInstructionListIterator;
 import com.android.tools.r8.ir.code.BasicBlockIterator;
 import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.IRCodeInstructionListIterator;
 import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InvokeDirect;
 import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.InvokeVirtual;
+import com.android.tools.r8.ir.code.NewInstance;
+import com.android.tools.r8.ir.code.Return;
+import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.PrimaryD8L8IRConverter;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
-import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.Timing;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
@@ -29,16 +33,41 @@
   private final PrimaryD8L8IRConverter converter;
   private final DexItemFactory dexItemFactory;
   private final Timing timing;
-  private final DexType reflectiveOracleType;
+  private final ReflectiveReferences reflectiveReferences;
 
   public ReflectiveInstrumentation(
       AppView<AppInfo> appView, PrimaryD8L8IRConverter converter, Timing timing) {
     this.appView = appView;
     this.dexItemFactory = appView.dexItemFactory();
+    this.reflectiveReferences = new ReflectiveReferences(dexItemFactory);
     this.converter = converter;
     this.timing = timing;
-    this.reflectiveOracleType =
-        dexItemFactory.createType(DescriptorUtils.javaClassToDescriptor(ReflectiveOracle.class));
+  }
+
+  public void updateReflectiveReceiver(String customReflectiveReceiverDescriptor) {
+    ProgramMethod getReceiver =
+        appView.definitionFor(reflectiveReferences.getReceiverMethod).asProgramMethod();
+    IRCode code = getReceiver.buildIR(appView);
+    assert code.streamInstructions().count() == 3;
+    DexType replacementType = dexItemFactory.createType(customReflectiveReceiverDescriptor);
+    IRCodeInstructionListIterator instructionListIterator = code.instructionListIterator();
+    instructionListIterator.next();
+    NewInstance newInstanceReplacement =
+        NewInstance.builder()
+            .setType(replacementType)
+            .setFreshOutValue(code, replacementType.toNonNullTypeElement(appView))
+            .build();
+    Value newInstanceOutValue = newInstanceReplacement.outValue();
+    instructionListIterator.replaceCurrentInstruction(newInstanceReplacement);
+    instructionListIterator.next();
+    DexMethod method = dexItemFactory.createInstanceInitializer(replacementType);
+    InvokeDirect invokeDirect =
+        InvokeDirect.builder().setMethod(method).setArguments(newInstanceOutValue).build();
+    instructionListIterator.replaceCurrentInstruction(invokeDirect);
+    instructionListIterator.next();
+    Return newReturn = Return.builder().setReturnValue(newInstanceOutValue).build();
+    instructionListIterator.replaceCurrentInstruction(newReturn);
+    converter.removeDeadCodeAndFinalizeIR(code, OptimizationFeedback.getIgnoreFeedback(), timing);
   }
 
   // TODO(b/394013779): Do this in parallel.
@@ -93,14 +122,14 @@
 
   private DexMethod getMethodReferenceWithClassParameter(String name) {
     return dexItemFactory.createMethod(
-        reflectiveOracleType,
+        reflectiveReferences.reflectiveOracleType,
         dexItemFactory.createProto(dexItemFactory.voidType, dexItemFactory.classType),
         name);
   }
 
   private DexMethod getMethodReferenceWithClassMethodNameAndParameters(String name) {
     return dexItemFactory.createMethod(
-        reflectiveOracleType,
+        reflectiveReferences.reflectiveOracleType,
         dexItemFactory.createProto(
             dexItemFactory.voidType,
             dexItemFactory.classType,
diff --git a/src/main/java/com/android/tools/r8/assistant/ReflectiveReferences.java b/src/main/java/com/android/tools/r8/assistant/ReflectiveReferences.java
new file mode 100644
index 0000000..039c2f2
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/assistant/ReflectiveReferences.java
@@ -0,0 +1,30 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.assistant;
+
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexType;
+
+public class ReflectiveReferences {
+
+  public final DexType reflectiveOracleType;
+  public final DexType reflectiveOperationReceiverType;
+  public final DexMethod getReceiverMethod;
+
+  public ReflectiveReferences(DexItemFactory factory) {
+    this.reflectiveOracleType = factory.createType(getDescriptor("ReflectiveOracle"));
+    this.reflectiveOperationReceiverType =
+        factory.createType(getDescriptor("ReflectiveOperationReceiver"));
+    this.getReceiverMethod =
+        factory.createMethod(
+            reflectiveOracleType,
+            factory.createProto(reflectiveOperationReceiverType),
+            "getReceiver");
+  }
+
+  private static String getDescriptor(String className) {
+    return "Lcom/android/tools/r8/assistant/runtime/" + className + ";";
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java b/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
index 3a578c3..15872c0 100644
--- a/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
+++ b/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
@@ -13,6 +13,8 @@
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.assistant.runtime.ReflectiveOperationReceiver;
+import com.android.tools.r8.assistant.runtime.ReflectiveOracle.Stack;
 import com.android.tools.r8.utils.ZipUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
@@ -59,7 +61,7 @@
   @Test
   public void testInstrumentation() throws Exception {
     testForAssistant()
-        .addInnerClasses(getClass())
+        .addProgramClasses(TestClass.class, Foo.class, Bar.class)
         .setMinApi(parameters)
         .compile()
         .inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(0, inspector))
@@ -70,6 +72,21 @@
             "Reflectively got declared method callMe on " + Bar.class.getName());
   }
 
+  @Test
+  public void testInstrumentationWithCustomOracle() throws Exception {
+    testForAssistant()
+        .addProgramClasses(TestClass.class, Foo.class, Bar.class)
+        .addInstrumentationClasses(InstrumentationClass.class)
+        .setCustomReflectiveOperationReceiver(InstrumentationClass.class)
+        .setMinApi(parameters)
+        .compile()
+        .inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(0, inspector))
+        .inspect(inspector -> inspectStaticCallsInReflectOn(2, inspector))
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines(
+            "Custom receiver " + Bar.class.getName(), "Custom receiver method callMe");
+  }
+
   private static void inspectStaticCallsInReflectOn(int count, CodeInspector inspector) {
     ClassSubject testClass = inspector.clazz(TestClass.class);
     assertThat(testClass, isPresent());
@@ -80,6 +97,20 @@
     assertEquals(count, codeCount);
   }
 
+  public static class InstrumentationClass implements ReflectiveOperationReceiver {
+
+    @Override
+    public void onClassNewInstance(Stack stack, Class<?> clazz) {
+      System.out.println("Custom receiver " + clazz.getName());
+    }
+
+    @Override
+    public void onClassGetDeclaredMethod(
+        Stack stack, Class<?> clazz, String method, Class<?>... parameters) {
+      System.out.println("Custom receiver method " + method);
+    }
+  }
+
   static class TestClass {
     public static void main(String[] args) {
       reflectOn(System.currentTimeMillis() == 0 ? Foo.class : Bar.class);
diff --git a/src/test/testbase/java/com/android/tools/r8/AssistantTestBuilder.java b/src/test/testbase/java/com/android/tools/r8/AssistantTestBuilder.java
index abecb6b..c007ec3 100644
--- a/src/test/testbase/java/com/android/tools/r8/AssistantTestBuilder.java
+++ b/src/test/testbase/java/com/android/tools/r8/AssistantTestBuilder.java
@@ -3,7 +3,9 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8;
 
+import static com.android.tools.r8.TestBase.descriptor;
 import static com.android.tools.r8.TestBase.testForD8;
+import static com.android.tools.r8.TestBase.writeClassesToJar;
 
 import com.android.tools.r8.TestBase.Backend;
 import com.android.tools.r8.benchmarks.BenchmarkResults;
@@ -13,7 +15,10 @@
 import java.io.IOException;
 import java.io.UncheckedIOException;
 import java.nio.file.Path;
+import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 
@@ -27,6 +32,8 @@
 
   private final D8TestBuilder initialCompileBuilder;
   private Path output;
+  private String customReflectiveOperationReceiver = null;
+  private List<Class<?>> customReflectiveOperationInputClasses = new ArrayList<>();
 
   private AssistantTestBuilder(TestState state) {
     super(state, R8AssistantCommand.builder(state.getDiagnosticsHandler()), Backend.DEX);
@@ -52,6 +59,11 @@
     throw new Unimplemented("No classpath for assistant");
   }
 
+  public AssistantTestBuilder addInstrumentationClasses(Class<?>... classes) {
+    Collections.addAll(customReflectiveOperationInputClasses, classes);
+    return self();
+  }
+
   @Override
   public AssistantTestBuilder addProgramFiles(Collection<Path> files) {
     initialCompileBuilder.addProgramFiles(files);
@@ -71,14 +83,23 @@
       if (output == null) {
         output = getState().getNewTempFile("assistant_output.jar");
       }
+      if (!customReflectiveOperationInputClasses.isEmpty()) {
+        builder.addReflectiveOperationReceiverInput(
+            ArchiveProgramResourceProvider.fromArchive(
+                writeClassesToJar(customReflectiveOperationInputClasses)));
+      }
     } catch (IOException e) {
       throw new UncheckedIOException(e);
     }
+
     builder
         .addProgramFiles(initialCompilation)
         .setOutput(output, OutputMode.DexIndexed)
         .setMinApiLevel(getMinApiLevel());
 
+    if (customReflectiveOperationReceiver != null) {
+      builder.setReflectiveReceiverClassDescriptor(customReflectiveOperationReceiver);
+    }
     R8Assistant.run(builder.build());
     return new AssistantTestCompileResult(
         initialCompilation,
@@ -86,4 +107,16 @@
         AndroidApp.builder().addProgramFiles(output).build(),
         getMinApiLevel());
   }
+
+  public AssistantTestBuilder setCustomReflectiveOperationReceiver(
+      String customReflectiveOperationReceiver) {
+    this.customReflectiveOperationReceiver = customReflectiveOperationReceiver;
+    return self();
+  }
+
+  public AssistantTestBuilder setCustomReflectiveOperationReceiver(
+      Class<?> customReflectiveOperationReceiver) {
+    this.customReflectiveOperationReceiver = descriptor(customReflectiveOperationReceiver);
+    return self();
+  }
 }
diff --git a/src/test/testbase/java/com/android/tools/r8/TestBase.java b/src/test/testbase/java/com/android/tools/r8/TestBase.java
index 2f6a150..fffa833 100644
--- a/src/test/testbase/java/com/android/tools/r8/TestBase.java
+++ b/src/test/testbase/java/com/android/tools/r8/TestBase.java
@@ -1429,9 +1429,15 @@
     consumer.finished(null);
   }
 
+
   protected static Path writeClassesToJar(Class<?>... classes) throws IOException {
+    List<Class<?>> classesCollection = Arrays.asList(classes);
+    return writeClassesToJar(classesCollection);
+  }
+
+  protected static Path writeClassesToJar(List<Class<?>> classesCollection) throws IOException {
     Path jar = staticTemp.newFolder().toPath().resolve("classes.jar");
-    writeClassesToJar(jar, Arrays.asList(classes));
+    writeClassesToJar(jar, classesCollection);
     return jar;
   }