Add support for Class.forName instrumentation

Add testing of returned stack elements.

Bug: b/393265921
Change-Id: I128d12071af66ded33e1d1561b15212f06390385
diff --git a/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOperationReceiver.java b/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOperationReceiver.java
index 62def20..67229c3 100644
--- a/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOperationReceiver.java
+++ b/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOperationReceiver.java
@@ -13,6 +13,8 @@
     return false;
   }
 
+  void onClassForName(Stack stack, String className);
+
   void onClassNewInstance(Stack stack, Class<?> clazz);
 
   void onClassGetDeclaredMethod(Stack stack, Class<?> clazz, String method, Class<?>... parameters);
diff --git a/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java b/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
index e653b72..df14c36 100644
--- a/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
+++ b/src/assistant/java/com/android/tools/r8/assistant/runtime/ReflectiveOracle.java
@@ -72,6 +72,10 @@
     getInstance().onClassGetDeclaredMethod(Stack.createStack(), clazz, name, parameters);
   }
 
+  public static void onClassForName(String className) {
+    getInstance().onClassForName(Stack.createStack(), className);
+  }
+
   @KeepForApi
   public static class ReflectiveOperationLogger implements ReflectiveOperationReceiver {
     @Override
@@ -86,6 +90,11 @@
     }
 
     @Override
+    public void onClassForName(Stack stack, String className) {
+      System.out.println("Reflectively called Class.forName on " + className);
+    }
+
+    @Override
     public boolean requiresStackInformation() {
       return true;
     }
diff --git a/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java b/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
index d7f9326..62edbb1 100644
--- a/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
+++ b/src/main/java/com/android/tools/r8/assistant/ReflectiveInstrumentation.java
@@ -16,8 +16,8 @@
 import com.android.tools.r8.ir.code.IRCodeInstructionListIterator;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InvokeDirect;
+import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.InvokeStatic;
-import com.android.tools.r8.ir.code.InvokeVirtual;
 import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.code.Return;
 import com.android.tools.r8.ir.code.Value;
@@ -87,19 +87,16 @@
                   blockIterator.next().listIterator();
               while (instructionIterator.hasNext()) {
                 Instruction instruction = instructionIterator.next();
-                if (!instruction.isInvokeVirtual()) {
+                if (!instruction.isInvokeVirtual() && !instruction.isInvokeStatic()) {
                   continue;
                 }
-                InvokeVirtual invokeVirtual = instruction.asInvokeVirtual();
-                DexMethod invokedMethod = invokeVirtual.getInvokedMethod();
+                InvokeMethod invoke = instruction.asInvokeMethod();
+                DexMethod invokedMethod = invoke.getInvokedMethod();
+
                 DexMethod toInstrumentCallTo = instrumentedMethodsAndTargets.get(invokedMethod);
                 if (toInstrumentCallTo != null) {
                   insertCallToMethod(
-                      toInstrumentCallTo,
-                      irCode,
-                      blockIterator,
-                      instructionIterator,
-                      invokeVirtual);
+                      toInstrumentCallTo, irCode, blockIterator, instructionIterator, invoke);
                   changed = true;
                 }
               }
@@ -117,13 +114,23 @@
         dexItemFactory.classMethods.newInstance,
         getMethodReferenceWithClassParameter("onClassNewInstance"),
         dexItemFactory.classMethods.getDeclaredMethod,
-        getMethodReferenceWithClassMethodNameAndParameters("onClassGetDeclaredMethod"));
+        getMethodReferenceWithClassMethodNameAndParameters("onClassGetDeclaredMethod"),
+        dexItemFactory.classMethods.forName,
+        getMethodReferenceWithStringParameter("onClassForName"));
   }
 
   private DexMethod getMethodReferenceWithClassParameter(String name) {
+    return getMethodReferenceWithSingleParameter(name, dexItemFactory.classType);
+  }
+
+  private DexMethod getMethodReferenceWithStringParameter(String name) {
+    return getMethodReferenceWithSingleParameter(name, dexItemFactory.stringType);
+  }
+
+  private DexMethod getMethodReferenceWithSingleParameter(String name, DexType type) {
     return dexItemFactory.createMethod(
         reflectiveReferences.reflectiveOracleType,
-        dexItemFactory.createProto(dexItemFactory.voidType, dexItemFactory.classType),
+        dexItemFactory.createProto(dexItemFactory.voidType, type),
         name);
   }
 
@@ -143,13 +150,13 @@
       IRCode code,
       BasicBlockIterator blockIterator,
       BasicBlockInstructionListIterator instructionIterator,
-      InvokeVirtual invoke) {
+      InvokeMethod invoke) {
     InvokeStatic invokeStatic =
         InvokeStatic.builder()
             .setMethod(method)
             .setArguments(invoke.inValues())
             // Same position so that the stack trace has the correct line number.
-            .setPosition(invoke)
+            .setPosition(invoke.getPosition())
             .build();
     instructionIterator.previous();
     instructionIterator.addPossiblyThrowingInstructionsToPossiblyThrowingBlock(
diff --git a/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java b/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
index 15872c0..230b7d6 100644
--- a/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
+++ b/src/test/java/com/android/tools/r8/assistant/R8AssistentReflectiveInstrumentationTest.java
@@ -30,6 +30,7 @@
 
 @RunWith(Parameterized.class)
 public class R8AssistentReflectiveInstrumentationTest extends TestBase {
+
   @Parameter(0)
   public TestParameters parameters;
 
@@ -64,12 +65,13 @@
         .addProgramClasses(TestClass.class, Foo.class, Bar.class)
         .setMinApi(parameters)
         .compile()
-        .inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(0, inspector))
-        .inspect(inspector -> inspectStaticCallsInReflectOn(2, inspector))
+        .inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(1, inspector))
+        .inspect(inspector -> inspectStaticCallsInReflectOn(4, inspector))
         .run(parameters.getRuntime(), TestClass.class)
         .assertSuccessWithOutputLines(
             "Reflectively created new instance of " + Bar.class.getName(),
-            "Reflectively got declared method callMe on " + Bar.class.getName());
+            "Reflectively got declared method callMe on " + Bar.class.getName(),
+            "Reflectively called Class.forName on " + Bar.class.getName());
   }
 
   @Test
