Rewrite instanceof to null check

Change-Id: Id15a77d0ba8fb1420f6e18ab1f8a25accd09dc5f
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/BranchSimplifier.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/BranchSimplifier.java
index 228eb0f..bca0b7d 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/BranchSimplifier.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/BranchSimplifier.java
@@ -12,6 +12,7 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexEncodedField;
+import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.analysis.equivalence.BasicBlockBehavioralSubsumption;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
@@ -29,6 +30,7 @@
 import com.android.tools.r8.ir.code.InstructionIterator;
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.IntSwitch;
+import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.NumericType;
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.Position;
@@ -112,6 +114,9 @@
         if (rewriteIfWithConstZero(code, block)) {
           simplified = true;
         }
+        if (rewriteIfWithObjectsIsNullOrNonNull(code, block)) {
+          simplified = true;
+        }
 
         if (simplifyKnownBooleanCondition(code, block)) {
           simplified = true;
@@ -602,6 +607,29 @@
     return false;
   }
 
+  private boolean rewriteIfWithObjectsIsNullOrNonNull(IRCode code, BasicBlock block) {
+    If theIf = block.exit().asIf();
+    if (!theIf.isZeroTest() || !theIf.getType().isEqualsOrNotEquals()) {
+      return false;
+    }
+
+    Value value = theIf.lhs();
+    if (value.isDefinedByInstructionSatisfying(Instruction::isInvokeStatic)) {
+      InvokeStatic invoke = value.getDefinition().asInvokeStatic();
+      DexMethod invokedMethod = invoke.getInvokedMethod();
+      if (invokedMethod.isIdenticalTo(dexItemFactory.objectsMethods.isNull)) {
+        If ifz = new If(theIf.getType().inverted(), invoke.getFirstArgument());
+        block.replaceLastInstruction(ifz, code);
+        return true;
+      } else if (invokedMethod.isIdenticalTo(dexItemFactory.objectsMethods.nonNull)) {
+        If ifz = new If(theIf.getType(), invoke.getFirstArgument());
+        block.replaceLastInstruction(ifz, code);
+        return true;
+      }
+    }
+    return false;
+  }
+
   private boolean flipIfBranchesIfNeeded(IRCode code, BasicBlock block) {
     If theIf = block.exit().asIf();
     BasicBlock trueTarget = theIf.getTrueTarget();
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
index 9e18843..d3c57fc 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
@@ -154,16 +154,6 @@
     REMOVED_CAST_DO_PROPAGATE
   }
 
