Do not finalize classes passed to Mockito.mock() / Mockito.spy()
Bug: b/389166093
Change-Id: I13c7406e4c1628336d0a5b199094288bdcd4bc29
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 0cf6dec..d02a4cd 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -400,6 +400,8 @@
// Method names used on MethodHandles.
public final DexString lookupString = createString("lookup");
public final DexString privateLookupInString = createString("privateLookupIn");
+ public final DexString mockString = createString("mock");
+ public final DexString spyString = createString("spy");
public final DexType booleanType = createStaticallyKnownType(booleanDescriptor);
public final DexType byteType = createStaticallyKnownType(byteDescriptor);
@@ -890,6 +892,7 @@
createStaticallyKnownType(desugarVarHandleDescriptorString);
public final DexType desugarMethodHandlesLookupType =
createStaticallyKnownType(desugarMethodHandlesLookupDescriptorString);
+ public final DexType mockitoType = createStaticallyKnownType("Lorg/mockito/Mockito;");
public final ObjectMethodsMembers objectMethodsMembers = new ObjectMethodsMembers();
public final ServiceLoaderMethods serviceLoaderMethods = new ServiceLoaderMethods();
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 86724f3..4baf2a5 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -1607,6 +1607,8 @@
} else if (dexItemFactory.serviceLoaderMethods.isLoadMethod(invokedMethod)) {
// Handling of application services.
pendingReflectiveUses.add(context);
+ } else if (EnqueuerMockitoSupport.isReflectiveMockInvoke(dexItemFactory, invokedMethod)) {
+ pendingReflectiveUses.add(context);
}
markTypeAsLive(invokedMethod.getHolderType(), context);
MethodResolutionResult resolutionResult =
@@ -5344,6 +5346,10 @@
handleServiceLoaderInvocation(method, invoke);
return;
}
+ if (EnqueuerMockitoSupport.isReflectiveMockInvoke(appView.dexItemFactory(), invokedMethod)) {
+ EnqueuerMockitoSupport.handleReflectiveMockInvoke(appView, keepInfo, method, invoke);
+ return;
+ }
if (!isReflectionMethod(dexItemFactory, invokedMethod)) {
return;
}
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoSupport.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoSupport.java
new file mode 100644
index 0000000..19be910
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoSupport.java
@@ -0,0 +1,82 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking;
+
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexDefinitionSupplier;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.ArrayTypeElement;
+import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.shaking.KeepInfoCollection.MutableKeepInfoCollection;
+
+class EnqueuerMockitoSupport {
+
+ static boolean isReflectiveMockInvoke(DexItemFactory dexItemFactory, DexMethod invokedMethod) {
+ return invokedMethod.holder.isIdenticalTo(dexItemFactory.mockitoType)
+ && (invokedMethod.getName().isIdenticalTo(dexItemFactory.mockString)
+ || invokedMethod.getName().isIdenticalTo(dexItemFactory.spyString));
+ }
+
+ /** Ensure classes passed to Mockito.mock() and Mockito.spy() are not marked as "final". */
+ static void handleReflectiveMockInvoke(
+ DexDefinitionSupplier appView,
+ MutableKeepInfoCollection keepInfo,
+ ProgramMethod context,
+ InvokeMethod invoke) {
+ DexMethod method = invoke.getInvokedMethod();
+ DexItemFactory dexItemFactory = appView.dexItemFactory();
+
+ DexType mockedType;
+ if (method.getParameter(0).isIdenticalTo(dexItemFactory.classType)) {
+ // Given an explicit const-cast
+ Value classValue = invoke.getFirstArgument();
+ if (!classValue.isConstClass()) {
+ return;
+ }
+ mockedType = classValue.getDefinition().asConstClass().getType();
+ } else if (method.getParameter(method.getArity() - 1).isArrayType()) {
+ // This should always be an empty array of the mocked type.
+ Value arrayValue = invoke.getLastArgument();
+ ArrayTypeElement arrayType = arrayValue.getType().asArrayType();
+ if (arrayType == null) {
+ // Should never happen.
+ return;
+ }
+ ClassTypeElement memberType = arrayType.getMemberType().asClassType();
+ if (memberType == null) {
+ return;
+ }
+ mockedType = memberType.getClassType();
+ } else {
+ // Should be Mockito.spy(Object).
+ if (method.getArity() != 1
+ || !method.getParameter(0).isIdenticalTo(dexItemFactory.objectType)) {
+ return;
+ }
+ Value objectValue = invoke.getFirstArgument();
+ if (objectValue == null || objectValue.isPhi()) {
+ return;
+ }
+ ClassTypeElement classType = objectValue.getType().asClassType();
+ if (classType == null) {
+ return;
+ }
+ mockedType = classType.toDexType(dexItemFactory);
+ }
+
+ DexClass dexClass = appView.definitionFor(mockedType, context);
+ if (dexClass == null || !dexClass.isProgramClass()) {
+ return;
+ }
+
+ // Make sure the type is not made final so that it can still be subclassed by Mockito.
+ keepInfo.joinClass(dexClass.asProgramClass(), joiner -> joiner.disallowOptimization());
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java b/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
new file mode 100644
index 0000000..346fa69
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
@@ -0,0 +1,217 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking.reflection;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isFinal;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isInterface;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.ShouldNotBeMergedImpl;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+/** Tests for MockitoStub.mock() can MockitoStub.spy(). */
+@RunWith(Parameterized.class)
+public class MockitoTest extends TestBase {
+
+ @Parameter(0)
+ public TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withDefaultRuntimes().withMaximumApiLevel().build();
+ }
+
+ private static final List<String> EXPECTED_OUTPUT =
+ Arrays.asList("A", "B", "C", "D", "E", "did thing");
+
+ public static class MockitoStub {
+ public static <T> T mock(Class<T> classToMock) {
+ return null;
+ }
+
+ public static <T> T mock(String name, T... reified) {
+ return null;
+ }
+
+ public static <T> T spy(Class<T> classToMock) {
+ return null;
+ }
+
+ public static <T> T spy(T... reified) {
+ return null;
+ }
+
+ public static <T> T spy(T classToMock) {
+ return null;
+ }
+ }
+
+ public static class Helpers {
+ public static class A {
+ @Override
+ public String toString() {
+ return "A";
+ }
+ }
+
+ public static class B {
+ @Override
+ public String toString() {
+ return "B";
+ }
+ }
+
+ public static class C {
+ @Override
+ public String toString() {
+ return "C";
+ }
+ }
+
+ public static class D {
+ @Override
+ public String toString() {
+ return "D";
+ }
+ }
+
+ public static class E {
+ @Override
+ public String toString() {
+ return "E";
+ }
+ }
+
+ public interface ShouldNotBeMerged {
+ void doThing();
+ }
+
+ public static class ShouldNotBeMergedImpl implements ShouldNotBeMerged {
+ @Override
+ public void doThing() {
+ System.out.println("did thing");
+ }
+ }
+ }
+
+ public static class TestMain {
+
+ @NeverInline
+ private static void mock1() {
+ MockitoStub.mock(Helpers.A.class);
+ System.out.println(new Helpers.A());
+ }
+
+ @NeverInline
+ private static void mock2() {
+ Helpers.B b = MockitoStub.mock("");
+ if (b == null) {
+ System.out.println(new Helpers.B());
+ }
+ }
+
+ @NeverInline
+ private static void spy1() {
+ MockitoStub.spy(Helpers.C.class);
+ System.out.println(new Helpers.C());
+ }
+
+ @NeverInline
+ private static void spy2() {
+ Helpers.D d = MockitoStub.spy();
+ if (d == null) {
+ System.out.println(new Helpers.D());
+ }
+ }
+
+ @NeverInline
+ private static void spy3() {
+ MockitoStub.spy(new Helpers.E());
+ System.out.println(new Helpers.E());
+ }
+
+ @NeverInline
+ private static void mockInterface() {
+ Helpers.ShouldNotBeMerged iface = MockitoStub.mock(Helpers.ShouldNotBeMerged.class);
+ if (iface == null) {
+ new ShouldNotBeMergedImpl().doThing();
+ }
+ }
+
+ public static void main(String[] args) {
+ // Use different methods to ensure Enqueuer.traceInvokeStatic() triggers for each one.
+ mock1();
+ mock2();
+ spy1();
+ spy2();
+ spy3();
+ mockInterface();
+ }
+ }
+
+ private static final String MOCKITO_DESCRIPTOR = "Lorg/mockito/Mockito;";
+
+ private static byte[] rewriteTestMain() throws IOException {
+ return transformer(TestMain.class)
+ .replaceClassDescriptorInMethodInstructions(
+ descriptor(MockitoStub.class), MOCKITO_DESCRIPTOR)
+ .transform();
+ }
+
+ private static byte[] rewriteMockito() throws IOException {
+ return transformer(MockitoStub.class).setClassDescriptor(MOCKITO_DESCRIPTOR).transform();
+ }
+
+ @Test
+ public void testRuntime() throws Exception {
+ byte[] mockitoClassBytes = rewriteMockito();
+ testForRuntime(parameters)
+ .addProgramClassesAndInnerClasses(Helpers.class)
+ .addProgramClassFileData(rewriteTestMain())
+ .addClasspathClassFileData(mockitoClassBytes)
+ .addRunClasspathFiles(buildOnDexRuntime(parameters, mockitoClassBytes))
+ .run(parameters.getRuntime(), TestMain.class)
+ .assertSuccessWithOutputLines(EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ byte[] mockitoClassBytes = rewriteMockito();
+ testForR8(parameters.getBackend())
+ .setMinApi(parameters)
+ .addProgramClassesAndInnerClasses(Helpers.class)
+ .addProgramClassFileData(rewriteTestMain())
+ .addClasspathClassFileData(mockitoClassBytes)
+ .enableInliningAnnotations()
+ .addKeepMainRule(TestMain.class)
+ .compile()
+ .inspect(
+ inspector -> {
+ assertThat(inspector.clazz(Helpers.ShouldNotBeMerged.class), isInterface());
+ inspector.forAllClasses(
+ clazz -> {
+ String className = clazz.getOriginalTypeName();
+ if (!className.endsWith("TestMain") && !className.endsWith("Impl")) {
+ assertThat(clazz.getOriginalTypeName(), clazz, not(isFinal()));
+ }
+ });
+ })
+ .addRunClasspathClassFileData(mockitoClassBytes)
+ .run(parameters.getRuntime(), TestMain.class)
+ .assertSuccessWithOutputLines(EXPECTED_OUTPUT);
+ }
+}