Do not finalize classes passed to Mockito.mock() / Mockito.spy()

Bug: b/389166093
Change-Id: I13c7406e4c1628336d0a5b199094288bdcd4bc29
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 0cf6dec..d02a4cd 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -400,6 +400,8 @@
   // Method names used on MethodHandles.
   public final DexString lookupString = createString("lookup");
   public final DexString privateLookupInString = createString("privateLookupIn");
+  public final DexString mockString = createString("mock");
+  public final DexString spyString = createString("spy");
 
   public final DexType booleanType = createStaticallyKnownType(booleanDescriptor);
   public final DexType byteType = createStaticallyKnownType(byteDescriptor);
@@ -890,6 +892,7 @@
       createStaticallyKnownType(desugarVarHandleDescriptorString);
   public final DexType desugarMethodHandlesLookupType =
       createStaticallyKnownType(desugarMethodHandlesLookupDescriptorString);
+  public final DexType mockitoType = createStaticallyKnownType("Lorg/mockito/Mockito;");
 
   public final ObjectMethodsMembers objectMethodsMembers = new ObjectMethodsMembers();
   public final ServiceLoaderMethods serviceLoaderMethods = new ServiceLoaderMethods();
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 86724f3..4baf2a5 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -1607,6 +1607,8 @@
     } else if (dexItemFactory.serviceLoaderMethods.isLoadMethod(invokedMethod)) {
       // Handling of application services.
       pendingReflectiveUses.add(context);
+    } else if (EnqueuerMockitoSupport.isReflectiveMockInvoke(dexItemFactory, invokedMethod)) {
+      pendingReflectiveUses.add(context);
     }
     markTypeAsLive(invokedMethod.getHolderType(), context);
     MethodResolutionResult resolutionResult =
@@ -5344,6 +5346,10 @@
       handleServiceLoaderInvocation(method, invoke);
       return;
     }