@@ -80,11 +82,96 @@
         .setCustomReflectiveOperationReceiver(InstrumentationClass.class)
         .setMinApi(parameters)
         .compile()
-        .inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(0, inspector))
-        .inspect(inspector -> inspectStaticCallsInReflectOn(2, inspector))
+        .inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(1, inspector))
+        .inspect(inspector -> inspectStaticCallsInReflectOn(4, inspector))
         .run(parameters.getRuntime(), TestClass.class)
         .assertSuccessWithOutputLines(
-            "Custom receiver " + Bar.class.getName(), "Custom receiver method callMe");
+            "Custom receiver " + Bar.class.getName(),
+            "Custom receiver method callMe",
+            "Custom receiver classForName " + Bar.class.getName());
+  }
+
+  @Test
+  public void testStack() throws Exception {
+    testForAssistant()
+        .addProgramClasses(TestClass.class, Foo.class, Bar.class)
+        .addInstrumentationClasses(TestReflectiveOperationReceiverStackHandler.class)
+        .setCustomReflectiveOperationReceiver(
+            descriptor(TestReflectiveOperationReceiverStackHandler.class))
+        .setMinApi(parameters)
+        .compile()
+        .inspectOriginalDex(inspector -> inspectStaticCallsInReflectOn(1, inspector))
+        .inspect(inspector -> inspectStaticCallsInReflectOn(4, inspector))
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("correct", "correct", "correct");
+  }
+
+  // Injected into the app by the R8Assistant.
+  public static class TestReflectiveOperationReceiverStackHandler
+      implements ReflectiveOperationReceiver {
+
+    int lineNumberOfNewInstance = -1;
+
+    @Override
+    public void onClassForName(Stack stack, String className) {
+      if (!className.equals(Bar.class.getName())) {
+        throw new RuntimeException("Wrong class name passed");
+      }
+      int lineNumberOfTopOfStack = getLineNumberOfTopOfStack(stack);
+      if (lineNumberOfTopOfStack != lineNumberOfNewInstance + 2) {
+        throw new RuntimeException("Wrong line number on top of stack " + lineNumberOfTopOfStack);
+      }
+      ensureCorrectStack(stack);
+    }
+
+    private void ensureCorrectStack(Stack stack) {
+      StackTraceElement[] stackTraceElements = stack.getStackTraceElements();
+      if (stackTraceElements.length != 2) {
+        // Only main and reflectOn should be in the stack
+        throw new RuntimeException("Wrong stack hight of " + stackTraceElements.length);
+      }
+      String topOfStack = stack.getStackTraceElements()[0].toString();
+      String secondToTopOfStack = stack.getStackTraceElements()[1].toString();
+      String sourceFile = "R8AssistentReflectiveInstrumentationTest";
+      if (!topOfStack.contains("reflectOn(" + sourceFile)) {
+        throw new RuntimeException("reflectOn must be top of stack, got " + topOfStack);
+      }
+      if (!secondToTopOfStack.contains("main(" + sourceFile)) {
+        throw new RuntimeException("main must be second to top of stack");
+      }
+      System.out.println("correct");
+    }
+
+    private int getLineNumberOfTopOfStack(Stack stack) {
+      return stack.getStackTraceElements()[0].getLineNumber();
+    }
+
+    @Override
+    public void onClassNewInstance(Stack stack, Class<?> clazz) {
+      if (!clazz.equals(Bar.class)) {
+        throw new RuntimeException("Wrong class passed");
+      }
+      ensureCorrectStack(stack);
+      lineNumberOfNewInstance = getLineNumberOfTopOfStack(stack);
+    }
+
+    @Override
+    public void onClassGetDeclaredMethod(
+        Stack stack, Class<?> clazz, String method, Class<?>... parameters) {
+      if (!clazz.equals(Bar.class) || !method.equals("callMe")) {
+        throw new RuntimeException("Wrong method passed");
+      }
+      int lineNumberOfTopOfStack = getLineNumberOfTopOfStack(stack);
+      if (lineNumberOfTopOfStack != lineNumberOfNewInstance + 1) {
+        throw new RuntimeException("Wrong line number on top of stack " + lineNumberOfTopOfStack);
+      }
+      ensureCorrectStack(stack);
+    }
+
+    @Override
+    public boolean requiresStackInformation() {
+      return true;
+    }
   }
 
   private static void inspectStaticCallsInReflectOn(int count, CodeInspector inspector) {
@@ -100,6 +187,11 @@
   public static class InstrumentationClass implements ReflectiveOperationReceiver {
 
     @Override
+    public void onClassForName(Stack stack, String className) {
+      System.out.println("Custom receiver classForName " + className);
+    }
+
+    @Override
     public void onClassNewInstance(Stack stack, Class<?> clazz) {
       System.out.println("Custom receiver " + clazz.getName());
     }
@@ -120,8 +212,11 @@
       try {
         clazz.newInstance();
         clazz.getDeclaredMethod("callMe");
-
-      } catch (InstantiationException | IllegalAccessException | NoSuchMethodException e) {
+        Class.forName(clazz.getName());
+      } catch (InstantiationException
+          | IllegalAccessException
+          | NoSuchMethodException
+          | ClassNotFoundException e) {
         throw new RuntimeException(e);
       }
     }