Strengthen insertion of AssumeDynamicType

Fixes: b/327130357
Fixes: b/289002809
Change-Id: I05f660943491855f59cc42c5fe6ec9160a4e6e87
diff --git a/src/main/java/com/android/tools/r8/ir/code/If.java b/src/main/java/com/android/tools/r8/ir/code/If.java
index 4a3db76..fac1512 100644
--- a/src/main/java/com/android/tools/r8/ir/code/If.java
+++ b/src/main/java/com/android/tools/r8/ir/code/If.java
@@ -48,6 +48,11 @@
     return visitor.visit(this);
   }
 
+  public boolean isInstanceOfTest() {
+    return isZeroTest()
+        && lhs().getAliasedValue().isDefinedByInstructionSatisfying(Instruction::isInstanceOf);
+  }
+
   public boolean isNullTest() {
     return isZeroTest() && lhs().getType().isReferenceType();
   }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
index 8215404..4f42153 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
@@ -14,10 +14,10 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.analysis.type.DynamicType;
-import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.analysis.type.TypeUtils;
 import com.android.tools.r8.ir.code.Assume;
+import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.CheckCast;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.InstanceOf;
@@ -78,38 +78,45 @@
     // CheckCast at line 4 unless we update v7 with the most precise information by narrowing the
     // affected values of v5. We therefore have to run the type analysis after each CheckCast
     // removal.
-    TypeAnalysis typeAnalysis =
-        new TypeAnalysis(appView, code).setKeepRedundantBlocksAfterAssumeRemoval(true);
-    AffectedValues affectedValues = new AffectedValues();
-    InstructionListIterator it = code.instructionListIterator();
     boolean needToRemoveTrivialPhis = false;
