Remove Objects#requireNonNull for definitely null reference.

Bug: 124246610
Change-Id: Ib593f20c01cc64905da4c32e8baf937446237c7b
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index e47b1b9..7b398f3 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -1611,62 +1611,69 @@
     AppInfoWithLiveness appInfoWithLiveness = appView.appInfo().withLiveness();
     Set<Value> needToWidenValues = Sets.newIdentityHashSet();
     Set<Value> needToNarrowValues = Sets.newIdentityHashSet();
-    InstructionIterator iterator = code.instructionIterator();
-    while (iterator.hasNext()) {
-      Instruction current = iterator.next();
-      if (current.isInvokeMethod()) {
-        InvokeMethod invoke = current.asInvokeMethod();
-        Value outValue = invoke.outValue();
-        // TODO(b/124246610): extend to other variants that receive error messages or supplier.
-        if (invoke.getInvokedMethod() == dexItemFactory.objectsMethods.requireNonNull) {
-          Value obj = invoke.arguments().get(0);
-          if ((outValue == null && obj.hasLocalInfo())
-              || (outValue != null && !obj.hasSameOrNoLocal(outValue))) {
-            continue;
-          }
-          Nullability nullability = obj.getTypeLattice().nullability();
-          if (nullability.isDefinitelyNotNull()) {
-            if (outValue != null) {
-              outValue.replaceUsers(obj);
-              needToNarrowValues.addAll(outValue.affectedValues());
+    Set<BasicBlock> blocksToBeRemoved = Sets.newIdentityHashSet();
+    ListIterator<BasicBlock> blockIterator = code.listIterator();
+    while (blockIterator.hasNext()) {
+      BasicBlock block = blockIterator.next();
+      if (blocksToBeRemoved.contains(block)) {
+        continue;
+      }
+      InstructionListIterator iterator = block.listIterator();
+      while (iterator.hasNext()) {
+        Instruction current = iterator.next();
+        if (current.isInvokeMethod()) {
+          InvokeMethod invoke = current.asInvokeMethod();
+          Value outValue = invoke.outValue();
+          // TODO(b/124246610): extend to other variants that receive error messages or supplier.
+          if (invoke.getInvokedMethod() == dexItemFactory.objectsMethods.requireNonNull) {
+            Value obj = invoke.arguments().get(0);
+            if ((outValue == null && obj.hasLocalInfo())
+                || (outValue != null && !obj.hasSameOrNoLocal(outValue))) {
+              continue;
             }
-            iterator.removeOrReplaceByDebugLocalRead();
-          } else if (nullability.isDefinitelyNull()) {
-            // TODO(b/124246610): throw NPE.
-            // Refactor UninstantiatedTypeOptimization#replaceCurrentInstructionWithThrowNull
-            // and move it to iterator.
-          }
-        } else if (outValue != null && !outValue.hasLocalInfo()) {
-          if (appView
-              .dexItemFactory()
-              .libraryMethodsReturningReceiver
-              .contains(invoke.getInvokedMethod())) {
-            if (checkArgumentType(invoke, 0)) {
-              outValue.replaceUsers(invoke.arguments().get(0));
-              invoke.setOutValue(null);
+            Nullability nullability = obj.getTypeLattice().nullability();
+            if (nullability.isDefinitelyNotNull()) {
+              if (outValue != null) {
+                outValue.replaceUsers(obj);
+                needToNarrowValues.addAll(outValue.affectedValues());
+              }
+              iterator.removeOrReplaceByDebugLocalRead();
+            } else if (obj.isAlwaysNull(appView) && appView.appInfo().hasSubtyping()) {
+              iterator.replaceCurrentInstructionWithThrowNull(
+                  appView.withSubtyping(), code, blockIterator, blocksToBeRemoved);
             }
-          } else if (appInfoWithLiveness != null) {
-            DexEncodedMethod target =
-                invoke.lookupSingleTarget(appInfoWithLiveness, code.method.method.holder);
-            if (target != null) {
-              DexMethod invokedMethod = target.method;
-              // Check if the invoked method is known to return one of its arguments.
-              DexEncodedMethod definition = appView.definitionFor(invokedMethod);
-              if (definition != null && definition.getOptimizationInfo().returnsArgument()) {
-                int argumentIndex = definition.getOptimizationInfo().getReturnedArgument();
-                // Replace the out value of the invoke with the argument and ignore the out value.
-                if (argumentIndex >= 0 && checkArgumentType(invoke, argumentIndex)) {
-                  Value argument = invoke.arguments().get(argumentIndex);
-                  assert outValue.verifyCompatible(argument.outType());
-                  if (argument
-                      .getTypeLattice()
-                      .lessThanOrEqual(outValue.getTypeLattice(), appView)) {
-                    needToNarrowValues.addAll(outValue.affectedValues());
-                  } else {
-                    needToWidenValues.addAll(outValue.affectedValues());
+          } else if (outValue != null && !outValue.hasLocalInfo()) {
+            if (appView
+                .dexItemFactory()
+                .libraryMethodsReturningReceiver
+                .contains(invoke.getInvokedMethod())) {
+              if (checkArgumentType(invoke, 0)) {
+                outValue.replaceUsers(invoke.arguments().get(0));
+                invoke.setOutValue(null);
+              }
+            } else if (appInfoWithLiveness != null) {
+              DexEncodedMethod target =
+                  invoke.lookupSingleTarget(appInfoWithLiveness, code.method.method.holder);
+              if (target != null) {
+                DexMethod invokedMethod = target.method;
+                // Check if the invoked method is known to return one of its arguments.
+                DexEncodedMethod definition = appView.definitionFor(invokedMethod);
+                if (definition != null && definition.getOptimizationInfo().returnsArgument()) {
+                  int argumentIndex = definition.getOptimizationInfo().getReturnedArgument();
+                  // Replace the out value of the invoke with the argument and ignore the out value.
+                  if (argumentIndex >= 0 && checkArgumentType(invoke, argumentIndex)) {
+                    Value argument = invoke.arguments().get(argumentIndex);
+                    assert outValue.verifyCompatible(argument.outType());
+                    if (argument
+                        .getTypeLattice()
+                        .lessThanOrEqual(outValue.getTypeLattice(), appView)) {
+                      needToNarrowValues.addAll(outValue.affectedValues());
+                    } else {
+                      needToWidenValues.addAll(outValue.affectedValues());
+                    }
+                    outValue.replaceUsers(argument);
+                    invoke.setOutValue(null);
                   }
-                  outValue.replaceUsers(argument);
-                  invoke.setOutValue(null);
                 }
               }
             }
@@ -1674,6 +1681,11 @@
         }
       }
     }
+    if (!blocksToBeRemoved.isEmpty()) {
+      code.removeBlocks(blocksToBeRemoved);
+      code.removeAllTrivialPhis();
+      assert code.getUnreachableBlocks().isEmpty();
+    }
     if (!needToWidenValues.isEmpty() || !needToNarrowValues.isEmpty()) {
       TypeAnalysis analysis = new TypeAnalysis(appView, code.method);
       // If out value of invoke < argument (e.g., losing non-null info), widen users type.
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
index f8ef59e..3d80826 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
@@ -4,27 +4,47 @@
 package com.android.tools.r8.ir.optimize;
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThat;
+import static org.junit.Assume.assumeTrue;
 
+import com.android.tools.r8.D8TestRunResult;
 import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.R8TestRunResult;
 import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.TestRunResult;
 import com.android.tools.r8.ToolHelper.DexVm.Version;
-import com.android.tools.r8.VmTestRunner;
-import com.android.tools.r8.VmTestRunner.IgnoreIfVmOlderThan;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
 import com.android.tools.r8.utils.codeinspector.MethodSubject;
 import com.google.common.collect.Streams;
 import java.util.Objects;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 class ObjectsRequireNonNullTestMain {
 
+  static class Uninitialized {
+    void noWayToCall() {
+      System.out.println("Uninitialized, hence no way to call this.");
+    }
+  }
+
+  @NeverPropagateValue
+  @NeverInline
+  static void consumeUninitialized(Uninitialized arg) {
+    Uninitialized nonNullArg = Objects.requireNonNull(arg);
+    // Dead code.
+    nonNullArg.noWayToCall();
+  }
+
   static class Foo {
     @NeverInline
     void bar() {
@@ -60,22 +80,50 @@
     } catch (NullPointerException npe) {
       System.out.println("Expected NPE");
     }
+
+    try {
+      consumeUninitialized(null);
+    } catch (NullPointerException npe) {
+      System.out.println("Expected NPE");
+    }
   }
 }
 
-@RunWith(VmTestRunner.class)
+@RunWith(Parameterized.class)
 public class ObjectsRequireNonNullTest extends TestBase {
   private static final String JAVA_OUTPUT = StringUtils.lines(
       "Foo::toString",
       "Foo::bar",
       "Foo::bar",
+      "Expected NPE",
       "Expected NPE"
   );
   private static final Class<?> MAIN = ObjectsRequireNonNullTestMain.class;
 
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters()
+        .withCfRuntimes()
+        // Objects#requireNonNull will be desugared VMs older than API level K.
+        .withDexRuntimesStartingFromExcluding(Version.V4_4_4)
+        .build();
+  }
+
+  private final TestParameters parameters;
+
+  public ObjectsRequireNonNullTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
   @Test
   public void testJvmOutput() throws Exception {
-    testForJvm().addTestClasspath().run(MAIN).assertSuccessWithOutput(JAVA_OUTPUT);
+    assumeTrue(
+        "Only run JVM reference once (for CF backend)",
+        parameters.getBackend() == Backend.CF);
+    testForJvm()
+        .addTestClasspath()
+        .run(parameters.getRuntime(), MAIN)
+        .assertSuccessWithOutput(JAVA_OUTPUT);
   }
 
   private static boolean isObjectsRequireNonNull(DexMethod method) {
@@ -92,48 +140,64 @@
     })).count();
   }
 
-  private void test(TestRunResult result, int expectedCount) throws Exception {
+  private void test(
+      TestRunResult result,
+      int expectedCountInMain,
+      int expectedCountInConsumer) throws Exception {
     CodeInspector codeInspector = result.inspector();
     ClassSubject mainClass = codeInspector.clazz(MAIN);
     MethodSubject mainMethod = mainClass.mainMethod();
     assertThat(mainMethod, isPresent());
-    long count = countObjectsRequireNonNull(mainMethod);
-    assertEquals(expectedCount, count);
+    assertEquals(expectedCountInMain, countObjectsRequireNonNull(mainMethod));
 
     MethodSubject unknownArg = mainClass.uniqueMethodWithName("unknownArg");
     assertThat(unknownArg, isPresent());
     // Due to the nullable argument, requireNonNull should remain.
     assertEquals(1, countObjectsRequireNonNull(unknownArg));
+
+    MethodSubject uninit = mainClass.uniqueMethodWithName("consumeUninitialized");
+    assertThat(uninit, isPresent());
+    assertEquals(expectedCountInConsumer, countObjectsRequireNonNull(uninit));
+    if (expectedCountInConsumer == 0) {
+      assertEquals(
+          0, Streams.stream(uninit.iterateInstructions(InstructionSubject::isInvoke)).count());
+      assertEquals(
+          1, Streams.stream(uninit.iterateInstructions(InstructionSubject::isThrow)).count());
+    }
   }
 
   @Test
-  @IgnoreIfVmOlderThan(Version.V4_4_4)
   public void testD8() throws Exception {
-    TestRunResult result = testForD8()
+    assumeTrue("Only run D8 for Dex backend", parameters.getBackend() == Backend.DEX);
+    D8TestRunResult result = testForD8()
         .debug()
         .addProgramClassesAndInnerClasses(MAIN)
-        .run(MAIN)
+        .setMinApi(parameters.getRuntime())
+        .run(parameters.getRuntime(), MAIN)
         .assertSuccessWithOutput(JAVA_OUTPUT);
-    test(result, 2);
+    test(result, 2, 1);
 
     result = testForD8()
         .release()
         .addProgramClassesAndInnerClasses(MAIN)
-        .run(MAIN)
+        .setMinApi(parameters.getRuntime())
+        .run(parameters.getRuntime(), MAIN)
         .assertSuccessWithOutput(JAVA_OUTPUT);
-    test(result, 0);
+    test(result, 0, 1);
   }
 
   @Test
-  @IgnoreIfVmOlderThan(Version.V4_4_4)
   public void testR8() throws Exception {
-    // CF disables move result optimization.
-    TestRunResult result = testForR8(Backend.DEX)
+    assumeTrue("CF disables move result optimization", parameters.getBackend() == Backend.DEX);
+    R8TestRunResult result = testForR8(parameters.getBackend())
         .addProgramClassesAndInnerClasses(MAIN)
         .enableInliningAnnotations()
+        .enableMemberValuePropagationAnnotations()
         .addKeepMainRule(MAIN)
-        .run(MAIN)
+        .noMinification()
+        .setMinApi(parameters.getRuntime())
+        .run(parameters.getRuntime(), MAIN)
         .assertSuccessWithOutput(JAVA_OUTPUT);
-    test(result, 0);
+    test(result, 0, 0);
   }
 }