Fix incorrect mapping from invoke-special in UseRegistry

Bug: 166210854
Bug: 202419103
Change-Id: I2fd57bba997e710e864d76f4166368645f3c64f6
Fixes: 201984767
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java b/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
index b6fd5e0..2e6564d 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
@@ -102,14 +102,7 @@
       NamingLens namingLens,
       LensCodeRewriterUtils rewriter,
       MethodVisitor visitor) {
-    Invoke.Type invokeType =
-        Invoke.Type.fromCfOpcode(
-            opcode,
-            method,
-            context,
-            dexItemFactory,
-            graphLens,
-            () -> context.getDefinition().getCode().asCfCode().getOriginalHolder());
+    Invoke.Type invokeType = Invoke.Type.fromCfOpcode(opcode, method, context, appView);
     MethodLookupResult lookup = graphLens.lookupMethod(method, context.getReference(), invokeType);
     DexMethod rewrittenMethod = lookup.getReference();
     String owner = namingLens.lookupInternalName(rewrittenMethod.holder);
@@ -221,13 +214,7 @@
           AppView<?> appView = builder.appView;
           ProgramMethod context = builder.getProgramMethod();
           canonicalMethod = method;
-          type =
-              Invoke.Type.fromInvokeSpecial(
-                  method,
-                  context,
-                  appView.dexItemFactory(),
-                  appView.graphLens(),
-                  code::getOriginalHolder);
+          type = Invoke.Type.fromInvokeSpecial(method, context, appView);
           break;
         }
       case Opcodes.INVOKESTATIC:
@@ -257,13 +244,7 @@
       builder.addMoveResult(state.push(method.getReturnType()).register);
     }
     assert type
-        == Invoke.Type.fromCfOpcode(
-            opcode,
-            method,
-            builder.getProgramMethod(),
-            builder.dexItemFactory(),
-            builder.appView.graphLens(),
-            code::getOriginalHolder);
+        == Invoke.Type.fromCfOpcode(opcode, method, builder.getProgramMethod(), builder.appView);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/graph/AppView.java b/src/main/java/com/android/tools/r8/graph/AppView.java
index 11201f8..b8177bd 100644
--- a/src/main/java/com/android/tools/r8/graph/AppView.java
+++ b/src/main/java/com/android/tools/r8/graph/AppView.java
@@ -64,7 +64,8 @@
   private AppServices appServices;
   private final DontWarnConfiguration dontWarnConfiguration;
   private final WholeProgramOptimizations wholeProgramOptimizations;
-  private GraphLens graphLens;
+  private GraphLens codeLens = GraphLens.getIdentityLens();
+  private GraphLens graphLens = GraphLens.getIdentityLens();
   private InitClassLens initClassLens;
   private ProguardCompatibilityActions proguardCompatibilityActions;
   private RootSet rootSet;
@@ -119,7 +120,6 @@
     this.appInfo = appInfo;
     this.dontWarnConfiguration = DontWarnConfiguration.create(options().getProguardConfiguration());
     this.wholeProgramOptimizations = wholeProgramOptimizations;
-    this.graphLens = GraphLens.getIdentityLens();
     this.initClassLens = InitClassLens.getThrowingInstance();
     this.rewritePrefix = mapper;
 
@@ -260,7 +260,9 @@
   }
 
   public GraphLens clearCodeRewritings() {
-    return graphLens = graphLens.withCodeRewritingsApplied(dexItemFactory());
+    GraphLens newLens = graphLens.withCodeRewritingsApplied(dexItemFactory());
+    setGraphLens(newLens);
+    return newLens;
   }
 
   public AppServices appServices() {
@@ -431,6 +433,14 @@
     return defaultValue;
   }
 
+  public GraphLens codeLens() {
+    return codeLens;
+  }
+
+  private void setCodeLens(GraphLens codeLens) {
+    this.codeLens = codeLens;
+  }
+
   public GraphLens graphLens() {
     return graphLens;
   }
