Introduce registerInvokeSpecial() method on UseRegistry

This also removes CfInvoke#getInvokeType() such that there is now a single way of converting from CfInvoke to Invoke.Type.

Fixes: 202381923
Change-Id: I1ce56c466f62d57669e4fa412b0b4b747fca46e9
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java b/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
index e9cd796..b6fd5e0 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
@@ -102,9 +102,15 @@
       NamingLens namingLens,
       LensCodeRewriterUtils rewriter,
       MethodVisitor visitor) {
-    MethodLookupResult lookup =
-        graphLens.lookupMethod(
-            method, context.getReference(), getInvokeType(context, dexItemFactory));
+    Invoke.Type invokeType =
+        Invoke.Type.fromCfOpcode(
+            opcode,
+            method,
+            context,
+            dexItemFactory,
+            graphLens,
+            () -> context.getDefinition().getCode().asCfCode().getOriginalHolder());
+    MethodLookupResult lookup = graphLens.lookupMethod(method, context.getReference(), invokeType);
     DexMethod rewrittenMethod = lookup.getReference();
     String owner = namingLens.lookupInternalName(rewrittenMethod.holder);
     String name = namingLens.lookupName(rewrittenMethod).toString();
@@ -120,50 +126,21 @@
   @Override
   void internalRegisterUse(
       UseRegistry<?> registry, DexClassAndMethod context, ListIterator<CfInstruction> iterator) {
-    Type invokeType = getInvokeType(context, registry.dexItemFactory());
-    switch (invokeType) {
-      case DIRECT:
-        registry.registerInvokeDirect(method);
-        break;
-      case INTERFACE:
+    switch (opcode) {
+      case Opcodes.INVOKEINTERFACE:
         registry.registerInvokeInterface(method);
         break;
-      case STATIC:
+      case Opcodes.INVOKESPECIAL:
+        registry.registerInvokeSpecial(method, itf);
+        break;
+      case Opcodes.INVOKESTATIC:
         registry.registerInvokeStatic(method, itf);
         break;
-      case SUPER:
-        registry.registerInvokeSuper(method);
-        break;
-      case VIRTUAL:
+      case Opcodes.INVOKEVIRTUAL:
         registry.registerInvokeVirtual(method);
         break;
       default:
-        throw new Unreachable("Unexpected invoke type " + invokeType);
-    }
-  }
-
-  // We should avoid interpreting a CF invoke using DEX semantics.
-  @Deprecated
-  private Invoke.Type getInvokeType(DexClassAndMethod context, DexItemFactory dexItemFactory) {
-    switch (opcode) {
-      case Opcodes.INVOKEINTERFACE:
-        return Type.INTERFACE;
-
-      case Opcodes.INVOKEVIRTUAL:
-        return Type.VIRTUAL;
-
-      case Opcodes.INVOKESPECIAL:
-        if (method.isInstanceInitializer(dexItemFactory)
-            || method.getHolderType() == context.getHolderType()) {
-          return Type.DIRECT;
-        }
-        return Type.SUPER;
-
-      case Opcodes.INVOKESTATIC:
-        return Type.STATIC;
-
-      default:
-        throw new Unreachable("unknown CfInvoke opcode " + opcode);
+        throw new Unreachable("Unknown CfInvoke opcode " + opcode);
     }
   }
 
@@ -218,8 +195,7 @@
         }
       case Opcodes.INVOKEVIRTUAL:
         {
-          canonicalMethod =
-              builder.appView.dexItemFactory().polymorphicMethods.canonicalize(method);
+          canonicalMethod = builder.dexItemFactory().polymorphicMethods.canonicalize(method);
           if (canonicalMethod == null) {
             type = Type.VIRTUAL;
             canonicalMethod = method;
@@ -242,10 +218,16 @@
           // direct superinterface of T."
           // Using invoke-super should therefore observe the correct semantics since we cannot
           // target less specific targets (up in the hierarchy).
+          AppView<?> appView = builder.appView;
+          ProgramMethod context = builder.getProgramMethod();
           canonicalMethod = method;
           type =
-              computeInvokeTypeForInvokeSpecial(
-                  builder.appView, method, builder.getProgramMethod(), code.getOriginalHolder());
+              Invoke.Type.fromInvokeSpecial(
+                  method,
+                  context,
+                  appView.dexItemFactory(),
+                  appView.graphLens(),
+                  code::getOriginalHolder);
           break;
         }
       case Opcodes.INVOKESTATIC:
@@ -257,7 +239,8 @@
       default:
         throw new Unreachable("unknown CfInvoke opcode " + opcode);
     }
-    int parameterCount = method.proto.parameters.size();
+
+    int parameterCount = method.getParameters().size();
     if (type != Type.STATIC) {
       parameterCount += 1;
     }
@@ -270,9 +253,17 @@
     }
     builder.addInvoke(
         type, canonicalMethod, callSiteProto, Arrays.asList(types), Arrays.asList(registers), itf);