-    while (it.hasNext()) {
-      Instruction current = it.next();
-      if (current.isCheckCast()) {
-        boolean hasPhiUsers = current.outValue().hasPhiUsers();
-        RemoveCheckCastInstructionIfTrivialResult removeResult =
-            removeCheckCastInstructionIfTrivial(
-                appViewWithLiveness,
-                current.asCheckCast(),
-                it,
+    for (BasicBlock block : code.getBlocks()) {
+      InstructionListIterator it = block.listIterator(code);
+      while (it.hasNext()) {
+        Instruction current = it.next();
+        if (current.isCheckCast()) {
+          boolean hasPhiUsers = current.outValue().hasPhiUsers();
+          AffectedValues affectedValues = new AffectedValues();
+          RemoveCheckCastInstructionIfTrivialResult removeResult =
+              removeCheckCastInstructionIfTrivial(
+                  appViewWithLiveness,
+                  current.asCheckCast(),
+                  it,
+                  code,
+                  code.context(),
+                  affectedValues,
+                  methodProcessor,
+                  methodProcessingContext);
+          if (removeResult != RemoveCheckCastInstructionIfTrivialResult.NO_REMOVALS) {
+            assert removeResult == RemoveCheckCastInstructionIfTrivialResult.REMOVED_CAST_DO_NARROW;
+            hasChanged = true;
+            needToRemoveTrivialPhis |= hasPhiUsers;
+            int blockSizeBeforeAssumeRemoval = block.size();
+            Instruction previous = it.peekPrevious();
+            affectedValues.narrowingWithAssumeRemoval(
+                appView,
                 code,
-                code.context(),
-                affectedValues,
-                methodProcessor,
-                methodProcessingContext);
-        if (removeResult != RemoveCheckCastInstructionIfTrivialResult.NO_REMOVALS) {
-          assert removeResult == RemoveCheckCastInstructionIfTrivialResult.REMOVED_CAST_DO_NARROW;
-          hasChanged = true;
-          needToRemoveTrivialPhis |= hasPhiUsers;
-          typeAnalysis.narrowing(affectedValues);
-          affectedValues.clear();
-        }
-      } else if (current.isInstanceOf()) {
-        boolean hasPhiUsers = current.outValue().hasPhiUsers();
-        if (removeInstanceOfInstructionIfTrivial(
-            appViewWithLiveness, current.asInstanceOf(), it, code)) {
-          hasChanged = true;
-          needToRemoveTrivialPhis |= hasPhiUsers;
+                typeAnalysis -> typeAnalysis.setKeepRedundantBlocksAfterAssumeRemoval(true));
+            if (block.size() != blockSizeBeforeAssumeRemoval) {
+              it = previous != null ? block.listIterator(code, previous) : block.listIterator(code);
+            }
+          }
+        } else if (current.isInstanceOf()) {
+          boolean hasPhiUsers = current.outValue().hasPhiUsers();
+          if (removeInstanceOfInstructionIfTrivial(
+              appViewWithLiveness, current.asInstanceOf(), it, code)) {
+            hasChanged = true;
+            needToRemoveTrivialPhis |= hasPhiUsers;
+          }
         }
       }
     }
@@ -120,9 +127,13 @@
     // Removing check-cast may result in a trivial phi:
     // v3 <- phi(v1, v1)
     if (needToRemoveTrivialPhis) {
+      AffectedValues affectedValues = new AffectedValues();
       code.removeAllDeadAndTrivialPhis(affectedValues);
+      affectedValues.narrowingWithAssumeRemoval(
+          appView,
+          code,
+          typeAnalysis -> typeAnalysis.setKeepRedundantBlocksAfterAssumeRemoval(true));
     }
-    typeAnalysis.narrowingWithAssumeRemoval(affectedValues);
     if (hasChanged) {
       code.removeRedundantBlocks();
     }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/AffectedValues.java b/src/main/java/com/android/tools/r8/ir/optimize/AffectedValues.java
index 217bd5a..8ee6379 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/AffectedValues.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/AffectedValues.java
@@ -4,6 +4,8 @@
 
 package com.android.tools.r8.ir.optimize;
 
+import static com.android.tools.r8.utils.ConsumerUtils.emptyConsumer;
+
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.code.BasicBlock;
@@ -14,6 +16,7 @@
 import java.util.Collection;
 import java.util.Iterator;
 import java.util.Set;
+import java.util.function.Consumer;
 import java.util.function.Predicate;
 
 public class AffectedValues implements Set<Value> {
@@ -35,8 +38,15 @@
   }
 
   public void narrowingWithAssumeRemoval(AppView<?> appView, IRCode code) {
+    narrowingWithAssumeRemoval(appView, code, emptyConsumer());
+  }
+
+  public void narrowingWithAssumeRemoval(
+      AppView<?> appView, IRCode code, Consumer<TypeAnalysis> typeAnalysisConsumer) {
     if (hasNext()) {
-      new TypeAnalysis(appView, code).narrowingWithAssumeRemoval(this);
+      TypeAnalysis typeAnalysis = new TypeAnalysis(appView, code);
+      typeAnalysisConsumer.accept(typeAnalysis);
+      typeAnalysis.narrowingWithAssumeRemoval(this);
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/AssumeInserter.java b/src/main/java/com/android/tools/r8/ir/optimize/AssumeInserter.java
index f151273..a08f82d 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/AssumeInserter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/AssumeInserter.java
@@ -25,6 +25,7 @@
 import com.android.tools.r8.ir.code.FieldInstruction;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.If;
+import com.android.tools.r8.ir.code.InstanceOf;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InstructionIterator;
 import com.android.tools.r8.ir.code.InstructionListIterator;
@@ -40,6 +41,7 @@
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.TriFunction;
 import com.android.tools.r8.utils.TriPredicate;
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
 import it.unimi.dsi.fastutil.ints.IntArrayList;
 import it.unimi.dsi.fastutil.ints.IntList;
@@ -192,24 +194,38 @@
     }
 
     If ifInstruction = block.exit().asIf();
-    if (ifInstruction != null && ifInstruction.isNonTrivialNullTest()) {
+    if (ifInstruction != null) {
       Value lhs = ifInstruction.lhs();
-      if (assumedValuesBuilder.isMaybeNullAndNotNullType(lhs)
-          && isNullableReferenceTypeWithOtherNonDebugUsers(lhs, ifInstruction)
-          && ifInstruction.targetFromNonNullObject().getPredecessors().size() == 1) {
-        assumedValuesBuilder.addNonNullValueWithUnknownDominance(ifInstruction, lhs);
+      if (ifInstruction.isInstanceOfTest()) {
+        InstanceOf instanceOf = lhs.getAliasedValue().getDefinition().asInstanceOf();
+        Value value = instanceOf.value();
+        DynamicTypeWithUpperBound dynamicType =
+            DynamicType.create(
+                appView, instanceOf.type().toTypeElement(appView, value.getType().nullability()));
+        DynamicTypeWithUpperBound staticType = DynamicType.create(appView, value.getType());
+        if (dynamicType.strictlyLessThan(staticType, appView)
+            && hasOtherNonDebugUsers(value, instanceOf)) {
+          assumedValuesBuilder.addAssumedValueWithUnknownDominance(
+              ifInstruction, value, dynamicType);
+        }
+      } else if (ifInstruction.isNonTrivialNullTest()) {
+        if (assumedValuesBuilder.isMaybeNullAndNotNullType(lhs)
+            && isNullableReferenceTypeWithOtherNonDebugUsers(lhs, ifInstruction)
+            && ifInstruction.targetFromNonNullObject().getPredecessors().size() == 1) {
+          assumedValuesBuilder.addNonNullValueWithUnknownDominance(ifInstruction, lhs);
+        }
       }
     }
   }
 
   private boolean computeAssumedValuesForInvokeMethod(
       IRCode code, InvokeMethod invoke, AssumedValues.Builder assumedValuesBuilder) {
-    if (!invoke.hasOutValue() && invoke.getInvokedMethod().proto.parameters.isEmpty()) {
+    if (invoke.hasUnusedOutValue() && invoke.arguments().isEmpty()) {
       return false;
     }
 
     DexMethod invokedMethod = invoke.getInvokedMethod();
-    if (invokedMethod.holder.isArrayType()
+    if (invokedMethod.getHolderType().isArrayType()
         && invokedMethod.match(appView.dexItemFactory().objectMembers.clone)) {
       return computeAssumedValuesFromArrayClone(invoke, assumedValuesBuilder);
     }
@@ -241,7 +257,10 @@
       return false;
     }
 
-    DexClassAndMethod singleTarget = invoke.lookupSingleTarget(appView, code.context());
+    DexClassAndMethod singleTarget =
+        resolutionResult
+            .lookupDispatchTarget(appView, invoke, code.context())
+            .getSingleDispatchTarget();
     if (invoke.hasUsedOutValue() && invoke.getOutType().isReferenceType()) {
       AssumeInfo assumeInfo =
           AssumeInfoLookup.lookupAssumeInfo(appView, resolutionResult, singleTarget);
@@ -265,18 +284,51 @@
               invoke, optimizationInfo.getDynamicType(), assumedValuesBuilder);
     }
 
-    // Case (3), parameters that are not null after the invocation.
+    // Case (3), parameters that are not null or have a more specific type after the invocation.
     BitSet nonNullParamOnNormalExits = optimizationInfo.getNonNullParamOnNormalExits();
-    if (nonNullParamOnNormalExits != null) {
-      for (int argumentIndex = 0; argumentIndex < invoke.arguments().size(); argumentIndex++) {
-        boolean isArgumentNonNullOnNormalExits =
-            (invoke.isInvokeMethodWithReceiver() && argumentIndex == 0)
-                || nonNullParamOnNormalExits.get(argumentIndex);
-        if (isArgumentNonNullOnNormalExits) {
-          Value argument = invoke.getArgument(argumentIndex);
-          if (assumedValuesBuilder.isMaybeNullAndNotNullType(argument)
-              && isNullableReferenceTypeWithOtherNonDebugUsers(argument, invoke)) {
-            assumedValuesBuilder.addNonNullValueWithUnknownDominance(invoke, argument);
+    for (int argumentIndex = 0; argumentIndex < invoke.arguments().size(); argumentIndex++) {
+      Value argument = invoke.getArgument(argumentIndex);
+
+      // Nullability.
+      boolean isArgumentNonNullOnNormalExits =
+          (invoke.isInvokeMethodWithReceiver() && argumentIndex == 0)
+              || (nonNullParamOnNormalExits != null
+                  && nonNullParamOnNormalExits.get(argumentIndex));
+      if (isArgumentNonNullOnNormalExits) {
+        if (assumedValuesBuilder.isMaybeNullAndNotNullType(argument)
+            && isNullableReferenceTypeWithOtherNonDebugUsers(argument, invoke)) {
+          assumedValuesBuilder.addNonNullValueWithUnknownDominance(invoke, argument);
+          needsAssumeInstruction = true;
+        }
+      }
+
+      // Type information.
+      if (argumentIndex == optimizationInfo.getReturnedArgument()
+          && hasOtherNonDebugUsers(argument, invoke)) {
+        DynamicTypeWithUpperBound dynamicType =
+            optimizationInfo.getDynamicType().isDynamicTypeWithUpperBound()
+                ? optimizationInfo.getDynamicType().asDynamicTypeWithUpperBound()
+                : DynamicType.unknown();
+        if (singleTarget != null
+            && singleTarget.getDefinition().isInstance()
+            && optimizationInfo.getReturnedArgument() == 0) {
+          DynamicTypeWithUpperBound receiverType =
+              DynamicTypeWithUpperBound.create(
+                  appView,
+                  singleTarget.getHolderType().toTypeElement(appView, definitelyNotNull()));
+          if (receiverType.strictlyLessThan(dynamicType, appView)) {
+            dynamicType = receiverType;
+          }
+        }
+        if (!dynamicType.isUnknown()) {
+          DynamicTypeWithUpperBound staticType = DynamicType.create(appView, argument.getType());
+          if (dynamicType
+              .withNullability(staticType.getNullability())
+              .strictlyLessThan(staticType, appView)) {
+            Nullability meetNullability =
+                dynamicType.getNullability().meet(staticType.getNullability());
+            assumedValuesBuilder.addAssumedValueWithUnknownDominance(
+                invoke, argument, dynamicType.withNullability(meetNullability));
             needsAssumeInstruction = true;
           }
         }
@@ -632,7 +684,12 @@
 
   private BasicBlock getInsertionBlock(Instruction instruction) {
     if (instruction.isIf()) {
-      return instruction.asIf().targetFromNonNullObject();
+      If theIf = instruction.asIf();
+      if (theIf.isInstanceOfTest()) {
+        return theIf.targetFromTrue();
+      } else {
+        return theIf.targetFromNonNullObject();
+      }
     }
     BasicBlock block = instruction.getBlock();
     if (block.hasCatchHandlers()) {
@@ -675,17 +732,11 @@
 
   private static boolean isNullableReferenceTypeWithOtherNonDebugUsers(
       Value value, Instruction ignore) {
-    if (isNullableReferenceType(value)) {
-      if (value.hasPhiUsers()) {
-        return true;
-      }
-      for (Instruction user : value.uniqueUsers()) {
-        if (user != ignore) {
-          return true;
-        }
-      }
-    }
-    return false;
+    return isNullableReferenceType(value) && hasOtherNonDebugUsers(value, ignore);
+  }
+
+  private static boolean hasOtherNonDebugUsers(Value value, Instruction ignore) {
+    return value.hasPhiUsers() || Iterables.any(value.uniqueUsers(), user -> user != ignore);
   }
 
   static class AssumedValueInfo {
@@ -884,6 +935,15 @@
             assumedValueInfo -> assumedValueInfo.setDynamicType(dynamicType));
       }
 
+      void addAssumedValueWithUnknownDominance(
+          Instruction instruction, Value assumedValue, DynamicTypeWithUpperBound dynamicType) {
+        updateAssumedValueInfo(
+            instruction,
+            assumedValue,
+            AssumedDominance.unknown(),
+            assumedValueInfo -> assumedValueInfo.setDynamicType(dynamicType));
+      }
+
       void addNonNullValueKnownToDominateAllUsers(Instruction instruction, Value nonNullValue) {
         updateAssumedValueInfo(
             instruction, nonNullValue, AssumedDominance.everything(), AssumedValueInfo::setNotNull);
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/MutableMethodOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/optimize/info/MutableMethodOptimizationInfo.java
index 4cba532..f2a5336 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/MutableMethodOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/MutableMethodOptimizationInfo.java
@@ -444,7 +444,6 @@
 
   @Override
   public int getReturnedArgument() {
-    assert returnsArgument();
     return returnedArgument;
   }
 
@@ -714,8 +713,7 @@
     // Nullability could be less precise, though. For example, suppose a value is known to be
     // non-null after a safe invocation, hence recorded with the non-null variant. If that call is
     // inlined and the method is reprocessed, such non-null assumption cannot be made again.
-    // TODO(b/327130357): Reenable assert.
-    // assert verifyDynamicType(appView, newDynamicType, staticReturnType);
+    assert verifyDynamicType(appView, newDynamicType, staticReturnType);
     setDynamicType(newDynamicType);
   }
 
@@ -726,7 +724,6 @@
     return this;
   }
 
-  @SuppressWarnings("UnusedMethod")
   private boolean verifyDynamicType(
       AppView<?> appView, DynamicType newDynamicType, TypeElement staticReturnType) {
     if (appView.enableWholeProgramOptimizations()) {
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/typechecks/DynamicTypeAfterInstanceOfTest.java b/src/test/java/com/android/tools/r8/ir/optimize/typechecks/DynamicTypeAfterInstanceOfTest.java
new file mode 100644
index 0000000..2e9542f
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/typechecks/DynamicTypeAfterInstanceOfTest.java
@@ -0,0 +1,80 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.optimize.typechecks;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DynamicTypeAfterInstanceOfTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .enableNoVerticalClassMergingAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        // The call to a.m() should be inlined as a result of the dynamic type information.
+        .inspect(
+            inspector -> {
+              ClassSubject aClassSubject = inspector.clazz(A.class);
+              assertThat(aClassSubject, isPresent());
+              assertThat(aClassSubject.uniqueMethodWithOriginalName("m"), isAbsent());
+
+              ClassSubject bClassSubject = inspector.clazz(B.class);
+              assertThat(bClassSubject, isPresent());
+              assertThat(bClassSubject.uniqueMethodWithOriginalName("m"), isAbsent());
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("B");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      A a = System.currentTimeMillis() > 0 ? new B() : new A();
+      if (a instanceof B) {
+        a.m();
+      }
+    }
+  }
+
+  @NoVerticalClassMerging
+  static class A {
+
+    public void m() {
+      System.out.println("A");
+    }
+  }
+
+  static class B extends A {
+
+    @Override
+    public void m() {
+      System.out.println("B");
+    }
+  }
+}