Implement CfInvoke -> InvokePolymorphic conversion

Refactor and extend MethodHandleTestRunner to also test CF frontend to
ensure coverage of InvokePolymorphic.

Also implement ConstMethodType.registerUse() to fix minification and
treeshaking of LDC(MethodType) in DEX backend.

Change-Id: I815e2d9871ed4ebbb91e57e0b70eaebaf2f47020
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java b/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
index f63a52b..32c7022 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
@@ -8,6 +8,7 @@
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.UseRegistry;
 import com.android.tools.r8.ir.code.Invoke;
@@ -93,26 +94,45 @@
   public void buildIR(IRBuilder builder, CfState state, CfSourceCode code)
       throws ApiLevelException {
     Invoke.Type type;
+    DexMethod canonicalMethod;
+    DexProto callSiteProto = null;
     switch (opcode) {
       case Opcodes.INVOKEINTERFACE:
-        type = Type.INTERFACE;
-        break;
-      case Opcodes.INVOKEVIRTUAL:
-        // TODO(mathiasr): Handle InvokePolymorphic
-        type = Type.VIRTUAL;
-        break;
-      case Opcodes.INVOKESPECIAL:
-        if (method.name.toString().equals(Constants.INSTANCE_INITIALIZER_NAME)) {
-          type = Type.DIRECT;
-        } else if (builder.getMethod().holder == method.holder) {
-          type = Type.DIRECT;
-        } else {
-          type = Type.SUPER;
+        {
+          canonicalMethod = method;
+          type = Type.INTERFACE;
+          break;
         }
-        break;
+      case Opcodes.INVOKEVIRTUAL:
+        {
+          canonicalMethod = builder.getFactory().polymorphicMethods.canonicalize(method);
+          if (canonicalMethod == null) {
+            type = Type.VIRTUAL;
+            canonicalMethod = method;
+          } else {
+            type = Type.POLYMORPHIC;
+            callSiteProto = method.proto;
+          }
+          break;
+        }
+      case Opcodes.INVOKESPECIAL:
+        {
+          canonicalMethod = method;
+          if (method.name.toString().equals(Constants.INSTANCE_INITIALIZER_NAME)) {
+            type = Type.DIRECT;
+          } else if (builder.getMethod().holder == method.holder) {
+            type = Type.DIRECT;
+          } else {
+            type = Type.SUPER;
+          }
+          break;
+        }
       case Opcodes.INVOKESTATIC:
-        type = Type.STATIC;
-        break;
+        {
+          canonicalMethod = method;
+          type = Type.STATIC;
+          break;
+        }
       default:
         throw new Unreachable("unknown CfInvoke opcode " + opcode);
     }
@@ -127,8 +147,8 @@
       types[i] = slot.type;
       registers[i] = slot.register;
     }
-    builder.addInvoke(type, method, null, Arrays.asList(types), Arrays.asList(registers), itf);
-    // TODO(mathiasr): For InvokePolymorphic, use correct return type
+    builder.addInvoke(
+        type, canonicalMethod, callSiteProto, Arrays.asList(types), Arrays.asList(registers), itf);
     if (!method.proto.returnType.isVoidType()) {
       builder.addMoveResult(state.push(method.proto.returnType).register);
     }
diff --git a/src/main/java/com/android/tools/r8/code/ConstMethodType.java b/src/main/java/com/android/tools/r8/code/ConstMethodType.java
index aba9e08..6a1bef5 100644
--- a/src/main/java/com/android/tools/r8/code/ConstMethodType.java
+++ b/src/main/java/com/android/tools/r8/code/ConstMethodType.java
@@ -8,6 +8,7 @@
 import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.ObjectToOffsetMapping;
 import com.android.tools.r8.graph.OffsetToObjectMapping;
+import com.android.tools.r8.graph.UseRegistry;
 import com.android.tools.r8.ir.conversion.IRBuilder;
 import com.android.tools.r8.naming.ClassNameMapper;
 import java.nio.ShortBuffer;
@@ -56,6 +57,11 @@
   }
 
   @Override