-    if (!method.proto.returnType.isVoidType()) {
-      builder.addMoveResult(state.push(method.proto.returnType).register);
+    if (!method.getReturnType().isVoidType()) {
+      builder.addMoveResult(state.push(method.getReturnType()).register);
     }
+    assert type
+        == Invoke.Type.fromCfOpcode(
+            opcode,
+            method,
+            builder.getProgramMethod(),
+            builder.dexItemFactory(),
+            builder.appView.graphLens(),
+            code::getOriginalHolder);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 24c2a3a..fdee405 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -2064,20 +2064,22 @@
             "weakCompareAndSetRelease");
 
     public DexMethod canonicalize(DexMethod invokeProto) {
+      DexMethod result = null;
       if (invokeProto.holder == methodHandleType) {
         if (invokeProto.name == invokeMethodName || invokeProto.name == invokeExactMethodName) {
-          return createMethod(methodHandleType, signature, invokeProto.name);
+          result = createMethod(methodHandleType, signature, invokeProto.name);
         }
       } else if (invokeProto.holder == varHandleType) {
         if (varHandleMethods.contains(invokeProto.name)) {
-          return createMethod(varHandleType, signature, invokeProto.name);
+          result = createMethod(varHandleType, signature, invokeProto.name);
         } else if (varHandleSetMethods.contains(invokeProto.name)) {
-          return createMethod(varHandleType, setSignature, invokeProto.name);
+          result = createMethod(varHandleType, setSignature, invokeProto.name);
         } else if (varHandleCompareAndSetMethods.contains(invokeProto.name)) {
-          return createMethod(varHandleType, compareAndSetSignature, invokeProto.name);
+          result = createMethod(varHandleType, compareAndSetSignature, invokeProto.name);
         }
       }
-      return null;
+      assert (result != null) == isPolymorphicInvoke(invokeProto);
+      return result;
     }
 
     private Set<DexString> createStrings(String... strings) {
@@ -2088,6 +2090,18 @@
       }
       return map.keySet();
     }
+
+    public boolean isPolymorphicInvoke(DexMethod invokeProto) {
+      if (invokeProto.holder == methodHandleType) {
+        return invokeProto.name == invokeMethodName || invokeProto.name == invokeExactMethodName;
+      }
+      if (invokeProto.holder == varHandleType) {
+        return varHandleMethods.contains(invokeProto.name)
+            || varHandleSetMethods.contains(invokeProto.name)
+            || varHandleCompareAndSetMethods.contains(invokeProto.name);
+      }
+      return false;
+    }
   }
 
   public class ProxyMethods {
diff --git a/src/main/java/com/android/tools/r8/graph/UseRegistry.java b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
index 396fa40..39f69b8 100644
--- a/src/main/java/com/android/tools/r8/graph/UseRegistry.java
+++ b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
@@ -3,7 +3,10 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.graph;
 
+import static com.android.tools.r8.graph.GraphLens.getIdentityLens;
+
 import com.android.tools.r8.code.CfOrDexInstruction;
+import com.android.tools.r8.ir.code.Invoke;
 import com.android.tools.r8.utils.TraversalContinuation;
 import java.util.ListIterator;
 
@@ -41,6 +44,11 @@
     return context;
   }
 
+  public final DexClassAndMethod getMethodContext() {
+    assert context.isMethod();
+    return context.asMethod();
+  }
+
   public TraversalContinuation getTraversalContinuation() {
     return continuation;
   }
@@ -51,6 +59,25 @@
 
   public abstract void registerInvokeDirect(DexMethod method);
 
+  public void registerInvokeSpecial(DexMethod method, boolean itf) {
+    registerInvokeSpecial(method);
+  }
+
+  public void registerInvokeSpecial(DexMethod method) {
+    // TODO(b/201984767, b/202381923): This needs to supply the right graph lens and original
+    //  context to produce correct invoke types for invoke-special instructions.
+    DexClassAndMethod context = getMethodContext();
+    Invoke.Type type =
+        Invoke.Type.fromInvokeSpecial(
+            method, context, dexItemFactory(), getIdentityLens(), context::getHolderType);
+    if (type.isDirect()) {
+      registerInvokeDirect(method);
+    } else {
+      assert type.isSuper();
+      registerInvokeSuper(method);
+    }
+  }
+
   public abstract void registerInvokeStatic(DexMethod method);
 
   public abstract void registerInvokeInterface(DexMethod method);
diff --git a/src/main/java/com/android/tools/r8/ir/code/Invoke.java b/src/main/java/com/android/tools/r8/ir/code/Invoke.java
index 27f9d4b..0dd562c 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Invoke.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Invoke.java
@@ -16,18 +16,23 @@
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClassAndMethod;
+import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItem;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexMethodHandle.MethodHandleType;
 import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.GraphLens;