-  private enum InstanceOfResult {
-    UNKNOWN,
-    TRUE,
-    FALSE;
-
-    boolean isTrue() {
-      return this == TRUE;
-    }
-  }
-
   // Returns true if the given check-cast instruction was removed.
   private RemoveCheckCastInstructionIfTrivialResult removeCheckCastInstructionIfTrivial(
       AppView<AppInfoWithLiveness> appViewWithLiveness,
@@ -335,64 +325,83 @@
     TypeElement instanceOfType =
         TypeElement.fromDexType(instanceOf.type(), inType.nullability(), appView);
     Value aliasValue = inValue.getAliasedValue();
-
-    InstanceOfResult result = InstanceOfResult.UNKNOWN;
-    if (inType.isDefinitelyNull()) {
-      result = InstanceOfResult.FALSE;
-    } else if (inType.lessThanOrEqual(instanceOfType, appView) && !inType.isNullable()) {
-      result = InstanceOfResult.TRUE;
-    } else if (!aliasValue.isPhi()
-        && aliasValue.definition.isCreatingInstanceOrArray()
-        && instanceOfType.strictlyLessThan(inType, appView)) {
-      result = InstanceOfResult.FALSE;
-    } else if (appView.appInfo().hasLiveness()) {
-      if (instanceOf.type().isClassType()
-          && isNeverInstantiatedDirectlyOrIndirectly(instanceOf.type())) {
-        // The type of the instance-of instruction is a program class, and is never instantiated
-        // directly or indirectly. Thus, the in-value must be null, meaning that the instance-of
-        // instruction will always evaluate to false.
-        result = InstanceOfResult.FALSE;
+    if (inType.lessThanOrEqual(instanceOfType, appView)) {
+      if (inType.isDefinitelyNull()) {
+        return replaceInstanceOfByFalse(code, it);
       }
-
-      if (result == InstanceOfResult.UNKNOWN) {
-        if (inType.isClassType()
-            && isNeverInstantiatedDirectlyOrIndirectly(inType.asClassType().getClassType())) {
-          // The type of the in-value is a program class, and is never instantiated directly or
-          // indirectly. This, the in-value must be null, meaning that the instance-of instruction
-          // will always evaluate to false.
-          result = InstanceOfResult.FALSE;
-        }
+      if (inType.isDefinitelyNotNull()) {
+        return replaceInstanceOfByTrue(code, it);
       }
-
-      if (result == InstanceOfResult.UNKNOWN) {
-        Value aliasedValue =
-            inValue.getSpecificAliasedValue(
-                value ->
-                    value.isDefinedByInstructionSatisfying(
-                        Instruction::isAssumeWithDynamicTypeAssumption));
-        if (aliasedValue != null) {
-          Assume assumeInstruction = aliasedValue.getDefinition().asAssume();
-          DynamicType dynamicType = assumeInstruction.getDynamicType();
-          if (dynamicType.getNullability().isDefinitelyNull()) {
-            result = InstanceOfResult.FALSE;
-          } else if (dynamicType.isDynamicTypeWithUpperBound()
-              && dynamicType
-                  .asDynamicTypeWithUpperBound()
-                  .getDynamicUpperBoundType()
-                  .lessThanOrEqual(instanceOfType, appView)
-              && (!inType.isNullable() || dynamicType.getNullability().isDefinitelyNotNull())) {
-            result = InstanceOfResult.TRUE;
-          }
-        }
+      if (options.canUseJavaUtilObjectsNonNull()) {
+        return replaceInstanceOfByNonNull(it, instanceOf);
       }
     }
-    if (result != InstanceOfResult.UNKNOWN) {
-      it.replaceCurrentInstructionWithConstBoolean(code, result.isTrue());
-      return true;
+    if (aliasValue.isDefinedByInstructionSatisfying(Instruction::isCreatingInstanceOrArray)
+        && instanceOfType.strictlyLessThan(inType, appView)) {
+      return replaceInstanceOfByFalse(code, it);
+    }
+    if (instanceOf.type().isClassType()
+        && isNeverInstantiatedDirectlyOrIndirectly(instanceOf.type())) {
+      // The type of the instance-of instruction is a program class, and is never instantiated
+      // directly or indirectly. Thus, the in-value must be null, meaning that the instance-of
+      // instruction will always evaluate to false.
+      return replaceInstanceOfByFalse(code, it);
+    }
+
+    if (inType.isClassType()
+        && isNeverInstantiatedDirectlyOrIndirectly(inType.asClassType().getClassType())) {
+      // The type of the in-value is a program class, and is never instantiated directly or
+      // indirectly. This, the in-value must be null, meaning that the instance-of instruction
+      // will always evaluate to false.
+      return replaceInstanceOfByFalse(code, it);
+    }
+
+    Value aliasedValue =
+        inValue.getSpecificAliasedValue(
+            value ->
+                value.isDefinedByInstructionSatisfying(
+                    Instruction::isAssumeWithDynamicTypeAssumption));
+    if (aliasedValue != null) {
+      Assume assumeInstruction = aliasedValue.getDefinition().asAssume();
+      DynamicType dynamicType = assumeInstruction.getDynamicType();
+      if (dynamicType.getNullability().isDefinitelyNull()) {
+        return replaceInstanceOfByFalse(code, it);
+      } else if (dynamicType.isDynamicTypeWithUpperBound()
+          && dynamicType
+              .asDynamicTypeWithUpperBound()
+              .getDynamicUpperBoundType()
+              .lessThanOrEqual(instanceOfType, appView)
+          && (!inType.isNullable() || dynamicType.getNullability().isDefinitelyNotNull())) {
+        return replaceInstanceOfByTrue(code, it);
+      }
     }
     return false;
   }
 
+  private boolean replaceInstanceOfByFalse(
+      IRCode code, InstructionListIterator instructionIterator) {
+    instructionIterator.replaceCurrentInstructionWithConstBoolean(code, false);
+    return true;
+  }
+
+  private boolean replaceInstanceOfByTrue(
+      IRCode code, InstructionListIterator instructionIterator) {
+    instructionIterator.replaceCurrentInstructionWithConstBoolean(code, true);
+    return true;
+  }
+
+  private boolean replaceInstanceOfByNonNull(
+      InstructionListIterator instructionIterator, InstanceOf instanceOf) {
+    InvokeStatic replacement =
+        InvokeStatic.builder()
+            .setMethod(dexItemFactory.objectsMethods.nonNull)
+            .setSingleArgument(instanceOf.value())
+            .setOutValue(instanceOf.outValue())
+            .build();
+    instructionIterator.replaceCurrentInstruction(replacement);
+    return true;
+  }
+
   private boolean isNeverInstantiatedDirectlyOrIndirectly(DexType type) {
     assert appView.appInfo().hasLiveness();
     assert type.isClassType();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/library/LibraryMethodSideEffectModelCollection.java b/src/main/java/com/android/tools/r8/ir/optimize/library/LibraryMethodSideEffectModelCollection.java
index 6b2fbb6..4d74eeb 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/library/LibraryMethodSideEffectModelCollection.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/library/LibraryMethodSideEffectModelCollection.java
@@ -97,6 +97,8 @@
         .add(dexItemFactory.recordMembers.constructor)
         .add(dexItemFactory.objectMembers.constructor)
         .add(dexItemFactory.objectMembers.getClass)
+        .add(dexItemFactory.objectsMethods.isNull)
+        .add(dexItemFactory.objectsMethods.nonNull)
         .add(dexItemFactory.shortMembers.shortValue)
         .add(dexItemFactory.shortMembers.toString)
         .add(dexItemFactory.shortMembers.valueOf)
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 53e5bd0..e9034af 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -2742,6 +2742,10 @@
     return hasFeaturePresentFrom(AndroidApiLevel.N);
   }
 
+  public boolean canUseJavaUtilObjectsNonNull() {
+    return isGeneratingDex() && hasFeaturePresentFrom(AndroidApiLevel.N);
+  }
+
   public boolean canUseSuppressedExceptions() {
     // TODO(b/214239152): Suppressed exceptions are @hide from at least 4.0.1 / Android I / API 14.
     return hasFeaturePresentFrom(AndroidApiLevel.K);
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/instanceofremoval/InstanceOfToNullCheckRewritingTest.java b/src/test/java/com/android/tools/r8/ir/optimize/instanceofremoval/InstanceOfToNullCheckRewritingTest.java
index 8c1a93a..97b973a 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/instanceofremoval/InstanceOfToNullCheckRewritingTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/instanceofremoval/InstanceOfToNullCheckRewritingTest.java
@@ -4,13 +4,14 @@
 
 package com.android.tools.r8.ir.optimize.instanceofremoval;
 
-import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresentIf;
 import static org.hamcrest.MatcherAssert.assertThat;
 
 import com.android.tools.r8.NoVerticalClassMerging;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.AndroidApiLevel;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -36,7 +37,13 @@
         .enableNoVerticalClassMergingAnnotations()
         .setMinApi(parameters)
         .compile()
-        .inspect(inspector -> assertThat(inspector.clazz(I.class), isPresent()))
+        .inspect(
+            inspector ->
+                assertThat(
+                    inspector.clazz(I.class),
+                    isPresentIf(
+                        parameters.isCfRuntime()
+                            || parameters.getApiLevel().isLessThan(AndroidApiLevel.N))))
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutputLines("true", "false");
   }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/nonnull/IfObjectsNullOrNonNullTest.java b/src/test/java/com/android/tools/r8/ir/optimize/nonnull/IfObjectsNullOrNonNullTest.java
new file mode 100644
index 0000000..78373ee
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/nonnull/IfObjectsNullOrNonNullTest.java
@@ -0,0 +1,90 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.optimize.nonnull;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import java.util.Objects;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class IfObjectsNullOrNonNullTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void testJvm() throws Exception {
+    parameters.assumeJvmTestParameters();
+    testForJvm(parameters)
+        .addInnerClasses(getClass())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("false", "true", "true", "false");
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .setMinApi(parameters)
+        .compile()
+        .inspect(
+            inspector -> {
+              MethodSubject mainMethodSubject = inspector.clazz(Main.class).mainMethod();
+              assertThat(mainMethodSubject, isPresent());
+              assertTrue(
+                  mainMethodSubject
+                      .streamInstructions()
+                      .noneMatch(InstructionSubject::isInvokeStatic));
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("false", "true", "true", "false");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      boolean alwaysTrue = args.length == 0;
+      Object nonNullObject = alwaysTrue ? new Object() : null;
+      Object nullObject = alwaysTrue ? null : new Object();
+      if (Objects.isNull(nonNullObject)) {
+        System.out.println("true");
+      } else {
+        System.out.println("false");
+      }
+      if (Objects.isNull(nullObject)) {
+        System.out.println("true");
+      } else {
+        System.out.println("false");
+      }
+      if (Objects.nonNull(nonNullObject)) {
+        System.out.println("true");
+      } else {
+        System.out.println("false");
+      }
+      if (Objects.nonNull(nullObject)) {
+        System.out.println("true");
+      } else {
+        System.out.println("false");
+      }
+    }
+  }
+}