EnqueuerMockitoAnalysis: Prevent optimization of spied subtypes
When we see "Mockito.spy(concreteObject)", we must allow subtypes of the
object type to be mocked.
Bug: 389166093
Change-Id: Ib7e5e75e4f9f68c42581274eda498a1164a3eb6e
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoAnalysis.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoAnalysis.java
index 2c7af5e..8f322f7 100644
--- a/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoAnalysis.java
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoAnalysis.java
@@ -15,6 +15,7 @@
import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.graph.MethodResolutionResult;
import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.SubtypingInfo;
import com.android.tools.r8.graph.analysis.EnqueuerAnalysisCollection;
import com.android.tools.r8.graph.analysis.FinishedEnqueuerAnalysis;
import com.android.tools.r8.graph.analysis.IrBasedEnqueuerAnalysis;
@@ -24,14 +25,15 @@
import com.android.tools.r8.ir.code.InvokeMethod;
import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.shaking.KeepInfo.Joiner;
-import com.android.tools.r8.shaking.KeepInfoCollection.MutableKeepInfoCollection;
-import java.util.HashSet;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import java.util.ArrayDeque;
+import java.util.Map;
import java.util.Set;
/** Ensure classes passed to Mockito.mock() and Mockito.spy() are not marked as "final". */
class EnqueuerMockitoAnalysis
implements TraceInvokeEnqueuerAnalysis, IrBasedEnqueuerAnalysis, FinishedEnqueuerAnalysis {
-
private final AppView<? extends AppInfoWithClassHierarchy> appView;
private final Enqueuer enqueuer;
@@ -39,7 +41,8 @@
private final DexString mockString;
private final DexString spyString;
- private final Set<DexProgramClass> mockedProgramClasses = new HashSet<>();
+ private final Set<DexProgramClass> mockedProgramClasses = Sets.newIdentityHashSet();
+ private final Map<DexProgramClass, ProgramMethod> spiedInstanceTypes = Maps.newIdentityHashMap();
public EnqueuerMockitoAnalysis(
AppView<? extends AppInfoWithClassHierarchy> appView, Enqueuer enqueuer) {
@@ -121,14 +124,17 @@
return true;
}
mockedType = classType.toDexType(dexItemFactory);
+ DexProgramClass spiedClass = asProgramClassOrNull(appView.definitionFor(mockedType, context));
+ if (spiedClass != null) {
+ spiedInstanceTypes.putIfAbsent(spiedClass, context);
+ }
}
- keepMockedType(context, mockedType, enqueuer.getKeepInfo());
+ recordMockedType(context, mockedType);
return true;
}
- private void keepMockedType(
- ProgramMethod context, DexType mockedType, MutableKeepInfoCollection keepInfo) {
+ private void recordMockedType(ProgramMethod context, DexType mockedType) {
DexType curType = mockedType;
while (curType != null) {
DexProgramClass programClass = asProgramClassOrNull(appView.definitionFor(curType, context));
@@ -136,29 +142,57 @@
return;
}
- if (curType.isIdenticalTo(mockedType)) {
- // Make sure the type is not made final so that it can still be subclassed by Mockito.
- keepInfo.joinClass(programClass, Joiner::disallowOptimization);
+ if (!mockedProgramClasses.add(programClass)) {
+ return;
}
-
- mockedProgramClasses.add(programClass);
curType = programClass.getSuperType();
}
}
@Override
public void done(Enqueuer enqueuer) {
- for (DexProgramClass programClass : mockedProgramClasses) {
- // disallowOptimization --> prevent method from being marked final.
- // allowCodeReplacement --> do not inline or optimize based on method body.
- programClass.forEachProgramVirtualMethodMatching(
- enqueuer::isMethodLive,
- virtualMethod ->
- enqueuer
- .getKeepInfo()
- .joinMethod(
- virtualMethod,
- joiner -> joiner.disallowOptimization().allowCodeReplacement()));
+ // When Mockity.spy(instance) is used, all subtypes of the given type must be mockable.
+ SubtypingInfo subtypingInfo = enqueuer.getSubtypingInfo();
+ ArrayDeque<DexType> subtypeDeque = new ArrayDeque<>();
+ Set<DexProgramClass> seen = Sets.newIdentityHashSet();
+ for (var entry : spiedInstanceTypes.entrySet()) {
+ DexProgramClass spiedClass = entry.getKey();
+ ProgramMethod context = entry.getValue();
+ if (!seen.add(spiedClass)) {
+ continue;
+ }
+ subtypeDeque.addAll(subtypingInfo.allImmediateSubtypes(spiedClass.getType()));
+ while (!subtypeDeque.isEmpty()) {
+ DexType subtype = subtypeDeque.removeLast();
+ DexProgramClass subClass = asProgramClassOrNull(appView.definitionFor(subtype, context));
+ if (subClass == null || !seen.add(subClass)) {
+ continue;
+ }
+ mockedProgramClasses.add(subClass);
+
+ subtypeDeque.addAll(subtypingInfo.allImmediateSubtypes(subtype));
+ }
}
+
+ for (DexProgramClass mockedClass : mockedProgramClasses) {
+ if (enqueuer.isTypeLive(mockedClass)) {
+ ensureClassIsMockable(mockedClass);
+ }
+ }
+ }
+
+ private void ensureClassIsMockable(DexProgramClass programClass) {
+ // Ensures the type is not made final so that it can still be subclassed.
+ enqueuer.getKeepInfo().joinClass(programClass, Joiner::disallowOptimization);
+
+ // disallowOptimization --> prevent method from being marked final.
+ // allowCodeReplacement --> do not inline or optimize based on method body.
+ programClass.forEachProgramVirtualMethodMatching(
+ enqueuer::isMethodLive,
+ virtualMethod ->
+ enqueuer.getKeepInfo()
+ .joinMethod(
+ virtualMethod,
+ joiner -> joiner.disallowOptimization().allowCodeReplacement()));
}
}
diff --git a/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java b/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
index 19ddd0a..40d5be0 100644
--- a/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
+++ b/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
@@ -15,6 +15,9 @@
import com.android.tools.r8.TestParameters;
import com.android.tools.r8.TestParametersCollection;
import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.ShouldNotBeMergedImpl;
+import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.SpyImpl1;
+import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.SpyImpl2;
+import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.SpyInterface;
import com.android.tools.r8.utils.codeinspector.InstructionSubject;
import java.io.IOException;
import java.util.Arrays;
@@ -38,7 +41,7 @@
}
private static final List<String> EXPECTED_OUTPUT =
- Arrays.asList("A", "B", "C", "D", "E", "did thing", "not inlined");
+ Arrays.asList("A", "B", "C", "D", "I spy 1", "did thing", "not inlined");
public static class MockitoStub {
public static <T> T mock(Class<T> classToMock) {
@@ -91,10 +94,20 @@
}
}
- public static class E {
+ public interface SpyInterface {
+ void spiedMethod();
+ }
+ public static class SpyImpl1 implements SpyInterface {
@Override
- public String toString() {
- return "E";
+ public void spiedMethod() {
+ System.out.println("I spy 1");
+ }
+ }
+
+ public static class SpyImpl2 implements SpyInterface {
+ @Override
+ public void spiedMethod() {
+ System.out.println("I spy 2");
}
}
@@ -149,9 +162,14 @@
}
@NeverInline
+ private static void spyHelper(SpyInterface inst) {
+ MockitoStub.spy(inst);
+ inst.spiedMethod();
+ }
+
+ @NeverInline
private static void spy3() {
- MockitoStub.spy(new Helpers.E());
- System.out.println(new Helpers.E());
+ spyHelper(System.currentTimeMillis() > 0 ? new SpyImpl1() : new SpyImpl2());
}
@NeverInline
@@ -234,11 +252,16 @@
.method("void", "mockSubclass")
.streamInstructions()
.noneMatch(InstructionSubject::isStaticGet));
+ // Ensure mocked classes are not marked as "final".
inspector.forAllClasses(
clazz -> {
String className = clazz.getOriginalTypeName();
if (!className.endsWith("TestMain") && !className.endsWith("Impl")) {
assertThat(clazz.getOriginalTypeName(), clazz, not(isFinal()));
+ // No mocked methods should be finalized.
+ clazz.forAllMethods(method -> {
+ assertThat(method, not(isFinal()));
+ });
}
});
})