Fix inliner downcast insertion for invoke-interface instructions

Bug: 199561570
Change-Id: I7160645b26bb9a3240ad1cc5fef6be905e05c227
diff --git a/src/main/java/com/android/tools/r8/graph/DexType.java b/src/main/java/com/android/tools/r8/graph/DexType.java
index a38770a..70b9b1d 100644
--- a/src/main/java/com/android/tools/r8/graph/DexType.java
+++ b/src/main/java/com/android/tools/r8/graph/DexType.java
@@ -7,6 +7,8 @@
 
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.ir.analysis.type.Nullability;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.references.ClassReference;
 import com.android.tools.r8.references.Reference;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
@@ -52,6 +54,10 @@
     return Reference.classFromDescriptor(toDescriptorString());
   }
 
+  public TypeElement toTypeElement(AppView<?> appView) {
+    return TypeElement.fromDexType(this, Nullability.maybeNull(), appView);
+  }
+
   @Override
   public int compareTo(DexReference other) {
     if (other.isDexType()) {
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
index 7bd9cf9..8c8359c 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
@@ -739,7 +739,7 @@
       IRCode inlinee,
       ListIterator<BasicBlock> blocksIterator,
       Set<BasicBlock> blocksToRemove,
-      DexType downcast) {
+      DexProgramClass downcast) {
     assert blocksToRemove != null;
     ProgramMethod callerContext = code.context();
     ProgramMethod calleeContext = inlinee.context();
@@ -783,9 +783,9 @@
       //  instruction if the program still type checks without the cast.
       Value receiver = invoke.inValues().get(0);
       TypeElement castTypeLattice =
-          TypeElement.fromDexType(downcast, receiver.getType().nullability(), appView);
+          TypeElement.fromDexType(downcast.getType(), receiver.getType().nullability(), appView);
       CheckCast castInstruction =
-          new CheckCast(code.createValue(castTypeLattice), receiver, downcast);
+          new CheckCast(code.createValue(castTypeLattice), receiver, downcast.getType());
       castInstruction.setPosition(invoke.getPosition());
 
       // Splice in the check cast operation.
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCodeInstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/IRCodeInstructionListIterator.java
index 4e223924..f9bbcc8 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCodeInstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCodeInstructionListIterator.java
@@ -9,6 +9,7 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DebugLocalInfo;
 import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -129,7 +130,7 @@
       IRCode inlinee,
       ListIterator<BasicBlock> blockIterator,
       Set<BasicBlock> blocksToRemove,
-      DexType downcast) {
+      DexProgramClass downcast) {
     throw new Unimplemented();
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/InstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/InstructionListIterator.java
index 14a7d45..d08413d 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InstructionListIterator.java
@@ -9,6 +9,7 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DebugLocalInfo;
 import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -254,9 +255,9 @@
       IRCode inlinee,
       ListIterator<BasicBlock> blockIterator,
       Set<BasicBlock> blocksToRemove,
-      DexType downcast);
+      DexProgramClass downcast);
 
-  /** See {@link #inlineInvoke(AppView, IRCode, IRCode, ListIterator, Set, DexType)}. */
+  /** See {@link #inlineInvoke(AppView, IRCode, IRCode, ListIterator, Set, DexProgramClass)}. */
   default BasicBlock inlineInvoke(AppView<?> appView, IRCode code, IRCode inlinee) {
     Set<BasicBlock> blocksToRemove = Sets.newIdentityHashSet();
     BasicBlock result = inlineInvoke(appView, code, inlinee, null, blocksToRemove, null);
diff --git a/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
index 3e61866..b5aa980 100644
--- a/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
@@ -8,6 +8,7 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DebugLocalInfo;
 import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -146,7 +147,7 @@
       IRCode inlinee,
       ListIterator<BasicBlock> blockIterator,
       Set<BasicBlock> blocksToRemove,
-      DexType downcast) {
+      DexProgramClass downcast) {
     return currentBlockIterator.inlineInvoke(
         appView, code, inlinee, blockIterator, blocksToRemove, downcast);
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
index 8e66601..6f6fa12 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
@@ -19,6 +19,7 @@
 import com.android.tools.r8.graph.ResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
 import com.android.tools.r8.ir.analysis.inlining.SimpleInliningConstraint;
+import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.InstancePut;
@@ -692,7 +693,8 @@
   }
 
   @Override
