Support taking custom class for R8Assistant callbacks
Bug: b/393265921
Change-Id: Iae00494550edd770e7dca4f0dbaf76f5df74f8e5
diff --git a/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java b/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
index ac46d0a..e653b72 100644
--- a/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
+++ b/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
@@ -17,13 +17,19 @@
// TODO(b/393249304): Support injecting alternative receiver.
synchronized (instanceLock) {
if (INSTANCE == null) {
- INSTANCE = new ReflectiveOperationLogger();
+ INSTANCE = getReceiver();
}
}
}
return INSTANCE;
}
+ // Might be rewritten to call new instance on a custom receiver.
+ private static ReflectiveOperationReceiver getReceiver() {
+ // Default, might be replaced, don't change this without changing the instrumentation
+ return new ReflectiveOperationLogger();
+ }
+
@KeepForApi
public static class Stack {
diff --git a/src/main/java/com/android/tools/r8/R8Assistant.java b/src/main/java/com/android/tools/r8/R8Assistant.java
index 2e26ebe..30092f1 100644
--- a/src/main/java/com/android/tools/r8/R8Assistant.java
+++ b/src/main/java/com/android/tools/r8/R8Assistant.java
@@ -65,6 +65,10 @@
reflectiveInstrumentation.instrumentClasses();
// Convert cf classes
converter.convert(appView, executorService);
+ if (command.getReflectiveReceiverDescriptor() != null) {
+ reflectiveInstrumentation.updateReflectiveReceiver(
+ command.getReflectiveReceiverDescriptor());
+ }
SyntheticFinalization.finalize(appView, timing, executorService);
ApplicationWriter writer = ApplicationWriter.create(appView, options.getMarker());
writer.write(executorService);
diff --git a/src/main/java/com/android/tools/r8/R8AssistantCommand.java b/src/main/java/com/android/tools/r8/R8AssistantCommand.java
index f5dba70..f89a8e2 100644
--- a/src/main/java/com/android/tools/r8/R8AssistantCommand.java
+++ b/src/main/java/com/android/tools/r8/R8AssistantCommand.java
@@ -16,6 +16,7 @@
import com.android.tools.r8.origin.SynthesizedOrigin;
import com.android.tools.r8.utils.AndroidApiLevel;
import com.android.tools.r8.utils.AndroidApp;
+import com.android.tools.r8.utils.DescriptorUtils;
import com.android.tools.r8.utils.DumpInputFlags;
import com.android.tools.r8.utils.InternalOptions;
import com.android.tools.r8.utils.InternalOptions.DesugarState;
@@ -30,12 +31,15 @@
@KeepForApi
public class R8AssistantCommand extends BaseCompilerCommand {
+ private final String reflectiveReceiverDescriptor;
+
public R8AssistantCommand(
AndroidApp app,
CompilationMode mode,
ProgramConsumer programConsumer,
int minApiLevel,
- Reporter reporter) {
+ Reporter reporter,
+ String reflectiveReceiverDescriptor) {
super(
app,
mode,
@@ -58,6 +62,7 @@
Collections.emptyList(),
null,
null);
+ this.reflectiveReceiverDescriptor = reflectiveReceiverDescriptor;
}
public static Builder builder(DiagnosticsHandler reporter) {
@@ -83,6 +88,10 @@
return options;
}
+ public String getReflectiveReceiverDescriptor() {
+ return reflectiveReceiverDescriptor;
+ }
+
/**
* This is an experimental API for injecting reflective identification callbacks into dex code.
* This API is subject to change.
@@ -90,6 +99,8 @@
@KeepForApi
public static class Builder extends BaseCompilerCommand.Builder<R8AssistantCommand, Builder> {
+ private String reflectiveReceiverDescriptor;
+
private Builder() {
this(new DiagnosticsHandler() {});
}
@@ -104,6 +115,20 @@
}
}
+ public Builder addReflectiveOperationReceiverInput(ProgramResourceProvider provider) {
+ // This code is simply added to the program input, will be called by the ReflectiveOracle.
+ addProgramResourceProvider(provider);
+ return self();
+ }
+
+ public Builder setReflectiveReceiverClassDescriptor(String descriptor) {
+ if (!DescriptorUtils.isClassDescriptor(descriptor)) {
+ getReporter().error("Not a valid descriptor " + descriptor);
+ }
+ this.reflectiveReceiverDescriptor = descriptor;
+ return self();
+ }
+
@Override
CompilationMode defaultCompilationMode() {
return CompilationMode.RELEASE;
@@ -138,7 +163,8 @@
getMode(),
getProgramConsumer(),
getMinApiLevel(),
- getReporter());
+ getReporter(),
+ reflectiveReceiverDescriptor);
}
}
}
diff --git a/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java b/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
index 1df0015..d7f9326 100644
--- a/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
+++ b/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
@@ -3,22 +3,26 @@
// BSD-style license that can be found in the LICENSE file.
package com.android.tools.r8.assistant;
-import com.android.tools.r8.assistant.runtime.ReflectiveOracle;
import com.android.tools.r8.graph.AppInfo;
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexProgramClass;
import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
import com.android.tools.r8.ir.code.BasicBlockInstructionListIterator;
import com.android.tools.r8.ir.code.BasicBlockIterator;
import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.IRCodeInstructionListIterator;
import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InvokeDirect;
import com.android.tools.r8.ir.code.InvokeStatic;
import com.android.tools.r8.ir.code.InvokeVirtual;
+import com.android.tools.r8.ir.code.NewInstance;
+import com.android.tools.r8.ir.code.Return;
+import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.ir.conversion.PrimaryD8L8IRConverter;
import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
-import com.android.tools.r8.utils.DescriptorUtils;
import com.android.tools.r8.utils.Timing;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
@@ -29,16 +33,41 @@
private final PrimaryD8L8IRConverter converter;
private final DexItemFactory dexItemFactory;
private final Timing timing;
- private final DexType reflectiveOracleType;
+ private final ReflectiveReferences reflectiveReferences;
public ReflectiveInstrumentation(
AppView<AppInfo> appView, PrimaryD8L8IRConverter converter, Timing timing) {
this.appView = appView;
this.dexItemFactory = appView.dexItemFactory();
+ this.reflectiveReferences = new ReflectiveReferences(dexItemFactory);
this.converter = converter;
this.timing = timing;
- this.reflectiveOracleType =
- dexItemFactory.createType(DescriptorUtils.javaClassToDescriptor(ReflectiveOracle.class));
+ }
+
+ public void updateReflectiveReceiver(String customReflectiveReceiverDescriptor) {
+ ProgramMethod getReceiver =
+ appView.definitionFor(reflectiveReferences.getReceiverMethod).asProgramMethod();
+ IRCode code = getReceiver.buildIR(appView);
+ assert code.streamInstructions().count() == 3;
+ DexType replacementType = dexItemFactory.createType(customReflectiveReceiverDescriptor);
+ IRCodeInstructionListIterator instructionListIterator = code.instructionListIterator();
+ instructionListIterator.next();
+ NewInstance newInstanceReplacement =
+ NewInstance.builder()
+ .setType(replacementType)
+ .setFreshOutValue(code, replacementType.toNonNullTypeElement(appView))
+ .build();
+ Value newInstanceOutValue = newInstanceReplacement.outValue();
+ instructionListIterator.replaceCurrentInstruction(newInstanceReplacement);
+ instructionListIterator.next();
+ DexMethod method = dexItemFactory.createInstanceInitializer(replacementType);
+ InvokeDirect invokeDirect =
+ InvokeDirect.builder().setMethod(method).setArguments(newInstanceOutValue).build();
+ instructionListIterator.replaceCurrentInstruction(invokeDirect);
+ instructionListIterator.next();
+ Return newReturn = Return.builder().setReturnValue(newInstanceOutValue).build();
+ instructionListIterator.replaceCurrentInstruction(newReturn);
+ converter.removeDeadCodeAndFinalizeIR(code, OptimizationFeedback.getIgnoreFeedback(), timing);
}
// TODO(b/394013779): Do this in parallel.
@@ -93,14 +122,14 @@
private DexMethod getMethodReferenceWithClassParameter(String name) {
return dexItemFactory.createMethod(
- reflectiveOracleType,
+ reflectiveReferences.reflectiveOracleType,
dexItemFactory.createProto(dexItemFactory.voidType, dexItemFactory.classType),
name);
}
private DexMethod getMethodReferenceWithClassMethodNameAndParameters(String name) {
return dexItemFactory.createMethod(
- reflectiveOracleType,
+ reflectiveReferences.reflectiveOracleType,
dexItemFactory.createProto(
dexItemFactory.voidType,
dexItemFactory.classType,
diff --git a/src/main/java/com/android/tools/r8/assistant/ReflectiveReferences.java b/src/main/java/com/android/tools/r8/assistant/ReflectiveReferences.java
new file mode 100644
index 0000000..039c2f2
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/assistant/ReflectiveReferences.java
@@ -0,0 +1,30 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.assistant;
+
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexType;
+
+public class ReflectiveReferences {
+
+ public final DexType reflectiveOracleType;
+ public final DexType reflectiveOperationReceiverType;
+ public final DexMethod getReceiverMethod;
+
+ public ReflectiveReferences(DexItemFactory factory) {
+ this.reflectiveOracleType = factory.createType(getDescriptor("ReflectiveOracle"));
+ this.reflectiveOperationReceiverType =
+ factory.createType(getDescriptor("ReflectiveOperationReceiver"));
+ this.getReceiverMethod =
+ factory.createMethod(
+ reflectiveOracleType,
+ factory.createProto(reflectiveOperationReceiverType),
+ "getReceiver");
+ }
+
+ private static String getDescriptor(String className) {
+ return "Lcom/android/tools/r8/assistant/runtime/" + className + ";";
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java b/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
index 3a578c3..15872c0 100644
--- a/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
+++ b/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
@@ -13,6 +13,8 @@
import com.android.tools.r8.TestParameters;
import com.android.tools.r8.TestParametersCollection;
import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.assistant.runtime.ReflectiveOperationReceiver;
+import com.android.tools.r8.assistant.runtime.ReflectiveOracle.Stack;
import com.android.tools.r8.utils.ZipUtils;
import com.android.tools.r8.utils.codeinspector.ClassSubject;
import com.android.tools.r8.utils.codeinspector.CodeInspector;
@@ -59,7 +61,7 @@
@Test
public void testInstrumentation() throws Exception {
testForAssistant()
- .addInnerClasses(getClass())
+ .addProgramClasses(TestClass.class, Foo.class, Bar.class)
.setMinApi(parameters)
.compile()
.inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(0, inspector))
@@ -70,6 +72,21 @@
"Reflectively got declared method callMe on " + Bar.class.getName());
}
+ @Test
+ public void testInstrumentationWithCustomOracle() throws Exception {
+ testForAssistant()
+ .addProgramClasses(TestClass.class, Foo.class, Bar.class)
+ .addInstrumentationClasses(InstrumentationClass.class)
+ .setCustomReflectiveOperationReceiver(InstrumentationClass.class)
+ .setMinApi(parameters)
+ .compile()
+ .inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(0, inspector))
+ .inspect(inspector -> inspectStaticCallsInReflectOn(2, inspector))
+ .run(parameters.getRuntime(), TestClass.class)
+ .assertSuccessWithOutputLines(
+ "Custom receiver " + Bar.class.getName(), "Custom receiver method callMe");
+ }
+
private static void inspectStaticCallsInReflectOn(int count, CodeInspector inspector) {
ClassSubject testClass = inspector.clazz(TestClass.class);
assertThat(testClass, isPresent());
@@ -80,6 +97,20 @@
assertEquals(count, codeCount);
}
+ public static class InstrumentationClass implements ReflectiveOperationReceiver {
+
+ @Override
+ public void onClassNewInstance(Stack stack, Class<?> clazz) {
+ System.out.println("Custom receiver " + clazz.getName());
+ }
+
+ @Override
+ public void onClassGetDeclaredMethod(
+ Stack stack, Class<?> clazz, String method, Class<?>... parameters) {
+ System.out.println("Custom receiver method " + method);
+ }
+ }
+
static class TestClass {
public static void main(String[] args) {
reflectOn(System.currentTimeMillis() == 0 ? Foo.class : Bar.class);
diff --git a/src/test/testbase/java/com/android/tools/r8/AssistantTestBuilder.java b/src/test/testbase/java/com/android/tools/r8/AssistantTestBuilder.java
index abecb6b..c007ec3 100644
--- a/src/test/testbase/java/com/android/tools/r8/AssistantTestBuilder.java
+++ b/src/test/testbase/java/com/android/tools/r8/AssistantTestBuilder.java
@@ -3,7 +3,9 @@
// BSD-style license that can be found in the LICENSE file.
package com.android.tools.r8;
+import static com.android.tools.r8.TestBase.descriptor;
import static com.android.tools.r8.TestBase.testForD8;
+import static com.android.tools.r8.TestBase.writeClassesToJar;
import com.android.tools.r8.TestBase.Backend;
import com.android.tools.r8.benchmarks.BenchmarkResults;
@@ -13,7 +15,10 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Path;
+import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
import java.util.function.Consumer;
import java.util.function.Supplier;
@@ -27,6 +32,8 @@
private final D8TestBuilder initialCompileBuilder;
private Path output;
+ private String customReflectiveOperationReceiver = null;
+ private List<Class<?>> customReflectiveOperationInputClasses = new ArrayList<>();
private AssistantTestBuilder(TestState state) {
super(state, R8AssistantCommand.builder(state.getDiagnosticsHandler()), Backend.DEX);
@@ -52,6 +59,11 @@
throw new Unimplemented("No classpath for assistant");
}
+ public AssistantTestBuilder addInstrumentationClasses(Class<?>... classes) {
+ Collections.addAll(customReflectiveOperationInputClasses, classes);
+ return self();
+ }
+
@Override
public AssistantTestBuilder addProgramFiles(Collection<Path> files) {
initialCompileBuilder.addProgramFiles(files);
@@ -71,14 +83,23 @@
if (output == null) {
output = getState().getNewTempFile("assistant_output.jar");
}
+ if (!customReflectiveOperationInputClasses.isEmpty()) {
+ builder.addReflectiveOperationReceiverInput(
+ ArchiveProgramResourceProvider.fromArchive(
+ writeClassesToJar(customReflectiveOperationInputClasses)));
+ }
} catch (IOException e) {
throw new UncheckedIOException(e);
}
+
builder
.addProgramFiles(initialCompilation)
.setOutput(output, OutputMode.DexIndexed)
.setMinApiLevel(getMinApiLevel());
+ if (customReflectiveOperationReceiver != null) {
+ builder.setReflectiveReceiverClassDescriptor(customReflectiveOperationReceiver);
+ }
R8Assistant.run(builder.build());
return new AssistantTestCompileResult(
initialCompilation,
@@ -86,4 +107,16 @@
AndroidApp.builder().addProgramFiles(output).build(),
getMinApiLevel());
}
+
+ public AssistantTestBuilder setCustomReflectiveOperationReceiver(
+ String customReflectiveOperationReceiver) {
+ this.customReflectiveOperationReceiver = customReflectiveOperationReceiver;
+ return self();
+ }
+
+ public AssistantTestBuilder setCustomReflectiveOperationReceiver(
+ Class<?> customReflectiveOperationReceiver) {
+ this.customReflectiveOperationReceiver = descriptor(customReflectiveOperationReceiver);
+ return self();
+ }
}
diff --git a/src/test/testbase/java/com/android/tools/r8/TestBase.java b/src/test/testbase/java/com/android/tools/r8/TestBase.java
index 2f6a150..fffa833 100644
--- a/src/test/testbase/java/com/android/tools/r8/TestBase.java
+++ b/src/test/testbase/java/com/android/tools/r8/TestBase.java
@@ -1429,9 +1429,15 @@
consumer.finished(null);
}
+
protected static Path writeClassesToJar(Class<?>... classes) throws IOException {
+ List<Class<?>> classesCollection = Arrays.asList(classes);
+ return writeClassesToJar(classesCollection);
+ }
+
+ protected static Path writeClassesToJar(List<Class<?>> classesCollection) throws IOException {
Path jar = staticTemp.newFolder().toPath().resolve("classes.jar");
- writeClassesToJar(jar, Arrays.asList(classes));
+ writeClassesToJar(jar, classesCollection);
return jar;
}