Retain messages for NPEs

Change-Id: I16bb94468cf6c6bda48929c02bba93a13f5515a2
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 45befcf..f7755d4 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -26,7 +26,6 @@
 import com.android.tools.r8.ir.analysis.equivalence.BasicBlockBehavioralSubsumption;
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
-import com.android.tools.r8.ir.analysis.type.TypeUtils;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.ir.code.AlwaysMaterializingNop;
 import com.android.tools.r8.ir.code.ArrayLength;
@@ -273,9 +272,34 @@
           }
 
           Throw throwInstruction = valueIsNullTarget.exit().asThrow();
-          Value exceptionValue = throwInstruction.exception();
-          if (!exceptionValue.isConstZero()
-              && !TypeUtils.isNullPointerException(exceptionValue.getTypeLattice(), appView)) {
+          Value exceptionValue = throwInstruction.exception().getAliasedValue();
+          Value message;
+          if (exceptionValue.isConstZero()) {
+            message = null;
+          } else if (exceptionValue.isDefinedByInstructionSatisfying(Instruction::isNewInstance)) {
+            NewInstance newInstance = exceptionValue.definition.asNewInstance();
+            if (newInstance.clazz != dexItemFactory.npeType) {
+              continue;
+            }
+            if (newInstance.outValue().numberOfAllUsers() != 2) {
+              continue; // Could be mutated before it is thrown.
+            }
+            InvokeDirect constructorCall = newInstance.getUniqueConstructorInvoke(dexItemFactory);
+            if (constructorCall == null) {
+              continue;
+            }
+            DexMethod invokedMethod = constructorCall.getInvokedMethod();
+            if (invokedMethod == dexItemFactory.npeMethods.init) {
+              message = null;
+            } else if (invokedMethod == dexItemFactory.npeMethods.initWithMessage) {
+              if (!appView.options().canUseRequireNonNull()) {
+                continue;
+              }
+              message = constructorCall.getArgument(1);
+            } else {
+              continue;
+            }
+          } else {
             continue;
           }
 
@@ -290,12 +314,26 @@
             continue;
           }
 
+          if (message != null) {
+            Instruction definition = message.definition;
+            if (message.definition.getBlock() == valueIsNullTarget) {
+              it.previous();
+              Instruction entry;
+              do {
+                entry = valueIsNullTarget.getInstructions().removeFirst();
+                it.add(entry);
+              } while (entry != definition);
+              it.next();
+            }
+          }
+
           rewriteIfToRequireNonNull(
               block,
               it,
               ifInstruction,
               ifInstruction.targetFromCondition(1),
               valueIsNullTarget,
+              message,
               throwInstruction.getPosition());
           shouldRemoveUnreachableBlocks = true;
         }
@@ -2922,15 +2960,24 @@
       If theIf,
       BasicBlock target,
       BasicBlock deadTarget,
+      Value message,
       Position position) {
     deadTarget.unlinkSinglePredecessorSiblingsAllowed();
     assert theIf == block.exit();
     iterator.previous();
     Instruction instruction;
     if (appView.options().canUseRequireNonNull()) {
-      DexMethod requireNonNullMethod = appView.dexItemFactory().objectsMethods.requireNonNull;
-      instruction = new InvokeStatic(requireNonNullMethod, null, ImmutableList.of(theIf.lhs()));
+      if (message != null) {
+        DexMethod requireNonNullMethod =
+            appView.dexItemFactory().objectsMethods.requireNonNullWithMessage;
+        instruction =
+            new InvokeStatic(requireNonNullMethod, null, ImmutableList.of(theIf.lhs(), message));
+      } else {
+        DexMethod requireNonNullMethod = appView.dexItemFactory().objectsMethods.requireNonNull;
+        instruction = new InvokeStatic(requireNonNullMethod, null, ImmutableList.of(theIf.lhs()));
+      }
     } else {
+      assert message == null;
       DexMethod getClassMethod = appView.dexItemFactory().objectMembers.getClass;
       instruction = new InvokeVirtual(getClassMethod, null, ImmutableList.of(theIf.lhs()));
     }
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 25b7387..37b38e9 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -1229,7 +1229,7 @@
   }
 
   public boolean canUseRequireNonNull() {
-    return isGeneratingClassFiles() || hasMinApi(AndroidApiLevel.K);
+    return isGeneratingDex() && hasMinApi(AndroidApiLevel.K);
   }
 
   public boolean canUseSuppressedExceptions() {
diff --git a/src/test/java/com/android/tools/r8/ToolHelper.java b/src/test/java/com/android/tools/r8/ToolHelper.java
index 039fb59..e184d97 100644
--- a/src/test/java/com/android/tools/r8/ToolHelper.java
+++ b/src/test/java/com/android/tools/r8/ToolHelper.java
@@ -249,6 +249,10 @@
         this.shortName = shortName;
       }
 