+  public void registerUse(UseRegistry registry) {
+    registry.registerProto(getMethodType());
+  }
+
+  @Override
   public void write(ShortBuffer dest, ObjectToOffsetMapping mapping) {
     int index = BBBB.getOffset(mapping);
     if (index != (index & 0xffff)) {
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 6b99724..f73da18 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -125,6 +125,8 @@
   public final DexString getMethodName = createString("getMethod");
   public final DexString getDeclaredMethodName = createString("getDeclaredMethod");
   public final DexString assertionsDisabled = createString("$assertionsDisabled");
+  public final DexString invokeMethodName = createString("invoke");
+  public final DexString invokeExactMethodName = createString("invokeExact");
 
   public final DexString stringDescriptor = createString("Ljava/lang/String;");
   public final DexString stringArrayDescriptor = createString("[Ljava/lang/String;");
@@ -209,6 +211,7 @@
   public final AtomicFieldUpdaterMethods atomicFieldUpdaterMethods =
       new AtomicFieldUpdaterMethods();
   public final Kotlin kotlin;
+  public final PolymorphicMethods polymorphicMethods = new PolymorphicMethods();
 
   // Dex system annotations.
   // See https://source.android.com/devices/tech/dalvik/dex-format.html#system-annotation
@@ -471,6 +474,75 @@
     }
   }
 
