Retain NPE messages in non-null-param-or-throw optimization

Bug: 157147597
Change-Id: Id45aa2e496f3953e1d9d6034a98d062d23602dfa
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
index b4ceb0d..7acb409 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
@@ -198,7 +198,9 @@
     }
     current.moveDebugValues(newInstruction);
     newInstruction.setBlock(block);
-    newInstruction.setPosition(current.getPosition());
+    if (!newInstruction.hasPosition()) {
+      newInstruction.setPosition(current.getPosition());
+    }
     listIterator.remove();
     listIterator.add(newInstruction);
     current.clearBlock();
@@ -210,7 +212,13 @@
       IRCode code, InternalOptions options, long value, TypeElement type) {
     ConstNumber constNumberInstruction = code.createNumberConstant(value, type);
     // Note that we only keep position info for throwing instructions in release mode.
-    constNumberInstruction.setPosition(options.debug ? current.getPosition() : Position.none());
+    Position position;
+    if (options.debug) {
+      position = current != null ? current.getPosition() : block.getPosition();
+    } else {
+      position = Position.none();
+    }
+    constNumberInstruction.setPosition(position);
     add(constNumberInstruction);
     return constNumberInstruction.outValue();
   }
@@ -280,58 +288,90 @@
     if (current == null) {
       throw new IllegalStateException();
     }
-    BasicBlock block = current.getBlock();
+
+    Instruction toBeReplaced = current;
+
+    BasicBlock block = toBeReplaced.getBlock();
     assert !blocksToRemove.contains(block);
     assert affectedValues != null;
 
-    BasicBlock normalSuccessorBlock = split(code, blockIterator);
+    // Split the block before the instruction that should be replaced by `throw null`.
     previous();
 