@@ -439,6 +449,14 @@
   public boolean setGraphLens(GraphLens graphLens) {
     if (graphLens != this.graphLens) {
       this.graphLens = graphLens;
+
+      // TODO(b/202368283): Currently, we always set an applied lens or a clear code rewriting lens
+      //  when the graph lens has been fully applied to all code. Therefore, we implicitly update
+      //  the code lens when these lenses are set. Now that we have an explicit code lens, the clear
+      //  code rewriting lens is redundant and could be removed.
+      if (graphLens.isAppliedLens() || graphLens.isClearCodeRewritingLens()) {
+        setCodeLens(graphLens);
+      }
       return true;
     }
     return false;
diff --git a/src/main/java/com/android/tools/r8/graph/GraphLens.java b/src/main/java/com/android/tools/r8/graph/GraphLens.java
index 3b816d8..fdba79b 100644
--- a/src/main/java/com/android/tools/r8/graph/GraphLens.java
+++ b/src/main/java/com/android/tools/r8/graph/GraphLens.java
@@ -469,6 +469,10 @@
     return false;
   }
 
+  public boolean isClearCodeRewritingLens() {
+    return false;
+  }
+
   public abstract boolean isIdentityLens();
 
   public boolean isMemberRebindingLens() {
@@ -971,6 +975,11 @@
     }
 
     @Override