+      public boolean isDalvik() {
+        return isOlderThanOrEqual(Version.V4_4_4);
+      }
+
       public boolean isLatest() {
         return this == DEFAULT;
       }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/IfThrowNullPointerExceptionTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/IfThrowNullPointerExceptionTest.java
index 2264a29..e679e80 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/ifs/IfThrowNullPointerExceptionTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/IfThrowNullPointerExceptionTest.java
@@ -7,6 +7,7 @@
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assume.assumeTrue;
 
@@ -17,10 +18,11 @@
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.MethodSubject;
-import com.google.common.collect.ImmutableList;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -41,6 +43,15 @@
   }
 
   @Test
+  public void testJVM() throws Exception {
+    assumeTrue(parameters.isCfRuntime());
+    testForJvm()
+        .addTestClasspath()
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutput(getExpectedStdout());
+  }
+
+  @Test
   public void testD8() throws Exception {
     assumeTrue(parameters.isDexRuntime());
     testForD8()
@@ -50,7 +61,7 @@
         .compile()
         .inspect(this::inspect)
         .run(parameters.getRuntime(), TestClass.class)
-        .assertSuccessWithOutputLines("Caught NPE", "Caught NPE");
+        .assertSuccessWithOutput(getExpectedStdout());
   }
 
   @Test
@@ -62,55 +73,104 @@
         .compile()
         .inspect(this::inspect)
         .run(parameters.getRuntime(), TestClass.class)
-        .assertSuccessWithOutputLines("Caught NPE", "Caught NPE");
+        .assertSuccessWithOutput(getExpectedStdout());
   }
 
   private void inspect(CodeInspector inspector) {
     ClassSubject classSubject = inspector.clazz(TestClass.class);
     assertThat(classSubject, isPresent());
+    inspectMethod(inspector, classSubject, "testThrowNPE", false, true);
+    inspectMethod(inspector, classSubject, "testThrowNPEWithMessage", true, canUseRequireNonNull());
+    inspectMethod(inspector, classSubject, "testThrowNull", false, true);
+  }
 
-    for (String methodName : ImmutableList.of("testThrowNPE", "testThrowNull")) {
-      MethodSubject methodSubject = classSubject.uniqueMethodWithName(methodName);
-      assertThat(methodSubject, isPresent());
+  private void inspectMethod(
+      CodeInspector inspector,
+      ClassSubject classSubject,
+      String methodName,
+      boolean isNPEWithMessage,
+      boolean shouldBeOptimized) {
+    MethodSubject methodSubject = classSubject.uniqueMethodWithName(methodName);
+    assertThat(methodSubject, isPresent());
 
-      IRCode code = methodSubject.buildIR();
+    IRCode code = methodSubject.buildIR();
+    if (shouldBeOptimized) {
       assertEquals(1, code.blocks.size());
 
       BasicBlock entryBlock = code.entryBlock();
-      assertEquals(3, entryBlock.getInstructions().size());
+      assertEquals(
+          3 + BooleanUtils.intValue(isNPEWithMessage), entryBlock.getInstructions().size());
       assertTrue(entryBlock.getInstructions().getFirst().isArgument());
       assertTrue(entryBlock.getInstructions().getLast().isReturn());
 
-      Instruction nullCheckInstruction = entryBlock.getInstructions().get(1);
-      if (parameters.isDexRuntime() && parameters.getApiLevel().isLessThan(AndroidApiLevel.K)) {
+      Instruction nullCheckInstruction =
+          entryBlock.getInstructions().get(1 + BooleanUtils.intValue(isNPEWithMessage));
+      if (canUseRequireNonNull()) {
+        assertTrue(nullCheckInstruction.isInvokeStatic());
+        if (isNPEWithMessage) {
+          assertEquals(
+              inspector.getFactory().objectsMethods.requireNonNullWithMessage,
+              nullCheckInstruction.asInvokeStatic().getInvokedMethod());
+        } else {
+          assertEquals(
+              inspector.getFactory().objectsMethods.requireNonNull,
+              nullCheckInstruction.asInvokeStatic().getInvokedMethod());
+        }
+      } else {
+        assertFalse(isNPEWithMessage);
         assertTrue(nullCheckInstruction.isInvokeVirtual());
         assertEquals(
-            "java.lang.Class java.lang.Object.getClass()",
-            nullCheckInstruction.asInvokeVirtual().getInvokedMethod().toSourceString());
-      } else {
-        assertTrue(nullCheckInstruction.isInvokeStatic());
-        assertEquals(
-            "java.lang.Object java.util.Objects.requireNonNull(java.lang.Object)",
-            nullCheckInstruction.asInvokeStatic().getInvokedMethod().toSourceString());
+            inspector.getFactory().objectMembers.getClass,
+            nullCheckInstruction.asInvokeVirtual().getInvokedMethod());
       }
+    } else {
+      assertEquals(3, code.blocks.size());
     }
   }
 
+  private String getExpectedStdout() {
+    if (parameters.isCfRuntime() || canUseRequireNonNull() || isDalvik()) {
+      return StringUtils.lines("Caught NPE: null", "Caught NPE: x was null", "Caught NPE: null");
+    }
+    return StringUtils.lines(
+        "Caught NPE: Attempt to invoke virtual method 'java.lang.Class java.lang.Object.getClass()'"
+            + " on a null object reference",
+        "Caught NPE: x was null",
+        "Caught NPE: Attempt to invoke virtual method 'java.lang.Class java.lang.Object.getClass()'"
+            + " on a null object reference");
+  }
+
+  private boolean canUseRequireNonNull() {
+    return parameters.isDexRuntime()
+        && parameters.getApiLevel().isGreaterThanOrEqualTo(AndroidApiLevel.K);
+  }
+
+  private boolean isDalvik() {
+    return parameters.isDexRuntime()
+        && parameters.getRuntime().asDex().getVm().getVersion().isDalvik();
+  }
+
   static class TestClass {
 
     public static void main(String[] args) {
       testThrowNPE(new Object());
+      testThrowNPEWithMessage(new Object());
       testThrowNull(new Object());
 
       try {
         testThrowNPE(null);
       } catch (NullPointerException e) {
-        System.out.println("Caught NPE");
+        System.out.println("Caught NPE: " + e.getMessage());
+      }
+      try {
+        testThrowNPEWithMessage(null);
+      } catch (NullPointerException e) {
+        System.out.println("Caught NPE: " + e.getMessage());
       }
       try {
         testThrowNull(null);
       } catch (NullPointerException e) {
-        System.out.println("Caught NPE");
+        System.out.println("Caught NPE: " + e.getMessage());
       }
     }
 
@@ -120,6 +180,12 @@
       }
     }
 