-    // Unlink all blocks that are dominated by successor.
-    {
-      DominatorTree dominatorTree = new DominatorTree(code, MAY_HAVE_UNREACHABLE_BLOCKS);
-      blocksToRemove.addAll(block.unlink(normalSuccessorBlock, dominatorTree, affectedValues));
+    BasicBlock throwBlock;
+    if (block.hasCatchHandlers() && !toBeReplaced.instructionTypeCanThrow()) {
+      // We need to insert the throw instruction in a block of its own, so split the current block
+      // into three blocks, where the intermediate block only contains a goto instruction.
+      throwBlock = split(code, blockIterator, true);
+      throwBlock.listIterator(code).split(code, blockIterator);
+    } else {
+      split(code, blockIterator, true);
+      throwBlock = block;
     }
 
-    // Insert constant null before the instruction.
+    // Position the instruction iterator before the goto instruction.
+    assert !hasNext();
     previous();
-    Value nullValue = insertConstNullInstruction(code, appView.options());
-    next();
+
+    // Unlink all blocks that are dominated by the unique normal successor of the throw block.
+    blocksToRemove.addAll(
+        throwBlock.unlink(
+            throwBlock.getUniqueNormalSuccessor(),
+            new DominatorTree(code, MAY_HAVE_UNREACHABLE_BLOCKS),
+            affectedValues));
+
+    InstructionListIterator throwBlockInstructionIterator =
+        throwBlock == block ? this : throwBlock.listIterator(code);
+
+    // Insert constant null before the goto instruction.
+    Value nullValue =
+        throwBlockInstructionIterator.insertConstNullInstruction(code, appView.options());
+
+    // Move past the inserted goto instruction.
+    throwBlockInstructionIterator.next();
+    assert !throwBlockInstructionIterator.hasNext();
 
     // Replace the instruction by throw.
     Throw throwInstruction = new Throw(nullValue);
-    for (Value inValue : current.inValues()) {
-      if (inValue.hasLocalInfo()) {
-        // Add this value as a debug value to avoid changing its live range.
-        throwInstruction.addDebugValue(inValue);
-      }
+    if (toBeReplaced.getPosition().isSome()) {
+      throwInstruction.setPosition(toBeReplaced.getPosition());
+    } else {
+      // The instruction that is being removed cannot throw, and thus it must be unreachable as we
+      // are replacing it by `throw null`, so we can safely use a none-position.
+      assert !toBeReplaced.instructionTypeCanThrow();
+      throwInstruction.setPosition(Position.syntheticNone());
     }
-    replaceCurrentInstruction(throwInstruction);
-    next();
-    remove();
+    throwBlockInstructionIterator.replaceCurrentInstruction(throwInstruction);
 
-    // Remove all catch handlers where the guard does not include NullPointerException.
     if (block.hasCatchHandlers()) {
-      CatchHandlers<BasicBlock> catchHandlers = block.getCatchHandlers();
-      catchHandlers.forEach(
-          (guard, target) -> {
-            if (blocksToRemove.contains(target)) {
-              // Already removed previously. This may happen if two catch handlers have the same
-              // target.
-              return;
-            }
-            if (!appView.appInfo().isSubtype(appView.dexItemFactory().npeType, guard)) {
-              // TODO(christofferqa): Consider updating previous dominator tree instead of
-              //   rebuilding it from scratch.
-              DominatorTree dominatorTree = new DominatorTree(code, MAY_HAVE_UNREACHABLE_BLOCKS);
-              blocksToRemove.addAll(block.unlink(target, dominatorTree, affectedValues));
-            }
-          });
+      if (block == throwBlock) {
+        // Remove all catch handlers where the guard does not include NullPointerException if the
+        // replaced instruction could throw.
+        CatchHandlers<BasicBlock> catchHandlers = block.getCatchHandlers();
+        catchHandlers.forEach(
+            (guard, target) -> {
+              if (blocksToRemove.contains(target)) {
+                // Already removed previously. This may happen if two catch handlers have the same
+                // target.
+                return;
+              }
+              if (!appView.appInfo().isSubtype(appView.dexItemFactory().npeType, guard)) {
+                // TODO(christofferqa): Consider updating previous dominator tree instead of
+                //   rebuilding it from scratch.
+                DominatorTree dominatorTree = new DominatorTree(code, MAY_HAVE_UNREACHABLE_BLOCKS);
+                blocksToRemove.addAll(block.unlink(target, dominatorTree, affectedValues));
+              }
+            });
+      } else {
+        // We replaced a dead, non-throwing instruction by a throwing instruction. Since this is
+        // dead code, we don't need to worry about the catch handlers of the `throwBlock`.
+      }
     }
   }
 
   @Override
-  public BasicBlock split(IRCode code, ListIterator<BasicBlock> blocksIterator) {
+  public BasicBlock split(
+      IRCode code, ListIterator<BasicBlock> blocksIterator, boolean keepCatchHandlers) {
     List<BasicBlock> blocks = code.blocks;
     assert blocksIterator == null || IteratorUtils.peekPrevious(blocksIterator) == block;
 
@@ -346,7 +386,6 @@
 
     // Prepare the new block, placing the exception handlers on the block with the throwing
     // instruction.
-    boolean keepCatchHandlers = hasPrevious() && peekPrevious().instructionTypeCanThrow();
     newBlock = block.createSplitBlock(blockNumber, keepCatchHandlers);
 
     // Add a goto instruction.
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCodeInstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/IRCodeInstructionListIterator.java
index b9cc832..47f5769 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCodeInstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCodeInstructionListIterator.java
@@ -68,7 +68,8 @@
   }
 
   @Override
-  public BasicBlock split(IRCode code, ListIterator<BasicBlock> blockIterator) {
+  public BasicBlock split(
+      IRCode code, ListIterator<BasicBlock> blockIterator, boolean keepCatchHandlers) {
     throw new Unimplemented();
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/InstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/InstructionListIterator.java
index a546ee3..d83aa64 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InstructionListIterator.java
@@ -104,20 +104,25 @@
 
   /**
    * Split the block into two blocks at the point of the {@link ListIterator} cursor. The existing
-   * block will have all the instructions before the cursor, and the new block all the
-   * instructions after the cursor.
+   * block will have all the instructions before the cursor, and the new block all the instructions
+   * after the cursor.
    *
-   * If the current block has catch handlers these catch handlers will be attached to the block
+   * <p>If the current block has catch handlers these catch handlers will be attached to the block
    * containing the throwing instruction after the split.
    *
    * @param code the IR code for the block this iterator originates from.
    * @param blockIterator basic block iterator used to iterate the blocks. This must be positioned
-   * just after the block for which this is the instruction iterator. After this method returns it
-   * will be positioned just after the basic block returned. Calling {@link #remove} without
-   * further navigation will remove that block.
+   *     just after the block for which this is the instruction iterator. After this method returns
+   *     it will be positioned just after the basic block returned. Calling {@link #remove} without
+   *     further navigation will remove that block.
+   * @param keepCatchHandlers whether to keep catch handlers on the original block.
    * @return Returns the new block with the instructions after the cursor.
    */
-  BasicBlock split(IRCode code, ListIterator<BasicBlock> blockIterator);
+  BasicBlock split(IRCode code, ListIterator<BasicBlock> blockIterator, boolean keepCatchHandlers);
+
+  default BasicBlock split(IRCode code, ListIterator<BasicBlock> blockIterator) {
+    return split(code, blockIterator, hasPrevious() && peekPrevious().instructionTypeCanThrow());
+  }
 
   default BasicBlock split(IRCode code) {
     return split(code, null);
diff --git a/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
index d73f8df..d851eb9 100644
--- a/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
@@ -77,8 +77,9 @@
   }
 
   @Override
-  public BasicBlock split(IRCode code, ListIterator<BasicBlock> blockIterator) {
-    return currentBlockIterator.split(code, blockIterator);
+  public BasicBlock split(
+      IRCode code, ListIterator<BasicBlock> blockIterator, boolean keepCatchHandlers) {
+    return currentBlockIterator.split(code, blockIterator, keepCatchHandlers);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/UninstantiatedTypeOptimization.java b/src/main/java/com/android/tools/r8/ir/optimize/UninstantiatedTypeOptimization.java
index 85f0274..d381afb 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/UninstantiatedTypeOptimization.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/UninstantiatedTypeOptimization.java
@@ -329,7 +329,7 @@
     AssumeDynamicTypeRemover assumeDynamicTypeRemover = new AssumeDynamicTypeRemover(appView, code);
     Set<BasicBlock> blocksToBeRemoved = Sets.newIdentityHashSet();
     ListIterator<BasicBlock> blockIterator = code.listIterator();
-    Set<Value> valuesToNarrow = Sets.newIdentityHashSet();
+    Set<Value> affectedValues = Sets.newIdentityHashSet();
     while (blockIterator.hasNext()) {
       BasicBlock block = blockIterator.next();
       if (blocksToBeRemoved.contains(block)) {
@@ -342,7 +342,7 @@
           Value couldBeNullValue = instruction.getNonNullInput();
           if (isThrowNullCandidate(couldBeNullValue, instruction, appView, code.context())) {
             instructionIterator.replaceCurrentInstructionWithThrowNull(
-                appView, code, blockIterator, blocksToBeRemoved, valuesToNarrow);
+                appView, code, blockIterator, blocksToBeRemoved, affectedValues);
             continue;
           }
         }
@@ -353,17 +353,17 @@
               instructionIterator,
               code,
               assumeDynamicTypeRemover,
-              blocksToBeRemoved,
-              valuesToNarrow);
+              affectedValues,
+              blocksToBeRemoved);
         }
       }
     }
     assumeDynamicTypeRemover.removeMarkedInstructions(blocksToBeRemoved).finish();
     code.removeBlocks(blocksToBeRemoved);
-    code.removeAllDeadAndTrivialPhis(valuesToNarrow);
+    code.removeAllDeadAndTrivialPhis(affectedValues);
     code.removeUnreachableBlocks();
-    if (!valuesToNarrow.isEmpty()) {
-      new TypeAnalysis(appView).narrowing(valuesToNarrow);
+    if (!affectedValues.isEmpty()) {
+      new TypeAnalysis(appView).narrowing(affectedValues);
     }
     assert code.isConsistentSSA();
   }
@@ -399,8 +399,8 @@
       InstructionListIterator instructionIterator,
       IRCode code,
       AssumeDynamicTypeRemover assumeDynamicTypeRemover,
-      Set<BasicBlock> blocksToBeRemoved,
-      Set<Value> affectedValues) {
+      Set<Value> affectedValues,
+      Set<BasicBlock> blocksToBeRemoved) {
     DexEncodedMethod target = invoke.lookupSingleTarget(appView, code.context());
     if (target == null) {
       return;
@@ -411,6 +411,7 @@
       for (int i = 0; i < invoke.arguments().size(); i++) {
         Value argument = invoke.arguments().get(i);
         if (argument.isAlwaysNull(appView) && facts.get(i)) {
+          instructionIterator.next();
           instructionIterator.replaceCurrentInstructionWithThrowNull(
               appView, code, blockIterator, blocksToBeRemoved, affectedValues);
           return;
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
index a1c69fa..83c5243 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
@@ -141,7 +141,9 @@
             .setMinApi(parameters.getApiLevel())
             .run(parameters.getRuntime(), MAIN)
             .assertSuccessWithOutput(JAVA_OUTPUT);
-    test(result, 0, 0);
+    // TODO(b/157427150): would be able to remove the call to requireNonNull() if we knew that it
+    //  throws an NullPointerException that does not have a message.
+    test(result, 0, 1);
   }
 
   static class ObjectsRequireNonNullTestMain {
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ThrowNPEWithMessageIfParameterIsNullTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ThrowNPEWithMessageIfParameterIsNullTest.java
new file mode 100644
index 0000000..924278e
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ThrowNPEWithMessageIfParameterIsNullTest.java
@@ -0,0 +1,58 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class ThrowNPEWithMessageIfParameterIsNullTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public ThrowNPEWithMessageIfParameterIsNullTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(ThrowNPEWithMessageIfParameterIsNullTest.class)
+        .addKeepMainRule(TestClass.class)
+        .enableInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("Hello world!");
+  }
+
+  static class TestClass {
+
+    public static void main(String[] args) {
+      try {
+        checkNotNull(null);
+      } catch (Exception e) {
+        System.out.println(e.getMessage());
+      }
+    }
+
+    @NeverInline
+    static void checkNotNull(Object o) {
+      if (o == null) {
+        throw new NullPointerException("Hello world!");
+      }
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/uninstantiatedtypes/InvokeMethodWithNonNullParamCheckTest.java b/src/test/java/com/android/tools/r8/ir/optimize/uninstantiatedtypes/InvokeMethodWithNonNullParamCheckTest.java
index 854ac14..828f4c0 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/uninstantiatedtypes/InvokeMethodWithNonNullParamCheckTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/uninstantiatedtypes/InvokeMethodWithNonNullParamCheckTest.java
@@ -7,7 +7,6 @@
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.IsNot.not;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotEquals;
 
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
@@ -107,15 +106,16 @@
                 .anyMatch(InstructionSubject::isThrow));
 
         if (shouldHaveThrow) {
-          // Check that there are no invoke instructions targeting the methods on `Static` and
-          // `Virtual`.
+          // TODO(b/157427150): Check that there are no invoke instructions targeting the methods on
+          //  `Static` and `Virtual`. This requires that we know that their methods throw
+          //  NullPointerExceptions without messages.
           Streams.stream(methodSubject.iterateInstructions())
               .filter(InstructionSubject::isInvoke)
               .forEach(
                   ins -> {
                     ClassSubject clazz = inspector.clazz(ins.getMethod().holder.toSourceString());
-                    assertNotEquals(clazz.getOriginalName(), Static.class.getTypeName());
-                    assertNotEquals(clazz.getOriginalName(), Virtual.class.getTypeName());
+                    // assertNotEquals(clazz.getOriginalName(), Static.class.getTypeName());
+                    // assertNotEquals(clazz.getOriginalName(), Virtual.class.getTypeName());
                   });
         }
 
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java b/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
index 8ea5334..b4d2c3a 100644
--- a/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
@@ -132,7 +132,8 @@
     }
 
     @Override
-    public BasicBlock split(IRCode code, ListIterator<BasicBlock> blockIterator) {
+    public BasicBlock split(
+        IRCode code, ListIterator<BasicBlock> blockIterator, boolean keepCatchHandlers) {
       throw new Unimplemented();
     }
 
diff --git a/src/test/java/com/android/tools/r8/shaking/array/DeadArrayLengthTest.java b/src/test/java/com/android/tools/r8/shaking/array/DeadArrayLengthTest.java
index 241abc5..14259ca 100644
--- a/src/test/java/com/android/tools/r8/shaking/array/DeadArrayLengthTest.java
+++ b/src/test/java/com/android/tools/r8/shaking/array/DeadArrayLengthTest.java
@@ -4,7 +4,6 @@
 package com.android.tools.r8.shaking.array;
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
-import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assume.assumeTrue;
@@ -48,13 +47,8 @@
     assertEquals(0, countArrayLength(nonNull));
 
     MethodSubject nullable = main.uniqueMethodWithName("isNullable");
-    if (isR8) {
-      // Replaced with null-throwing code at the call site.
-      assertThat(nullable, not(isPresent()));
-    } else {
-      assertThat(nullable, isPresent());
-      assertEquals(1, countArrayLength(nullable));
-    }
+    assertThat(nullable, isPresent());
+    assertEquals(isR8 ? 0 : 1, countArrayLength(nullable));
 
     MethodSubject nullCheck = main.uniqueMethodWithName("afterNullCheck");
     assertThat(nullCheck, isPresent());