+  public class PolymorphicMethods {
+
+    private final DexProto signature = createProto(objectType, objectArrayType);
+    private final DexProto setSignature = createProto(voidType, objectArrayType);
+    private final DexProto compareAndSetSignature = createProto(booleanType, objectArrayType);
+
+    private final Set<DexString> varHandleMethods =
+        createStrings(
+            "compareAndExchange",
+            "compareAndExchangeAcquire",
+            "compareAndExchangeRelease",
+            "get",
+            "getAcquire",
+            "getAndAdd",
+            "getAndAddAcquire",
+            "getAndAddRelease",
+            "getAndBitwiseAnd",
+            "getAndBitwiseAndAcquire",
+            "getAndBitwiseAndRelease",
+            "getAndBitwiseOr",
+            "getAndBitwiseOrAcquire",
+            "getAndBitwiseOrRelease",
+            "getAndBitwiseXor",
+            "getAndBitwiseXorAcquire",
+            "getAndBitwiseXorRelease",
+            "getAndSet",
+            "getAndSetAcquire",
+            "getAndSetRelease",
+            "getOpaque",
+            "getVolatile");
+
+    private final Set<DexString> varHandleSetMethods =
+        createStrings("set", "setOpaque", "setRelease", "setVolatile");
+
+    private final Set<DexString> varHandleCompareAndSetMethods =
+        createStrings(
+            "compareAndSet",
+            "weakCompareAndSet",
+            "weakCompareAndSetAcquire",
+            "weakCompareAndSetPlain",
+            "weakCompareAndSetRelease");
+
+    public DexMethod canonicalize(DexMethod invokeProto) {
+      if (invokeProto.holder == methodHandleType) {
+        if (invokeProto.name == invokeMethodName || invokeProto.name == invokeExactMethodName) {
+          return createMethod(methodHandleType, signature, invokeProto.name);
+        }
+      } else if (invokeProto.holder == varHandleType) {
+        if (varHandleMethods.contains(invokeProto.name)) {
+          return createMethod(varHandleType, signature, invokeProto.name);
+        } else if (varHandleSetMethods.contains(invokeProto.name)) {
+          return createMethod(varHandleType, setSignature, invokeProto.name);
+        } else if (varHandleCompareAndSetMethods.contains(invokeProto.name)) {
+          return createMethod(varHandleType, compareAndSetSignature, invokeProto.name);
+        }
+      }
+      return null;
+    }
+
+    private Set<DexString> createStrings(String... strings) {
+      IdentityHashMap<DexString, DexString> map = new IdentityHashMap<>();
+      for (String string : strings) {
+        DexString dexString = createString(string);
+        map.put(dexString, dexString);
+      }
+      return map.keySet();
+    }
+  }
+
   private static <T extends DexItem> T canonicalize(ConcurrentHashMap<T, T> map, T item) {
     assert item != null;
     assert !DexItemFactory.isInternalSentinel(item);
diff --git a/src/test/java/com/android/tools/r8/cf/MethodHandleTestRunner.java b/src/test/java/com/android/tools/r8/cf/MethodHandleTestRunner.java
index ae3b95e..df57902 100644
--- a/src/test/java/com/android/tools/r8/cf/MethodHandleTestRunner.java
+++ b/src/test/java/com/android/tools/r8/cf/MethodHandleTestRunner.java
@@ -9,9 +9,9 @@
 import com.android.tools.r8.CompilationMode;
 import com.android.tools.r8.DexIndexedConsumer;
 import com.android.tools.r8.ProgramConsumer;
-import com.android.tools.r8.R8;
 import com.android.tools.r8.R8Command;
 import com.android.tools.r8.R8Command.Builder;
+import com.android.tools.r8.TestBase;
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.ToolHelper.DexVm;
 import com.android.tools.r8.ToolHelper.ProcessResult;
@@ -19,45 +19,75 @@
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.DescriptorUtils;
 import java.nio.file.Path;
+import java.util.ArrayList;
 import java.util.Arrays;
-import org.junit.Rule;
+import java.util.List;
+import org.junit.Assume;
 import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
 
-public class MethodHandleTestRunner {
+@RunWith(Parameterized.class)
+public class MethodHandleTestRunner extends TestBase {
   static final Class<?> CLASS = MethodHandleTest.class;
 
-  private boolean ldc = false;
-  private boolean minify = false;
+  enum LookupType {
+    DYNAMIC,
+    CONSTANT,
+  }
 
-  @Rule
-  public TemporaryFolder temp = ToolHelper.getTemporaryFolderForTest();
+  enum MinifyMode {
+    NONE,
+    MINIFY,
+  }
 
-  @Test
-  public void testMethodHandlesLookup() throws Exception {
-    // Run test with dynamic method lookups, i.e. using MethodHandles.lookup().find*()
-    ldc = false;
-    test();
+  enum Frontend {
+    JAR,
+    CF,
+  }
+
+  private CompilationMode compilationMode;
+  private LookupType lookupType;
+  private Frontend frontend;
+  private ProcessResult runInput;
+  private MinifyMode minifyMode;
+
+  @Parameters(name = "{0}_{1}_{2}_{3}")
+  public static List<String[]> data() {
+    List<String[]> res = new ArrayList<>();
+    for (LookupType lookupType : LookupType.values()) {
+      for (Frontend frontend : Frontend.values()) {
+        for (MinifyMode minifyMode : MinifyMode.values()) {
+          if (lookupType == LookupType.DYNAMIC && minifyMode == MinifyMode.MINIFY) {
+            // Skip because we don't keep the members looked up dynamically.
+            continue;
+          }
+          for (CompilationMode compilationMode : CompilationMode.values()) {
+            res.add(
+                new String[] {
+                  lookupType.name(), frontend.name(), minifyMode.name(), compilationMode.name()
+                });
+          }
+        }
+      }
+    }
+    return res;
+  }
+
+  public MethodHandleTestRunner(
+      String lookupType, String frontend, String minifyMode, String compilationMode) {
+    this.lookupType = LookupType.valueOf(lookupType);
+    this.frontend = Frontend.valueOf(frontend);
+    this.minifyMode = MinifyMode.valueOf(minifyMode);
+    this.compilationMode = CompilationMode.valueOf(compilationMode);
   }
 
   @Test
-  public void testLdcMethodHandle() throws Exception {
-    // Run test with LDC methods, i.e. without java.lang.invoke.MethodHandles
-    ldc = true;
-    test();
-  }
-
-  @Test
-  public void testMinify() throws Exception {
-    // Run test with LDC methods, i.e. without java.lang.invoke.MethodHandles
-    ldc = true;
-    ProcessResult runInput = runInput();
-    assertEquals(0, runInput.exitCode);
-    Path outCf = temp.getRoot().toPath().resolve("cf.jar");
-    build(new ClassFileConsumer.ArchiveConsumer(outCf), true);
-    ProcessResult runCf =
-        ToolHelper.runJava(outCf, CLASS.getCanonicalName(), ldc ? "error" : "exception");
-    assertEquals(runInput.toString(), runCf.toString());
+  public void test() throws Exception {
+    runInput();
+    runCf();
+    runDex();
   }
 
   private final Class[] inputClasses = {
@@ -70,25 +100,41 @@
     MethodHandleTest.F.class,
   };
 
-  private void test() throws Exception {
-    ProcessResult runInput = runInput();
-    Path outCf = temp.getRoot().toPath().resolve("cf.jar");
-    build(new ClassFileConsumer.ArchiveConsumer(outCf), false);
-    Path outDex = temp.getRoot().toPath().resolve("dex.zip");
-    build(new DexIndexedConsumer.ArchiveConsumer(outDex), false);
-
-    ProcessResult runCf =
-        ToolHelper.runJava(outCf, CLASS.getCanonicalName(), ldc ? "error" : "exception");
-    assertEquals(runInput.toString(), runCf.toString());
-    // TODO(mathiasr): Once we include a P runtime, change this to "P and above".
-    if (ToolHelper.getDexVm() != DexVm.ART_DEFAULT) {
-      return;
+  private void runInput() throws Exception {
+    Path out = temp.getRoot().toPath().resolve("input.jar");
+    ClassFileConsumer.ArchiveConsumer archiveConsumer = new ClassFileConsumer.ArchiveConsumer(out);
+    for (Class<?> c : inputClasses) {
+      archiveConsumer.accept(
+          getClassAsBytes(c), DescriptorUtils.javaTypeToDescriptor(c.getName()), null);
     }
+    archiveConsumer.finished(null);
+    String expected = lookupType == LookupType.CONSTANT ? "error" : "exception";
+    runInput = ToolHelper.runJava(out, CLASS.getName(), expected);
+    if (runInput.exitCode != 0) {
+      System.out.println(runInput);
+    }
+    assertEquals(0, runInput.exitCode);
+  }
+
+  private void runCf() throws Exception {
+    Path outCf = temp.getRoot().toPath().resolve("cf.jar");
+    build(new ClassFileConsumer.ArchiveConsumer(outCf));
+    String expected = lookupType == LookupType.CONSTANT ? "error" : "exception";
+    ProcessResult runCf = ToolHelper.runJava(outCf, CLASS.getCanonicalName(), expected);
+    assertEquals(runInput.toString(), runCf.toString());
+  }
+
+  private void runDex() throws Exception {
+    // TODO(mathiasr): Once we include a P runtime, change this to "P and above".
+    Assume.assumeTrue(ToolHelper.getDexVm() == DexVm.ART_DEFAULT);
+    Path outDex = temp.getRoot().toPath().resolve("dex.zip");
+    build(new DexIndexedConsumer.ArchiveConsumer(outDex));
+    String expected = lookupType == LookupType.CONSTANT ? "pass" : "exception";
     ProcessResult runDex =
         ToolHelper.runArtRaw(
             outDex.toString(),
             CLASS.getCanonicalName(),
-            cmd -> cmd.appendProgramArgument(ldc ? "pass" : "exception"));
+            cmd -> cmd.appendProgramArgument(expected));
     // Only compare stdout and exitCode since dex2oat prints to stderr.
     if (runInput.exitCode != runDex.exitCode) {
       System.out.println(runDex.stderr);
@@ -97,51 +143,36 @@
     assertEquals(runInput.exitCode, runDex.exitCode);
   }
 
-  private void build(ProgramConsumer programConsumer, boolean minify) throws Exception {
+  private void build(ProgramConsumer programConsumer) throws Exception {
     // MethodHandle.invoke() only supported from Android O
     // ConstMethodHandle only supported from Android P
     AndroidApiLevel apiLevel = AndroidApiLevel.P;
-    Builder cfBuilder =
+    Builder command =
         R8Command.builder()
-            .setMode(CompilationMode.DEBUG)
+            .setMode(compilationMode)
             .addLibraryFiles(ToolHelper.getAndroidJar(apiLevel))
             .setProgramConsumer(programConsumer);
     if (!(programConsumer instanceof ClassFileConsumer)) {
-      cfBuilder.setMinApiLevel(apiLevel.getLevel());
+      command.setMinApiLevel(apiLevel.getLevel());
     }
     for (Class<?> c : inputClasses) {
       byte[] classAsBytes = getClassAsBytes(c);
-      cfBuilder.addClassProgramData(classAsBytes, Origin.unknown());
+      command.addClassProgramData(classAsBytes, Origin.unknown());
     }
-    if (minify) {
-      cfBuilder.addProguardConfiguration(
+    if (minifyMode == MinifyMode.MINIFY) {
+      command.addProguardConfiguration(
           Arrays.asList(
               "-keep public class com.android.tools.r8.cf.MethodHandleTest {",
               "  public static void main(...);",
               "}"),
           Origin.unknown());
     }
-    R8.run(cfBuilder.build());
-  }
-
-  private ProcessResult runInput() throws Exception {
-    Path out = temp.getRoot().toPath().resolve("input.jar");
-    ClassFileConsumer.ArchiveConsumer archiveConsumer = new ClassFileConsumer.ArchiveConsumer(out);
-    for (Class<?> c : inputClasses) {
-      archiveConsumer.accept(
-          getClassAsBytes(c), DescriptorUtils.javaTypeToDescriptor(c.getName()), null);
-    }
-    archiveConsumer.finished(null);
-    ProcessResult runInput = ToolHelper.runJava(out, CLASS.getName(), ldc ? "error" : "exception");
-    if (runInput.exitCode != 0) {
-      System.out.println(runInput);
-    }
-    assertEquals(0, runInput.exitCode);
-    return runInput;
+    ToolHelper.runR8(
+        command.build(), options -> options.enableCfFrontend = frontend == Frontend.CF);
   }
 
   private byte[] getClassAsBytes(Class<?> clazz) throws Exception {
-    if (ldc) {
+    if (lookupType == LookupType.CONSTANT) {
       if (clazz == MethodHandleTest.D.class) {
         return MethodHandleDump.dumpD();
       } else if (clazz == MethodHandleTest.class) {