EnqueuerMockitoAnalysis: Prevent optimization of spied subtypes

When we see "Mockito.spy(concreteObject)", we must allow subtypes of the
object type to be mocked.

Bug: 389166093
Change-Id: Ib7e5e75e4f9f68c42581274eda498a1164a3eb6e
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoAnalysis.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoAnalysis.java
index 2c7af5e..8f322f7 100644
--- a/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoAnalysis.java
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerMockitoAnalysis.java
@@ -15,6 +15,7 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.MethodResolutionResult;
 import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.SubtypingInfo;
 import com.android.tools.r8.graph.analysis.EnqueuerAnalysisCollection;
 import com.android.tools.r8.graph.analysis.FinishedEnqueuerAnalysis;
 import com.android.tools.r8.graph.analysis.IrBasedEnqueuerAnalysis;
@@ -24,14 +25,15 @@
 import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.shaking.KeepInfo.Joiner;
-import com.android.tools.r8.shaking.KeepInfoCollection.MutableKeepInfoCollection;
-import java.util.HashSet;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import java.util.ArrayDeque;
+import java.util.Map;
 import java.util.Set;
 
 /** Ensure classes passed to Mockito.mock() and Mockito.spy() are not marked as "final". */
 class EnqueuerMockitoAnalysis
     implements TraceInvokeEnqueuerAnalysis, IrBasedEnqueuerAnalysis, FinishedEnqueuerAnalysis {
-
   private final AppView<? extends AppInfoWithClassHierarchy> appView;
   private final Enqueuer enqueuer;
 
@@ -39,7 +41,8 @@
   private final DexString mockString;
   private final DexString spyString;
 
-  private final Set<DexProgramClass> mockedProgramClasses = new HashSet<>();
+  private final Set<DexProgramClass> mockedProgramClasses = Sets.newIdentityHashSet();
+  private final Map<DexProgramClass, ProgramMethod> spiedInstanceTypes = Maps.newIdentityHashMap();
 
   public EnqueuerMockitoAnalysis(
       AppView<? extends AppInfoWithClassHierarchy> appView, Enqueuer enqueuer) {
@@ -121,14 +124,17 @@
         return true;
       }
       mockedType = classType.toDexType(dexItemFactory);
+      DexProgramClass spiedClass = asProgramClassOrNull(appView.definitionFor(mockedType, context));
+      if (spiedClass != null) {
+        spiedInstanceTypes.putIfAbsent(spiedClass, context);
+      }
     }
 
-    keepMockedType(context, mockedType, enqueuer.getKeepInfo());
+    recordMockedType(context, mockedType);
     return true;
   }
 