-  public DexType getReceiverTypeIfKnown(InvokeMethod invoke) {
-    return null; // Maybe improve later.
+  public ClassTypeElement getReceiverTypeOrDefault(
+      InvokeMethod invoke, ClassTypeElement defaultValue) {
+    return defaultValue;
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/ForcedInliningOracle.java b/src/main/java/com/android/tools/r8/ir/optimize/ForcedInliningOracle.java
index c25755c..6387a98 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/ForcedInliningOracle.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/ForcedInliningOracle.java
@@ -5,10 +5,10 @@
 package com.android.tools.r8.ir.optimize;
 
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.ResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
+import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.InvokeDirect;
@@ -135,10 +135,14 @@
   public void markInlined(InlineeWithReason inlinee) {}
 
   @Override
-  public DexType getReceiverTypeIfKnown(InvokeMethod invoke) {
+  public ClassTypeElement getReceiverTypeOrDefault(
+      InvokeMethod invoke, ClassTypeElement defaultValue) {
     assert invoke.isInvokeMethodWithReceiver();
     Inliner.InliningInfo info = invokesToInline.get(invoke.asInvokeMethodWithReceiver());
     assert info != null;
-    return info.receiverType;
+    if (info.receiverClass != null) {
+      return info.receiverClass.getType().toTypeElement(appView).asClassType();
+    }
+    return defaultValue;
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
index 41a298a..ff31a7a 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
@@ -24,6 +24,7 @@
 import com.android.tools.r8.graph.ResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
 import com.android.tools.r8.ir.analysis.proto.ProtoInliningReasonStrategy;
+import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
@@ -838,11 +839,11 @@
 
   public static class InliningInfo {
     public final ProgramMethod target;
-    public final DexType receiverType; // null, if unknown
+    public final DexProgramClass receiverClass; // null, if unknown
 
-    public InliningInfo(ProgramMethod target, DexType receiverType) {
+    public InliningInfo(ProgramMethod target, DexProgramClass receiverClass) {
       this.target = target;
-      this.receiverType = receiverType;
+      this.receiverClass = receiverClass;
     }
   }
 
@@ -1007,14 +1008,11 @@
             continue;
           }
 
-          DexType downcastTypeOrNull = getDowncastTypeIfNeeded(strategy, invoke, singleTarget);
-          if (downcastTypeOrNull != null) {
-            DexClass downcastClass = appView.definitionFor(downcastTypeOrNull, context);
-            if (downcastClass == null
-                || AccessControl.isClassAccessible(downcastClass, context, appView)
-                    .isPossiblyFalse()) {
-              continue;
-            }
+          DexProgramClass downcastClass = getDowncastTypeIfNeeded(strategy, invoke, singleTarget);
+          if (downcastClass != null
+              && AccessControl.isClassAccessible(downcastClass, context, appView)
+                  .isPossiblyFalse()) {
+            continue;
           }
 
           if (!inlineeStack.isEmpty()
@@ -1064,7 +1062,7 @@
           iterator.previous();
           strategy.markInlined(inlinee);
           iterator.inlineInvoke(
-              appView, code, inlinee.code, blockIterator, blocksToRemove, downcastTypeOrNull);
+              appView, code, inlinee.code, blockIterator, blocksToRemove, downcastClass);
 
           if (inlinee.reason == Reason.SINGLE_CALLER) {
             feedback.markInlinedIntoSingleCallSite(singleTargetMethod);
@@ -1152,19 +1150,22 @@
     return false;
   }
 
-  private DexType getDowncastTypeIfNeeded(
+  private DexProgramClass getDowncastTypeIfNeeded(
       InliningStrategy strategy, InvokeMethod invoke, ProgramMethod target) {
     if (invoke.isInvokeMethodWithReceiver()) {
-      // If the invoke has a receiver but the actual type of the receiver is different
-      // from the computed target holder, inlining requires a downcast of the receiver.
-      DexType receiverType = strategy.getReceiverTypeIfKnown(invoke);
-      if (receiverType == null) {
-        // In case we don't know exact type of the receiver we use declared
-        // method holder as a fallback.
-        receiverType = invoke.getInvokedMethod().holder;
+      // If the invoke has a receiver but the actual type of the receiver is different from the
+      // computed target holder, inlining requires a downcast of the receiver. In case we don't know
+      // the exact type of the receiver we use the static type of the receiver.
+      Value receiver = invoke.asInvokeMethodWithReceiver().getReceiver();
+      if (!receiver.getType().isClassType()) {
+        return target.getHolder();
       }
-      if (!appView.appInfo().isSubtype(receiverType, target.getHolderType())) {
-        return target.getHolderType();
+
+      ClassTypeElement receiverType =
+          strategy.getReceiverTypeOrDefault(invoke, receiver.getType().asClassType());
+      ClassTypeElement targetType = target.getHolderType().toTypeElement(appView).asClassType();
+      if (!receiverType.lessThanOrEqualUpToNullability(targetType, appView)) {
+        return target.getHolder();
       }
     }
     return null;
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/InliningStrategy.java b/src/main/java/com/android/tools/r8/ir/optimize/InliningStrategy.java
index 1dc076f..b6d33ea 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/InliningStrategy.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/InliningStrategy.java
@@ -4,8 +4,8 @@
 
 package com.android.tools.r8.ir.optimize;
 
-import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.InvokeDirect;
@@ -49,5 +49,5 @@
 
   void ensureMethodProcessed(ProgramMethod target, IRCode inlinee, OptimizationFeedback feedback);
 
-  DexType getReceiverTypeIfKnown(InvokeMethod invoke);
+  ClassTypeElement getReceiverTypeOrDefault(InvokeMethod invoke, ClassTypeElement defaultValue);
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
index 18ead7e..8cdeaca 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
@@ -441,7 +441,7 @@
               throw new IllegalClassInlinerStateException();
             }
 
-            directMethodCalls.put(invoke, new InliningInfo(singleTarget, eligibleClass.type));
+            directMethodCalls.put(invoke, new InliningInfo(singleTarget, eligibleClass));
             break;
           }
         }
@@ -900,7 +900,7 @@
               .getParent();
     }
 
-    return new InliningInfo(singleTarget, eligibleClass.type);
+    return new InliningInfo(singleTarget, eligibleClass);
   }
 
   // An invoke is eligible for inlining in the following cases:
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/inliner/InterfaceInvokeWithObjectReceiverInliningTest.java b/src/test/java/com/android/tools/r8/ir/optimize/inliner/InterfaceInvokeWithObjectReceiverInliningTest.java
index 255e1fe..63b21bc 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/inliner/InterfaceInvokeWithObjectReceiverInliningTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/inliner/InterfaceInvokeWithObjectReceiverInliningTest.java
@@ -4,8 +4,6 @@
 
 package com.android.tools.r8.ir.optimize.inliner;
 
-import static org.hamcrest.CoreMatchers.anyOf;
-import static org.hamcrest.CoreMatchers.containsString;
 import static org.junit.Assume.assumeFalse;
 
 import com.android.tools.r8.NeverInline;
@@ -13,7 +11,6 @@
 import com.android.tools.r8.R8TestBuilder;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.ToolHelper.DexVm.Version;
 import com.android.tools.r8.transformers.ClassFileTransformer.MethodPredicate;
 import com.android.tools.r8.utils.BooleanUtils;
 import java.io.IOException;
@@ -73,17 +70,9 @@
         .run(parameters.getRuntime(), Main.class)
         // TODO(b/199561570): Should succeed with 0.
         .applyIf(
-            (enableInlining
-                    && (!enableVerticalClassMerging
-                        || parameters.isDexRuntimeVersion(Version.V5_1_1)
-                        || parameters.isDexRuntimeVersion(Version.V6_0_1)))
-                || (!enableInlining && !enableVerticalClassMerging),
+            enableInlining || !enableVerticalClassMerging,
             runResult -> runResult.assertSuccessWithOutputLines("0"),
-            runResult ->
-                runResult.assertFailureWithErrorThatMatches(
-                    anyOf(
-                        containsString(NoSuchFieldError.class.getTypeName()),
-                        containsString(VerifyError.class.getTypeName()))));
+            runResult -> runResult.assertFailureWithErrorThatThrows(VerifyError.class));
   }
 
   private static byte[] getTransformedMain() throws IOException {
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java b/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
index 105008c..f3db67c 100644
--- a/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
@@ -12,6 +12,7 @@
 import com.android.tools.r8.graph.DebugLocalInfo;
 import com.android.tools.r8.graph.DexApplication;
 import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -203,7 +204,7 @@
         IRCode inlinee,
         ListIterator<BasicBlock> blockIterator,
         Set<BasicBlock> blocksToRemove,
-        DexType downcast) {
+        DexProgramClass downcast) {
       throw new Unimplemented();
     }
   }