Leverage exact type of receiver when finding call targets.

Bug: 141580674
Change-Id: I96690d294c1fa7fb26702744ecf27011efb09b92
diff --git a/src/main/java/com/android/tools/r8/ir/code/Instruction.java b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
index 1a17a8b..cd3d513 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Instruction.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
@@ -1112,6 +1112,10 @@
     return null;
   }
 
+  public boolean isInvokeMethodWithDynamicDispatch() {
+    return isInvokeInterface() || isInvokeVirtual();
+  }
+
   public boolean isInvokeMethod() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/code/InvokeInterface.java b/src/main/java/com/android/tools/r8/ir/code/InvokeInterface.java
index a3fb554..fc57fdc 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InvokeInterface.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InvokeInterface.java
@@ -20,6 +20,7 @@
 import com.android.tools.r8.ir.optimize.InliningConstraints;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 
 public class InvokeInterface extends InvokeMethodWithReceiver {
@@ -104,7 +105,13 @@
   @Override
   public Collection<DexEncodedMethod> lookupTargets(
       AppView<? extends AppInfoWithSubtyping> appView, DexType invocationContext) {
+    // Leverage exact receiver type if available.
+    DexEncodedMethod singleTarget = lookupSingleTarget(appView, invocationContext);
+    if (singleTarget != null) {
+      return Collections.singletonList(singleTarget);
+    }
     DexMethod method = getInvokedMethod();
+    // TODO(b/141580674): we could filter out some targets based on refined receiver type.
     return appView
         .appInfo()
         .resolveMethodOnInterface(method.holder, method)
diff --git a/src/main/java/com/android/tools/r8/ir/code/InvokeVirtual.java b/src/main/java/com/android/tools/r8/ir/code/InvokeVirtual.java
index e9d5142..d7585c3 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InvokeVirtual.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InvokeVirtual.java
@@ -22,6 +22,7 @@
 import com.android.tools.r8.ir.optimize.InliningConstraints;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.function.Predicate;
 
@@ -107,7 +108,13 @@
   @Override
   public Collection<DexEncodedMethod> lookupTargets(
       AppView<? extends AppInfoWithSubtyping> appView, DexType invocationContext) {
+    // Leverage exact receiver type if available.
+    DexEncodedMethod singleTarget = lookupSingleTarget(appView, invocationContext);
+    if (singleTarget != null) {
+      return Collections.singletonList(singleTarget);
+    }
     DexMethod method = getInvokedMethod();
+    // TODO(b/141580674): we could filter out some targets based on refined receiver type.
     return appView
         .appInfo()
         .resolveMethodOnClass(method.holder, method)
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
index 8d3f18f..4826545 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
@@ -18,6 +18,7 @@
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.optimize.info.CallSiteOptimizationInfo;
 import com.android.tools.r8.ir.optimize.info.MutableCallSiteOptimizationInfo;
@@ -91,22 +92,25 @@
         continue;
       }
       if (instruction.isInvokeMethod()) {
-        // For virtual and interface calls, proceed on valid results only (since it's enforced).
-        if (instruction.isInvokeVirtual() || instruction.isInvokeInterface()) {
-          DexMethod invokedMethod = instruction.asInvokeMethod().getInvokedMethod();
+        InvokeMethod invoke = instruction.asInvokeMethod();
+        if (invoke.isInvokeMethodWithDynamicDispatch()) {
+          DexMethod invokedMethod = invoke.getInvokedMethod();
           ResolutionResult resolutionResult =
               appView.appInfo().resolveMethod(invokedMethod.holder, invokedMethod);
+          // For virtual and interface calls, proceed on valid results only (since it's enforced).
           if (!resolutionResult.isValidVirtualTarget(appView.options())) {
             continue;
           }
         }
-        Collection<DexEncodedMethod> targets =
-            instruction.asInvokeMethod().lookupTargets(appView, context.method.holder);
-        if (targets == null) {
+        Collection<DexEncodedMethod> targets = invoke.lookupTargets(appView, context.method.holder);
+        assert invoke.isInvokeMethodWithDynamicDispatch()
+            // For other invocation types, the size of targets should be at most one.
+            || targets == null || targets.size() <= 1;
+        if (targets == null || targets.isEmpty()) {
           continue;
         }
         for (DexEncodedMethod target : targets) {
-          recordArgumentsIfNecessary(context, target, instruction.inValues());
+          recordArgumentsIfNecessary(context, target, invoke.inValues());
         }
       }
       // TODO(b/129458850): if lambda desugaring happens before IR processing, seeing invoke-custom
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/dynamictype/InvokeInterfacePositiveTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/dynamictype/InvokeInterfacePositiveTest.java
index ce2fd2e..c720d09 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/dynamictype/InvokeInterfacePositiveTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/dynamictype/InvokeInterfacePositiveTest.java
@@ -113,8 +113,9 @@
   static class Main {
     public static void main(String... args) {
       I i = System.currentTimeMillis() > 0 ? new A() : new B();
-      i.m(new Sub1());       // calls A.m() with Sub1.
-      new B().m(new Sub2()); // calls B.m() with Sub2.
+      i.m(new Sub1());  // calls A.m() with Sub1.
+      i = new B();      // with the exact type:
+      i.m(new Sub2());  // calls B.m() with Sub2.
     }
   }
 }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeVirtualPositiveTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeVirtualPositiveTest.java
index 7d0838c..70aa862 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeVirtualPositiveTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeVirtualPositiveTest.java
@@ -46,7 +46,7 @@
         .enableInliningAnnotations()
         .setMinApi(parameters.getRuntime())
         .run(parameters.getRuntime(), MAIN)
-        .assertSuccessWithOutputLines("A", "B")
+        .assertSuccessWithOutputLines("A", "null")
         .inspect(this::inspect);
   }
 
@@ -64,8 +64,8 @@
 
     MethodSubject b_m = b.uniqueMethodWithName("m");
     assertThat(b_m, isPresent());
-    // Can optimize branches since `arg` is definitely not null.
-    assertTrue(b_m.streamInstructions().noneMatch(InstructionSubject::isIf));
+    // Should not optimize branches since the nullability of `arg` is unsure.
+    assertTrue(b_m.streamInstructions().anyMatch(InstructionSubject::isIf));
   }
 
   @NeverMerge
@@ -113,8 +113,8 @@
       A a = System.currentTimeMillis() > 0 ? new A() : new B();
       a.m(a);  // calls A.m() with non-null instance.
 
-      B b = new B();
-      b.m(b);  // calls B.m() with non-null instance
+      A b = new B();  // with the exact type:
+      b.m(null);      // calls B.m() with null.
     }
   }
 }