+    static void testThrowNPEWithMessage(Object x) {
+      if (x == null) {
+        throw new NullPointerException("x was null");
+      }
+    }
+
     static void testThrowNull(Object x) {
       if (x == null) {
         throw null;
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/inliner/InlineInvokeWithNullableReceiverTest.java b/src/test/java/com/android/tools/r8/ir/optimize/inliner/InlineInvokeWithNullableReceiverTest.java
index d8b26ba..dce7b2d 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/inliner/InlineInvokeWithNullableReceiverTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/inliner/InlineInvokeWithNullableReceiverTest.java
@@ -68,8 +68,7 @@
     assertThat(methodSubject, isPresent());
 
     // A `throw` instruction should have been synthesized into main().
-    if (parameters.isCfRuntime()
-        || parameters.getApiLevel().isGreaterThanOrEqualTo(AndroidApiLevel.K)) {
+    if (canUseRequireNonNull()) {
       assertTrue(methodSubject.streamInstructions().anyMatch(InstructionSubject::isInvokeStatic));
     } else {
       assertTrue(
@@ -92,6 +91,11 @@
     assertThat(otherClassSubject.uniqueMethodWithName("m"), not(isPresent()));
   }
 
+  private boolean canUseRequireNonNull() {
+    return parameters.isDexRuntime()
+        && parameters.getApiLevel().isGreaterThanOrEqualTo(AndroidApiLevel.K);
+  }
+
   static class TestClass {
 
     public static void main(String[] args) {
diff --git a/src/test/java/com/android/tools/r8/retrace/InlineWithoutNullCheckTest.java b/src/test/java/com/android/tools/r8/retrace/InlineWithoutNullCheckTest.java
index 1796dbf..3cabc0e 100644
--- a/src/test/java/com/android/tools/r8/retrace/InlineWithoutNullCheckTest.java
+++ b/src/test/java/com/android/tools/r8/retrace/InlineWithoutNullCheckTest.java
@@ -203,8 +203,8 @@
   }
 
   private boolean canUseRequireNonNull() {
-    return parameters.isCfRuntime()
-        || parameters.getApiLevel().isGreaterThanOrEqualTo(AndroidApiLevel.K);
+    return parameters.isDexRuntime()
+        && parameters.getApiLevel().isGreaterThanOrEqualTo(AndroidApiLevel.K);
   }
 
   static class TestClassForInlineMethod {