+    if (EnqueuerMockitoSupport.isReflectiveMockInvoke(appView.dexItemFactory(), invokedMethod)) {
+      EnqueuerMockitoSupport.handleReflectiveMockInvoke(appView, keepInfo, method, invoke);
+      return;
+    }
     if (!isReflectionMethod(dexItemFactory, invokedMethod)) {
       return;
     }
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoSupport.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoSupport.java
new file mode 100644
index 0000000..19be910
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoSupport.java
@@ -0,0 +1,82 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking;
+
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexDefinitionSupplier;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.ArrayTypeElement;
+import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.shaking.KeepInfoCollection.MutableKeepInfoCollection;
+
+class EnqueuerMockitoSupport {
+
+  static boolean isReflectiveMockInvoke(DexItemFactory dexItemFactory, DexMethod invokedMethod) {
+    return invokedMethod.holder.isIdenticalTo(dexItemFactory.mockitoType)
+        && (invokedMethod.getName().isIdenticalTo(dexItemFactory.mockString)
+            || invokedMethod.getName().isIdenticalTo(dexItemFactory.spyString));
+  }
+
+  /** Ensure classes passed to Mockito.mock() and Mockito.spy() are not marked as "final". */
+  static void handleReflectiveMockInvoke(
+      DexDefinitionSupplier appView,
+      MutableKeepInfoCollection keepInfo,
+      ProgramMethod context,
+      InvokeMethod invoke) {
+    DexMethod method = invoke.getInvokedMethod();
+    DexItemFactory dexItemFactory = appView.dexItemFactory();
+
+    DexType mockedType;
+    if (method.getParameter(0).isIdenticalTo(dexItemFactory.classType)) {
+      // Given an explicit const-cast
+      Value classValue = invoke.getFirstArgument();
+      if (!classValue.isConstClass()) {
+        return;
+      }
+      mockedType = classValue.getDefinition().asConstClass().getType();
+    } else if (method.getParameter(method.getArity() - 1).isArrayType()) {
+      // This should always be an empty array of the mocked type.
+      Value arrayValue = invoke.getLastArgument();
+      ArrayTypeElement arrayType = arrayValue.getType().asArrayType();
+      if (arrayType == null) {
+        // Should never happen.
+        return;
+      }
+      ClassTypeElement memberType = arrayType.getMemberType().asClassType();
+      if (memberType == null) {
+        return;
+      }
+      mockedType = memberType.getClassType();
+    } else {
+      // Should be Mockito.spy(Object).
+      if (method.getArity() != 1
+          || !method.getParameter(0).isIdenticalTo(dexItemFactory.objectType)) {
+        return;
+      }
+      Value objectValue = invoke.getFirstArgument();
+      if (objectValue == null || objectValue.isPhi()) {
+        return;
+      }
+      ClassTypeElement classType = objectValue.getType().asClassType();
+      if (classType == null) {
+        return;
+      }
+      mockedType = classType.toDexType(dexItemFactory);
+    }
+
+    DexClass dexClass = appView.definitionFor(mockedType, context);
+    if (dexClass == null || !dexClass.isProgramClass()) {
+      return;
+    }
+
+    // Make sure the type is not made final so that it can still be subclassed by Mockito.
+    keepInfo.joinClass(dexClass.asProgramClass(), joiner -> joiner.disallowOptimization());
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java b/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
new file mode 100644
index 0000000..346fa69
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
@@ -0,0 +1,217 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking.reflection;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isFinal;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isInterface;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.ShouldNotBeMergedImpl;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+/** Tests for MockitoStub.mock() can MockitoStub.spy(). */
+@RunWith(Parameterized.class)
+public class MockitoTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDefaultRuntimes().withMaximumApiLevel().build();
+  }
+
+  private static final List<String> EXPECTED_OUTPUT =
+      Arrays.asList("A", "B", "C", "D", "E", "did thing");
+
+  public static class MockitoStub {
+    public static <T> T mock(Class<T> classToMock) {
+      return null;
+    }
+
+    public static <T> T mock(String name, T... reified) {
+      return null;
+    }
+
+    public static <T> T spy(Class<T> classToMock) {
+      return null;
+    }
+
+    public static <T> T spy(T... reified) {
+      return null;
+    }
+
+    public static <T> T spy(T classToMock) {
+      return null;
+    }
+  }
+
+  public static class Helpers {
+    public static class A {
+      @Override
+      public String toString() {
+        return "A";
+      }
+    }
+
+    public static class B {
+      @Override
+      public String toString() {
+        return "B";
+      }
+    }
+
+    public static class C {
+      @Override
+      public String toString() {
+        return "C";
+      }
+    }
+
+    public static class D {
+      @Override
+      public String toString() {
+        return "D";
+      }
+    }
+
+    public static class E {
+      @Override
+      public String toString() {
+        return "E";
+      }
+    }
+
+    public interface ShouldNotBeMerged {
+      void doThing();
+    }
+
+    public static class ShouldNotBeMergedImpl implements ShouldNotBeMerged {
+      @Override
+      public void doThing() {
+        System.out.println("did thing");
+      }
+    }
+  }
+
+  public static class TestMain {
+
+    @NeverInline
+    private static void mock1() {
+      MockitoStub.mock(Helpers.A.class);
+      System.out.println(new Helpers.A());
+    }
+
+    @NeverInline
+    private static void mock2() {
+      Helpers.B b = MockitoStub.mock("");
+      if (b == null) {
+        System.out.println(new Helpers.B());
+      }
+    }
+
+    @NeverInline
+    private static void spy1() {
+      MockitoStub.spy(Helpers.C.class);
+      System.out.println(new Helpers.C());
+    }
+
+    @NeverInline
+    private static void spy2() {
+      Helpers.D d = MockitoStub.spy();
+      if (d == null) {
+        System.out.println(new Helpers.D());
+      }
+    }
+
+    @NeverInline
+    private static void spy3() {
+      MockitoStub.spy(new Helpers.E());
+      System.out.println(new Helpers.E());
+    }
+
+    @NeverInline
+    private static void mockInterface() {
+      Helpers.ShouldNotBeMerged iface = MockitoStub.mock(Helpers.ShouldNotBeMerged.class);
+      if (iface == null) {
+        new ShouldNotBeMergedImpl().doThing();
+      }
+    }
+
+    public static void main(String[] args) {
+      // Use different methods to ensure Enqueuer.traceInvokeStatic() triggers for each one.
+      mock1();
+      mock2();
+      spy1();
+      spy2();
+      spy3();
+      mockInterface();
+    }
+  }
+
+  private static final String MOCKITO_DESCRIPTOR = "Lorg/mockito/Mockito;";
+
+  private static byte[] rewriteTestMain() throws IOException {
+    return transformer(TestMain.class)
+        .replaceClassDescriptorInMethodInstructions(
+            descriptor(MockitoStub.class), MOCKITO_DESCRIPTOR)
+        .transform();
+  }
+
+  private static byte[] rewriteMockito() throws IOException {
+    return transformer(MockitoStub.class).setClassDescriptor(MOCKITO_DESCRIPTOR).transform();
+  }
+
+  @Test
+  public void testRuntime() throws Exception {
+    byte[] mockitoClassBytes = rewriteMockito();
+    testForRuntime(parameters)
+        .addProgramClassesAndInnerClasses(Helpers.class)
+        .addProgramClassFileData(rewriteTestMain())
+        .addClasspathClassFileData(mockitoClassBytes)
+        .addRunClasspathFiles(buildOnDexRuntime(parameters, mockitoClassBytes))
+        .run(parameters.getRuntime(), TestMain.class)
+        .assertSuccessWithOutputLines(EXPECTED_OUTPUT);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    byte[] mockitoClassBytes = rewriteMockito();
+    testForR8(parameters.getBackend())
+        .setMinApi(parameters)
+        .addProgramClassesAndInnerClasses(Helpers.class)
+        .addProgramClassFileData(rewriteTestMain())
+        .addClasspathClassFileData(mockitoClassBytes)
+        .enableInliningAnnotations()
+        .addKeepMainRule(TestMain.class)
+        .compile()
+        .inspect(
+            inspector -> {
+              assertThat(inspector.clazz(Helpers.ShouldNotBeMerged.class), isInterface());
+              inspector.forAllClasses(
+                  clazz -> {
+                    String className = clazz.getOriginalTypeName();
+                    if (!className.endsWith("TestMain") && !className.endsWith("Impl")) {
+                      assertThat(clazz.getOriginalTypeName(), clazz, not(isFinal()));
+                    }
+                  });
+            })
+        .addRunClasspathClassFileData(mockitoClassBytes)
+        .run(parameters.getRuntime(), TestMain.class)
+        .assertSuccessWithOutputLines(EXPECTED_OUTPUT);
+  }
+}