+    public boolean isClearCodeRewritingLens() {
+      return true;
+    }
+
+    @Override
     public RewrittenPrototypeDescription lookupPrototypeChangesForMethodDefinition(
         DexMethod method) {
       return getIdentityLens().lookupPrototypeChangesForMethodDefinition(method);
diff --git a/src/main/java/com/android/tools/r8/graph/UseRegistry.java b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
index e062b1a..85241b8 100644
--- a/src/main/java/com/android/tools/r8/graph/UseRegistry.java
+++ b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
@@ -62,12 +62,8 @@
   }
 
   public void registerInvokeSpecial(DexMethod method) {
-    // TODO(b/201984767, b/202381923): This needs to supply the right original context to produce
-    //  correct invoke types for invoke-special instructions.
     DexClassAndMethod context = getMethodContext();
-    Invoke.Type type =
-        Invoke.Type.fromInvokeSpecial(
-            method, context, dexItemFactory(), appView.graphLens(), context::getHolderType);
+    Invoke.Type type = Invoke.Type.fromInvokeSpecial(method, context, appView);
     if (type.isDirect()) {
       registerInvokeDirect(method);
     } else {
diff --git a/src/main/java/com/android/tools/r8/graph/classmerging/VerticallyMergedClasses.java b/src/main/java/com/android/tools/r8/graph/classmerging/VerticallyMergedClasses.java
index 3b59401..ca19375 100644
--- a/src/main/java/com/android/tools/r8/graph/classmerging/VerticallyMergedClasses.java
+++ b/src/main/java/com/android/tools/r8/graph/classmerging/VerticallyMergedClasses.java
@@ -17,13 +17,19 @@
 public class VerticallyMergedClasses implements MergedClasses {
 
   private final BidirectionalManyToOneMap<DexType, DexType> mergedClasses;
+  private final BidirectionalManyToOneMap<DexType, DexType> mergedInterfaces;
 
-  public VerticallyMergedClasses(BidirectionalManyToOneMap<DexType, DexType> mergedClasses) {
+  public VerticallyMergedClasses(
+      BidirectionalManyToOneMap<DexType, DexType> mergedClasses,
+      BidirectionalManyToOneMap<DexType, DexType> mergedInterfaces) {
     this.mergedClasses = mergedClasses;
+    this.mergedInterfaces = mergedInterfaces;
   }
 
   public static VerticallyMergedClasses empty() {
-    return new VerticallyMergedClasses(new EmptyBidirectionalOneToOneMap<>());
+    EmptyBidirectionalOneToOneMap<DexType, DexType> emptyMap =
+        new EmptyBidirectionalOneToOneMap<>();
+    return new VerticallyMergedClasses(emptyMap, emptyMap);
   }
 
   @Override
@@ -52,6 +58,10 @@
     return mergedClasses.containsKey(type);
   }
 
+  public boolean hasInterfaceBeenMergedIntoSubtype(DexType type) {
+    return mergedInterfaces.containsKey(type);
+  }
+
   public boolean isEmpty() {
     return mergedClasses.isEmpty();
   }
diff --git a/src/main/java/com/android/tools/r8/ir/code/Invoke.java b/src/main/java/com/android/tools/r8/ir/code/Invoke.java
index 0dd562c..e840ecc 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Invoke.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Invoke.java
@@ -32,7 +32,6 @@
 import com.android.tools.r8.utils.BooleanUtils;
 import java.util.List;
 import java.util.Set;
-import java.util.function.Supplier;
 import org.objectweb.asm.Opcodes;
 
 public abstract class Invoke extends Instruction {
@@ -59,22 +58,16 @@
     }
 
     public static Type fromCfOpcode(
-        int opcode,
-        DexMethod invokedMethod,
-        DexClassAndMethod context,
-        DexItemFactory dexItemFactory,
-        GraphLens graphLens,
-        Supplier<DexType> originalContextProvider) {
+        int opcode, DexMethod invokedMethod, DexClassAndMethod context, AppView<?> appView) {
       switch (opcode) {
         case Opcodes.INVOKEINTERFACE:
           return Type.INTERFACE;
         case Opcodes.INVOKESPECIAL:
-          return fromInvokeSpecial(
-              invokedMethod, context, dexItemFactory, graphLens, originalContextProvider);
+          return fromInvokeSpecial(invokedMethod, context, appView);
         case Opcodes.INVOKESTATIC:
           return Type.STATIC;
         case Opcodes.INVOKEVIRTUAL:
-          return dexItemFactory.polymorphicMethods.isPolymorphicInvoke(invokedMethod)
+          return appView.dexItemFactory().polymorphicMethods.isPolymorphicInvoke(invokedMethod)
               ? Type.POLYMORPHIC
               : Type.VIRTUAL;
         default:
@@ -83,17 +76,16 @@
     }
 
     public static Type fromInvokeSpecial(
-        DexMethod invokedMethod,
-        DexClassAndMethod context,
-        DexItemFactory dexItemFactory,
-        GraphLens graphLens,
-        Supplier<DexType> originalContextProvider) {
-      if (invokedMethod.isInstanceInitializer(dexItemFactory)) {
+        DexMethod invokedMethod, DexClassAndMethod context, AppView<?> appView) {
+      if (invokedMethod.isInstanceInitializer(appView.dexItemFactory())) {
         return Type.DIRECT;
       }
 
-      DexType originalContext = originalContextProvider.get();
-      if (invokedMethod.getHolderType() != originalContext) {
+      GraphLens graphLens = appView.graphLens();
+      GraphLens codeLens = appView.codeLens();
+      DexMethod originalContext =
+          graphLens.getOriginalMethodSignature(context.getReference(), codeLens);
+      if (invokedMethod.getHolderType() != originalContext.getHolderType()) {
         return Type.SUPER;
       }
 
@@ -104,7 +96,26 @@
         return Type.SUPER;
       }
 
-      if (context.getHolder().isInterface()) {
+      // If the definition was moved to the current context from a super class due to vertical class
+      // merging, then this used to be an invoke-super.
+      DexType originalHolderOfDefinition =
+          graphLens.getOriginalMethodSignature(definition.getReference(), codeLens).getHolderType();
+      if (originalHolderOfDefinition != originalContext.getHolderType()) {
+        if (appView.hasVerticallyMergedClasses()
+            && appView
+                .verticallyMergedClasses()
+                .hasBeenMergedIntoSubtype(originalHolderOfDefinition)) {
+          return Type.SUPER;
+        }
+      }
+
+      boolean originalContextIsInterface =
+          context.getHolder().isInterface()
+              || (appView.hasVerticallyMergedClasses()
+                  && appView
+                      .verticallyMergedClasses()
+                      .hasInterfaceBeenMergedIntoSubtype(originalContext.getHolderType()));
+      if (originalContextIsInterface) {
         // On interfaces invoke-special should be mapped to invoke-super if the invoke-special
         // instruction is used to target a default interface method.
         if (definition.belongsToVirtualPool()) {
@@ -113,9 +124,7 @@
       } else {
         // Due to desugaring of invoke-special instructions that target virtual methods, this should
         // never target a virtual method.
-        // TODO(b/201984767): Reenable this assert. The assert does not always hold when this method
-        //  is called from the UseRegistry and there is a non-empty graph lens.
-        // assert definition.isPrivate() || lookupResult.getType().isVirtual();
+        assert definition.isPrivate() || lookupResult.getType().isVirtual();
       }
 
       return Type.DIRECT;
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index 8d38af0..02c89bf 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -231,6 +231,9 @@
   private final MutableBidirectionalManyToOneMap<DexType, DexType> mergedClasses =
       BidirectionalManyToOneHashMap.newIdentityHashMap();
 
+  private final MutableBidirectionalManyToOneMap<DexType, DexType> mergedInterfaces =
+      BidirectionalManyToOneHashMap.newIdentityHashMap();
+
   // Set of types that must not be merged into their subtype.
   private final Set<DexType> pinnedTypes = Sets.newIdentityHashSet();
 
@@ -701,7 +704,8 @@
     }
     timing.end();
 
-    VerticallyMergedClasses verticallyMergedClasses = new VerticallyMergedClasses(mergedClasses);
+    VerticallyMergedClasses verticallyMergedClasses =
+        new VerticallyMergedClasses(mergedClasses, mergedInterfaces);
     appView.setVerticallyMergedClasses(verticallyMergedClasses);
     if (verticallyMergedClasses.isEmpty()) {
       return null;
@@ -1177,6 +1181,9 @@
       source.clearStaticFields();
       // Step 5: Record merging.
       mergedClasses.put(source.type, target.type);
+      if (source.isInterface()) {
+        mergedInterfaces.put(source.type, target.type);
+      }
       assert !abortMerge;
       assert GenericSignatureCorrectnessHelper.createForVerification(
               appView, GenericSignatureContextBuilder.createForSingleClass(appView, target))
diff --git a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialForInvokeVirtualTest.java b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialForInvokeVirtualTest.java
index 2b0324c..41c7da1 100644
--- a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialForInvokeVirtualTest.java
+++ b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialForInvokeVirtualTest.java
@@ -53,8 +53,7 @@
         .addKeepMainRule(Main.class)
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/166210854): Fails but should not.
-        .assertFailure();
+        .assertSuccessWithOutput(EXPECTED);
   }
 
   private byte[] getClassBWithTransformedInvoked() throws IOException {
diff --git a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialForNonDeclaredInvokeVirtualTest.java b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialForNonDeclaredInvokeVirtualTest.java
index b9156fb..05e6c38 100644
--- a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialForNonDeclaredInvokeVirtualTest.java
+++ b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialForNonDeclaredInvokeVirtualTest.java
@@ -53,8 +53,7 @@
         .addKeepMainRule(Main.class)
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/166210854): Fails but should not.
-        .assertFailure();
+        .assertSuccessWithOutput(EXPECTED);
   }
 
   private byte[] getClassCWithTransformedInvoked() throws IOException {
diff --git a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialInterfaceWithBridgeTest.java b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialInterfaceWithBridgeTest.java
index 31ec2c6..477532b 100644
--- a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialInterfaceWithBridgeTest.java
+++ b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialInterfaceWithBridgeTest.java
@@ -50,8 +50,7 @@
         .addKeepMainRule(Main.class)
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/166210854): Fails but should not.
-        .assertFailure();
+        .assertSuccessWithOutputLines("Hello World!");
   }
 
   private byte[] getClassWithTransformedInvoked() throws IOException {
diff --git a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialToMissingMethodDeclaredInSuperClassTest.java b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialToMissingMethodDeclaredInSuperClassTest.java
index fca3a67..411bf18 100644
--- a/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialToMissingMethodDeclaredInSuperClassTest.java
+++ b/src/test/java/com/android/tools/r8/graph/invokespecial/InvokeSpecialToMissingMethodDeclaredInSuperClassTest.java
@@ -48,7 +48,7 @@
         .addKeepMainRule(Main.class)
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), Main.class)
-        .assertFailureWithErrorThatThrows(NoSuchMethodError.class);
+        .assertSuccessWithOutputLines("A.foo()");
   }
 
   private byte[] getClassWithTransformedInvoked() throws IOException {