Rewrite if-null-throw-null to requireNonNull in release

Change-Id: Iee03b00c475af00c91be1014ecdf5895aef6fe8d
Bug: 150591081
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index c9491e3..6c9e84b 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -595,6 +595,8 @@
   public Map<DexMethod, Predicate<InvokeMethod>> libraryMethodsWithoutSideEffects =
       Streams.<Pair<DexMethod, Predicate<InvokeMethod>>>concat(
               Stream.of(new Pair<>(enumMethods.constructor, alwaysTrue())),
+              Stream.of(new Pair<>(npeMethods.init, alwaysTrue())),
+              Stream.of(new Pair<>(npeMethods.initWithMessage, alwaysTrue())),
               Stream.of(new Pair<>(objectMembers.constructor, alwaysTrue())),
               Stream.of(new Pair<>(objectMembers.getClass, alwaysTrue())),
               mapToPredicate(classMethods.getNames, alwaysTrue()),
@@ -625,12 +627,20 @@
 
   public Set<DexType> libraryTypesAssumedToBePresent =
       ImmutableSet.<DexType>builder()
-          .add(objectType, callableType, stringBufferType, stringBuilderType, stringType, enumType)
+          .add(
+              callableType,
+              enumType,
+              npeType,
+              objectType,
+              stringBufferType,
+              stringBuilderType,
+              stringType)
           .addAll(primitiveToBoxed.values())
           .build();
 
   public Set<DexType> libraryClassesWithoutStaticInitialization =
-      ImmutableSet.of(boxedBooleanType, enumType, objectType, stringBufferType, stringBuilderType);
+      ImmutableSet.of(
+          boxedBooleanType, enumType, npeType, objectType, stringBufferType, stringBuilderType);
 
   private boolean skipNameValidationForTesting = false;
 
@@ -951,16 +961,12 @@
 
   public class NullPointerExceptionMethods {
 
-    public final DexMethod init;
+    public final DexMethod init =
+        createMethod(npeType, createProto(voidType), constructorMethodName);
+    public final DexMethod initWithMessage =
+        createMethod(npeType, createProto(voidType, stringType), constructorMethodName);
 
-    private NullPointerExceptionMethods() {
-      init =
-          createMethod(
-              npeDescriptor,
-              constructorMethodName,
-              voidDescriptor,
-              DexString.EMPTY_ARRAY);
-    }
+    private NullPointerExceptionMethods() {}
   }
 
   /**
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/TypeUtils.java b/src/main/java/com/android/tools/r8/ir/analysis/type/TypeUtils.java
new file mode 100644
index 0000000..912b379
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/TypeUtils.java
@@ -0,0 +1,15 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.analysis.type;
+
+import com.android.tools.r8.graph.AppView;
+
+public class TypeUtils {
+
+  public static boolean isNullPointerException(TypeLatticeElement type, AppView<?> appView) {
+    return type.isClassType()
+        && type.asClassTypeLatticeElement().getClassType() == appView.dexItemFactory().npeType;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/code/Instruction.java b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
index 9c50aab..8cf2f40 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Instruction.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
@@ -533,6 +533,29 @@
     return false;
   }
 
+  public boolean isBlockLocalInstructionWithoutSideEffects(AppView<?> appView, DexType context) {
+    return definesBlockLocalValue() && !instructionMayHaveSideEffects(appView, context);
+  }
+
+  private boolean definesBlockLocalValue() {
+    return !definesValueWithNonLocalUsages();
+  }
+
+  private boolean definesValueWithNonLocalUsages() {
+    if (hasOutValue()) {
+      Value outValue = outValue();
+      if (outValue.numberOfPhiUsers() > 0) {
+        return true;
+      }
+      for (Instruction user : outValue.uniqueUsers()) {
+        if (user.getBlock() != getBlock()) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
   public boolean instructionTypeCanBeCanonicalized() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/code/InvokeDirect.java b/src/main/java/com/android/tools/r8/ir/code/InvokeDirect.java
index db939e8..ec0c68f 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InvokeDirect.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InvokeDirect.java
@@ -159,10 +159,6 @@
 
   @Override
   public boolean instructionMayHaveSideEffects(AppView<?> appView, DexType context) {
-    if (!appView.enableWholeProgramOptimizations()) {
-      return true;
-    }
-
     if (appView.options().debug) {
       return true;
     }
@@ -181,42 +177,41 @@
     }
 
     // Find the target and check if the invoke may have side effects.
-    if (appView.appInfo().hasLiveness()) {
-      AppView<AppInfoWithLiveness> appViewWithLiveness = appView.withLiveness();
-      DexEncodedMethod target = lookupSingleTarget(appViewWithLiveness, context);
-      if (target == null) {
-        return true;
-      }
-
-      // Verify that the target method is accessible in the current context.
-      if (!isMemberVisibleFromOriginalContext(
-          appView, context, target.method.holder, target.accessFlags)) {
-        return true;
-      }
-
-      // Verify that the target method does not have side-effects.
-      DexClass clazz = appView.definitionFor(target.method.holder);
-      if (clazz == null) {
-        assert false : "Expected to be able to find the enclosing class of a method definition";
-        return true;
-      }
-
-      if (appViewWithLiveness.appInfo().noSideEffects.containsKey(target.method)) {
-        return false;
-      }
-
-      MethodOptimizationInfo optimizationInfo = target.getOptimizationInfo();
-      if (target.isInstanceInitializer()) {
-        InstanceInitializerInfo initializerInfo = optimizationInfo.getInstanceInitializerInfo();
-        if (!initializerInfo.mayHaveOtherSideEffectsThanInstanceFieldAssignments()) {
-          return false;
-        }
-      }
-
-      return optimizationInfo.mayHaveSideEffects();
+    if (!appView.appInfo().hasLiveness()) {
+      return true;
     }
 
-    return true;
+    AppView<AppInfoWithLiveness> appViewWithLiveness = appView.withLiveness();
+    DexEncodedMethod target = lookupSingleTarget(appViewWithLiveness, context);
+    if (target == null) {
+      return true;
+    }
+
+    // Verify that the target method is accessible in the current context.
+    if (!isMemberVisibleFromOriginalContext(
+        appView, context, target.method.holder, target.accessFlags)) {
+      return true;
+    }
+
+    // Verify that the target method does not have side-effects.
+    DexClass clazz = appView.definitionFor(target.method.holder);
+    if (clazz == null) {
+      assert false : "Expected to be able to find the enclosing class of a method definition";
+      return true;
+    }
+
+    if (appViewWithLiveness.appInfo().noSideEffects.containsKey(target.method)) {
+      return false;
+    }
+
+    MethodOptimizationInfo optimizationInfo = target.getOptimizationInfo();
+    if (target.isInstanceInitializer()) {
+      InstanceInitializerInfo initializerInfo = optimizationInfo.getInstanceInitializerInfo();
+      if (!initializerInfo.mayHaveOtherSideEffectsThanInstanceFieldAssignments()) {
+        return false;
+      }
+    }
+    return optimizationInfo.mayHaveSideEffects();
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/NewInstance.java b/src/main/java/com/android/tools/r8/ir/code/NewInstance.java
index de95e53..23ac981 100644
--- a/src/main/java/com/android/tools/r8/ir/code/NewInstance.java
+++ b/src/main/java/com/android/tools/r8/ir/code/NewInstance.java
@@ -142,8 +142,10 @@
 
   @Override
   public boolean instructionMayHaveSideEffects(AppView<?> appView, DexType context) {
+    DexItemFactory dexItemFactory = appView.dexItemFactory();
     if (!appView.enableWholeProgramOptimizations()) {
-      return true;
+      return !(dexItemFactory.libraryTypesAssumedToBePresent.contains(clazz)
+          && dexItemFactory.libraryClassesWithoutStaticInitialization.contains(clazz));
     }
 
     if (clazz.isPrimitiveType() || clazz.isArrayType()) {
@@ -157,7 +159,7 @@
     }
 
     if (definition.isLibraryClass()
-        && !appView.dexItemFactory().libraryTypesAssumedToBePresent.contains(clazz)) {
+        && !dexItemFactory.libraryTypesAssumedToBePresent.contains(clazz)) {
       return true;
     }
 
@@ -178,7 +180,6 @@
     }
 
     // Verify that the object does not have a finalizer.
-    DexItemFactory dexItemFactory = appView.dexItemFactory();
     ResolutionResult finalizeResolutionResult =
         appView.appInfo().resolveMethod(clazz, dexItemFactory.objectMembers.finalize);
     if (finalizeResolutionResult.isSingleResolution()) {
diff --git a/src/main/java/com/android/tools/r8/ir/code/Value.java b/src/main/java/com/android/tools/r8/ir/code/Value.java
index bdaa741..25ed39f 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Value.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Value.java
@@ -864,6 +864,10 @@
     return isConstant() && getConstInstruction().isConstNumber();
   }
 
+  public boolean isConstZero() {
+    return isConstNumber() && definition.asConstNumber().isZero();
+  }
+
   public boolean isConstString() {
     return isConstant() && getConstInstruction().isConstString();
   }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 5d12ded..f4e315b 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -1350,9 +1350,11 @@
       invertConditionalsForTesting(code);
     }
 
-    timing.begin("Rewrite throw NPE");
-    codeRewriter.rewriteThrowNullPointerException(code);
-    timing.end();
+    if (!isDebugMode) {
+      timing.begin("Rewrite throw NPE");
+      codeRewriter.rewriteThrowNullPointerException(code);
+      timing.end();
+    }
 
     timing.begin("Optimize class initializers");
     ClassInitializerDefaultsResult classInitializerDefaultsResult =
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index a8a4101..55fbf2c 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -26,6 +26,7 @@
 import com.android.tools.r8.ir.analysis.equivalence.BasicBlockBehavioralSubsumption;
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
+import com.android.tools.r8.ir.analysis.type.TypeUtils;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.ir.code.AlwaysMaterializingNop;
 import com.android.tools.r8.ir.code.ArrayLength;
@@ -241,10 +242,61 @@
 
   // Rewrite 'throw new NullPointerException()' to 'throw null'.
   public void rewriteThrowNullPointerException(IRCode code) {
+    boolean shouldRemoveUnreachableBlocks = false;
     for (BasicBlock block : code.blocks) {
       InstructionListIterator it = block.listIterator(code);
       while (it.hasNext()) {
         Instruction instruction = it.next();
+
+        // Check for the patterns 'if (x == null) throw null' and
+        // 'if (x == null) throw new NullPointerException()'.
+        if (instruction.isIf()) {
+          If ifInstruction = instruction.asIf();
+          if (!ifInstruction.isZeroTest()) {
+            continue;
+          }
+
+          Value value = ifInstruction.lhs();
+          if (!value.getTypeLattice().isReference()) {
+            assert value.getTypeLattice().isPrimitive();
+            continue;
+          }
+
+          BasicBlock valueIsNullTarget = ifInstruction.targetFromCondition(0);
+          if (valueIsNullTarget.getPredecessors().size() != 1
+              || !valueIsNullTarget.exit().isThrow()) {
+            continue;
+          }
+
+          Throw throwInstruction = valueIsNullTarget.exit().asThrow();
+          Value exceptionValue = throwInstruction.exception();
+          if (!exceptionValue.isConstZero()
+              && !TypeUtils.isNullPointerException(exceptionValue.getTypeLattice(), appView)) {
+            continue;
+          }
+
+          boolean canDetachValueIsNullTarget = true;
+          for (Instruction i : valueIsNullTarget.instructionsBefore(throwInstruction)) {
+            if (!i.isBlockLocalInstructionWithoutSideEffects(appView, code.method.holder())) {
+              canDetachValueIsNullTarget = false;
+              break;
+            }
+          }
+          if (!canDetachValueIsNullTarget) {
+            continue;
+          }
+
+          rewriteIfToRequireNonNull(
+              code,
+              block,
+              it,
+              ifInstruction,
+              ifInstruction.targetFromCondition(1),
+              valueIsNullTarget,
+              throwInstruction.getPosition());
+          shouldRemoveUnreachableBlocks = true;
+        }
+
         // Check for 'new-instance NullPointerException' with 2 users, not declaring a local and
         // not ending the scope of any locals.
         if (instruction.isNewInstance()
@@ -298,6 +350,12 @@
         }
       }
     }
+    if (shouldRemoveUnreachableBlocks) {
+      Set<Value> affectedValues = code.removeUnreachableBlocks();
+      if (!affectedValues.isEmpty()) {
+        new TypeAnalysis(appView).narrowing(affectedValues);
+      }
+    }
     assert code.isConsistentSSA();
   }
 
@@ -2855,6 +2913,33 @@
     assert block.exit().asGoto().getTarget() == target;
   }
 
+  private void rewriteIfToRequireNonNull(
+      IRCode code,
+      BasicBlock block,
+      InstructionListIterator iterator,
+      If theIf,
+      BasicBlock target,
+      BasicBlock deadTarget,
+      Position position) {
+    deadTarget.unlinkSinglePredecessorSiblingsAllowed();
+    assert theIf == block.exit();
+    iterator.previous();
+    Instruction instruction;
+    DexMethod requireNonNullMethod = appView.dexItemFactory().objectsMethods.requireNonNull;
+    if (appView.options().canUseRequireNonNull() && code.method.method != requireNonNullMethod) {
+      instruction = new InvokeStatic(requireNonNullMethod, null, ImmutableList.of(theIf.lhs()));
+    } else {
+      DexMethod getClassMethod = appView.dexItemFactory().objectMembers.getClass;
+      instruction = new InvokeVirtual(getClassMethod, null, ImmutableList.of(theIf.lhs()));
+    }
+    instruction.setPosition(position);
+    iterator.add(instruction);
+    iterator.next();
+    iterator.replaceCurrentInstruction(new Goto());
+    assert block.exit().isGoto();
+    assert block.exit().asGoto().getTarget() == target;
+  }
+
   private void rewriteIfWithConstZero(IRCode code, BasicBlock block) {
     If theIf = block.exit().asIf();
     if (theIf.isZeroTest()) {
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/IfThrowNullPointerExceptionTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/IfThrowNullPointerExceptionTest.java
new file mode 100644
index 0000000..2264a29
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/IfThrowNullPointerExceptionTest.java
@@ -0,0 +1,129 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import com.google.common.collect.ImmutableList;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class IfThrowNullPointerExceptionTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection params() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public IfThrowNullPointerExceptionTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    assumeTrue(parameters.isDexRuntime());
+    testForD8()
+        .addInnerClasses(IfThrowNullPointerExceptionTest.class)
+        .release()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("Caught NPE", "Caught NPE");
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(IfThrowNullPointerExceptionTest.class)
+        .addKeepClassAndMembersRules(TestClass.class)
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("Caught NPE", "Caught NPE");
+  }
+
+  private void inspect(CodeInspector inspector) {
+    ClassSubject classSubject = inspector.clazz(TestClass.class);
+    assertThat(classSubject, isPresent());
+
+    for (String methodName : ImmutableList.of("testThrowNPE", "testThrowNull")) {
+      MethodSubject methodSubject = classSubject.uniqueMethodWithName(methodName);
+      assertThat(methodSubject, isPresent());
+
+      IRCode code = methodSubject.buildIR();
+      assertEquals(1, code.blocks.size());
+
+      BasicBlock entryBlock = code.entryBlock();
+      assertEquals(3, entryBlock.getInstructions().size());
+      assertTrue(entryBlock.getInstructions().getFirst().isArgument());
+      assertTrue(entryBlock.getInstructions().getLast().isReturn());
+
+      Instruction nullCheckInstruction = entryBlock.getInstructions().get(1);
+      if (parameters.isDexRuntime() && parameters.getApiLevel().isLessThan(AndroidApiLevel.K)) {
+        assertTrue(nullCheckInstruction.isInvokeVirtual());
+        assertEquals(
+            "java.lang.Class java.lang.Object.getClass()",
+            nullCheckInstruction.asInvokeVirtual().getInvokedMethod().toSourceString());
+      } else {
+        assertTrue(nullCheckInstruction.isInvokeStatic());
+        assertEquals(
+            "java.lang.Object java.util.Objects.requireNonNull(java.lang.Object)",
+            nullCheckInstruction.asInvokeStatic().getInvokedMethod().toSourceString());
+      }
+    }
+  }
+
+  static class TestClass {
+
+    public static void main(String[] args) {
+      testThrowNPE(new Object());
+      testThrowNull(new Object());
+
+      try {
+        testThrowNPE(null);
+      } catch (NullPointerException e) {
+        System.out.println("Caught NPE");
+      }
+      try {
+        testThrowNull(null);
+      } catch (NullPointerException e) {
+        System.out.println("Caught NPE");
+      }
+    }
+
+    static void testThrowNPE(Object x) {
+      if (x == null) {
+        throw new NullPointerException();
+      }
+    }
+
+    static void testThrowNull(Object x) {
+      if (x == null) {
+        throw null;
+      }
+    }
+  }
+}