+import com.android.tools.r8.graph.GraphLens.MethodLookupResult;
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.conversion.DexBuilder;
 import com.android.tools.r8.utils.BooleanUtils;
 import java.util.List;
 import java.util.Set;
+import java.util.function.Supplier;
 import org.objectweb.asm.Opcodes;
 
 public abstract class Invoke extends Instruction {
@@ -53,12 +58,77 @@
       this.dexOpcodeRange = dexOpcodeRange;
     }
 
+    public static Type fromCfOpcode(
+        int opcode,
+        DexMethod invokedMethod,
+        DexClassAndMethod context,
+        DexItemFactory dexItemFactory,
+        GraphLens graphLens,
+        Supplier<DexType> originalContextProvider) {
+      switch (opcode) {
+        case Opcodes.INVOKEINTERFACE:
+          return Type.INTERFACE;
+        case Opcodes.INVOKESPECIAL:
+          return fromInvokeSpecial(
+              invokedMethod, context, dexItemFactory, graphLens, originalContextProvider);
+        case Opcodes.INVOKESTATIC:
+          return Type.STATIC;
+        case Opcodes.INVOKEVIRTUAL:
+          return dexItemFactory.polymorphicMethods.isPolymorphicInvoke(invokedMethod)
+              ? Type.POLYMORPHIC
+              : Type.VIRTUAL;
+        default:
+          throw new Unreachable("unknown CfInvoke opcode " + opcode);
+      }
+    }
+
+    public static Type fromInvokeSpecial(
+        DexMethod invokedMethod,
+        DexClassAndMethod context,
+        DexItemFactory dexItemFactory,
+        GraphLens graphLens,
+        Supplier<DexType> originalContextProvider) {
+      if (invokedMethod.isInstanceInitializer(dexItemFactory)) {
+        return Type.DIRECT;
+      }
+
+      DexType originalContext = originalContextProvider.get();
+      if (invokedMethod.getHolderType() != originalContext) {
+        return Type.SUPER;
+      }
+
+      MethodLookupResult lookupResult =
+          graphLens.lookupMethod(invokedMethod, context.getReference(), Type.DIRECT);
+      DexEncodedMethod definition = context.getHolder().lookupMethod(lookupResult.getReference());
+      if (definition == null) {
+        return Type.SUPER;
+      }
+
+      if (context.getHolder().isInterface()) {
+        // On interfaces invoke-special should be mapped to invoke-super if the invoke-special
+        // instruction is used to target a default interface method.
+        if (definition.belongsToVirtualPool()) {
+          return Type.SUPER;
+        }
+      } else {
+        // Due to desugaring of invoke-special instructions that target virtual methods, this should
+        // never target a virtual method.
+        // TODO(b/201984767): Reenable this assert. The assert does not always hold when this method
+        //  is called from the UseRegistry and there is a non-empty graph lens.
+        // assert definition.isPrivate() || lookupResult.getType().isVirtual();
+      }
+
+      return Type.DIRECT;
+    }
+
     public int getCfOpcode() {
       switch (this) {
         case DIRECT:
           return Opcodes.INVOKESPECIAL;
         case INTERFACE:
           return Opcodes.INVOKEINTERFACE;
+        case POLYMORPHIC:
+          return Opcodes.INVOKEVIRTUAL;
         case STATIC:
           return Opcodes.INVOKESTATIC;
         case SUPER:
@@ -67,7 +137,6 @@
           return Opcodes.INVOKEVIRTUAL;
         case NEW_ARRAY:
         case MULTI_NEW_ARRAY:
-        case POLYMORPHIC:
         default:
           throw new Unreachable();
       }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
index d265b7b..01f1739 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
@@ -27,6 +27,7 @@
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItem;
+import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexMethodHandle;
 import com.android.tools.r8.graph.DexProto;
@@ -484,6 +485,10 @@
     this.basicBlockNumberGenerator = new NumberGenerator();
   }
 
+  public DexItemFactory dexItemFactory() {
+    return appView.dexItemFactory();
+  }
+
   public DexEncodedMethod getMethod() {
     return method.getDefinition();
   }
diff --git a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialToMissingMethodDeclaredInSuperInterfaceTest.java b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialToMissingMethodDeclaredInSuperInterfaceTest.java
index 2a9e1c0..8cf0a34 100644
--- a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialToMissingMethodDeclaredInSuperInterfaceTest.java
+++ b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialToMissingMethodDeclaredInSuperInterfaceTest.java
@@ -48,11 +48,7 @@
         .addKeepMainRule(Main.class)
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/202381923): invoke-special is not mapped correctly.
-        .applyIf(
-            parameters.canUseDefaultAndStaticInterfaceMethods(),
-            runResult -> runResult.assertFailureWithErrorThatThrows(NoSuchMethodError.class),
-            runResult -> runResult.assertSuccessWithOutputLines("A.foo()"));
+        .assertSuccessWithOutputLines("A.foo()");
   }
 
   private byte[] getClassWithTransformedInvoked() throws IOException {