-  private void keepMockedType(
-      ProgramMethod context, DexType mockedType, MutableKeepInfoCollection keepInfo) {
+  private void recordMockedType(ProgramMethod context, DexType mockedType) {
     DexType curType = mockedType;
     while (curType != null) {
       DexProgramClass programClass = asProgramClassOrNull(appView.definitionFor(curType, context));
@@ -136,29 +142,57 @@
         return;
       }
 
-      if (curType.isIdenticalTo(mockedType)) {
-        // Make sure the type is not made final so that it can still be subclassed by Mockito.
-        keepInfo.joinClass(programClass, Joiner::disallowOptimization);
+      if (!mockedProgramClasses.add(programClass)) {
+        return;
       }
-
-      mockedProgramClasses.add(programClass);
       curType = programClass.getSuperType();
     }
   }
 
   @Override
   public void done(Enqueuer enqueuer) {
-    for (DexProgramClass programClass : mockedProgramClasses) {
-      // disallowOptimization --> prevent method from being marked final.
-      // allowCodeReplacement --> do not inline or optimize based on method body.
-      programClass.forEachProgramVirtualMethodMatching(
-          enqueuer::isMethodLive,
-          virtualMethod ->
-              enqueuer
-                  .getKeepInfo()
-                  .joinMethod(
-                      virtualMethod,
-                      joiner -> joiner.disallowOptimization().allowCodeReplacement()));
+    // When Mockity.spy(instance) is used, all subtypes of the given type must be mockable.
+    SubtypingInfo subtypingInfo = enqueuer.getSubtypingInfo();
+    ArrayDeque<DexType> subtypeDeque = new ArrayDeque<>();
+    Set<DexProgramClass> seen = Sets.newIdentityHashSet();
+    for (var entry : spiedInstanceTypes.entrySet()) {
+      DexProgramClass spiedClass = entry.getKey();
+      ProgramMethod context = entry.getValue();
+      if (!seen.add(spiedClass)) {
+        continue;
+      }
+      subtypeDeque.addAll(subtypingInfo.allImmediateSubtypes(spiedClass.getType()));
+      while (!subtypeDeque.isEmpty()) {
+        DexType subtype = subtypeDeque.removeLast();
+        DexProgramClass subClass = asProgramClassOrNull(appView.definitionFor(subtype, context));
+        if (subClass == null || !seen.add(subClass)) {
+          continue;
+        }
+        mockedProgramClasses.add(subClass);
+
+        subtypeDeque.addAll(subtypingInfo.allImmediateSubtypes(subtype));
+      }
     }
+
+    for (DexProgramClass mockedClass : mockedProgramClasses) {
+      if (enqueuer.isTypeLive(mockedClass)) {
+        ensureClassIsMockable(mockedClass);
+      }
+    }
+  }
+
+  private void ensureClassIsMockable(DexProgramClass programClass) {
+    // Ensures the type is not made final so that it can still be subclassed.
+    enqueuer.getKeepInfo().joinClass(programClass, Joiner::disallowOptimization);
+
+    // disallowOptimization --> prevent method from being marked final.
+    // allowCodeReplacement --> do not inline or optimize based on method body.
+    programClass.forEachProgramVirtualMethodMatching(
+        enqueuer::isMethodLive,
+        virtualMethod ->
+            enqueuer.getKeepInfo()
+                .joinMethod(
+                    virtualMethod,
+                    joiner -> joiner.disallowOptimization().allowCodeReplacement()));
   }
 }
diff --git a/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java b/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
index 19ddd0a..40d5be0 100644
--- a/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
+++ b/src/test/java/com/android/tools/r8/shaking/reflection/MockitoTest.java
@@ -15,6 +15,9 @@
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.ShouldNotBeMergedImpl;
+import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.SpyImpl1;
+import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.SpyImpl2;
+import com.android.tools.r8.shaking.reflection.MockitoTest.Helpers.SpyInterface;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
 import java.io.IOException;
 import java.util.Arrays;
@@ -38,7 +41,7 @@
   }
 
   private static final List<String> EXPECTED_OUTPUT =
-      Arrays.asList("A", "B", "C", "D", "E", "did thing", "not inlined");
+      Arrays.asList("A", "B", "C", "D", "I spy 1", "did thing", "not inlined");
 
   public static class MockitoStub {
     public static <T> T mock(Class<T> classToMock) {
@@ -91,10 +94,20 @@
       }
     }
 
-    public static class E {
+    public interface SpyInterface {
+      void spiedMethod();
+    }
+    public static class SpyImpl1 implements SpyInterface {
       @Override
-      public String toString() {
-        return "E";
+      public void spiedMethod() {
+        System.out.println("I spy 1");
+      }
+    }
+
+    public static class SpyImpl2 implements SpyInterface {
+      @Override
+      public void spiedMethod() {
+        System.out.println("I spy 2");
       }
     }
 
@@ -149,9 +162,14 @@
     }
 
     @NeverInline
+    private static void spyHelper(SpyInterface inst) {
+      MockitoStub.spy(inst);
+      inst.spiedMethod();
+    }
+
+    @NeverInline
     private static void spy3() {
-      MockitoStub.spy(new Helpers.E());
-      System.out.println(new Helpers.E());
+      spyHelper(System.currentTimeMillis() > 0 ? new SpyImpl1() : new SpyImpl2());
     }
 
     @NeverInline
@@ -234,11 +252,16 @@
                       .method("void", "mockSubclass")
                       .streamInstructions()
                       .noneMatch(InstructionSubject::isStaticGet));
+              // Ensure mocked classes are not marked as "final".
               inspector.forAllClasses(
                   clazz -> {
                     String className = clazz.getOriginalTypeName();
                     if (!className.endsWith("TestMain") && !className.endsWith("Impl")) {
                       assertThat(clazz.getOriginalTypeName(), clazz, not(isFinal()));
+                      // No mocked methods should be finalized.
+                      clazz.forAllMethods(method -> {
+                        assertThat(method, not(isFinal()));
+                      });
                     }
                   });
             })