Make non-null-param-or-throw-NPE analysis aware of requireNonNull()

Change-Id: Ie3d053609c36cb388f64fd6959a4d0c4c48b979d
diff --git a/src/main/java/com/android/tools/r8/ir/code/ArrayGet.java b/src/main/java/com/android/tools/r8/ir/code/ArrayGet.java
index 0c118fb..17dcc5e 100644
--- a/src/main/java/com/android/tools/r8/ir/code/ArrayGet.java
+++ b/src/main/java/com/android/tools/r8/ir/code/ArrayGet.java
@@ -251,7 +251,7 @@
   }
 
   @Override
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
     return array() == value;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/ArrayLength.java b/src/main/java/com/android/tools/r8/ir/code/ArrayLength.java
index b2f908c..12ff943 100644
--- a/src/main/java/com/android/tools/r8/ir/code/ArrayLength.java
+++ b/src/main/java/com/android/tools/r8/ir/code/ArrayLength.java
@@ -8,7 +8,6 @@
 import com.android.tools.r8.cf.code.CfArrayLength;
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.AbstractError;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
@@ -140,7 +139,7 @@
   }
 
   @Override
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
     return array() == value;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/ArrayPut.java b/src/main/java/com/android/tools/r8/ir/code/ArrayPut.java
index b8d8526..8ab298a 100644
--- a/src/main/java/com/android/tools/r8/ir/code/ArrayPut.java
+++ b/src/main/java/com/android/tools/r8/ir/code/ArrayPut.java
@@ -15,7 +15,6 @@
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.type.ArrayTypeLatticeElement;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
@@ -252,7 +251,7 @@
   }
 
   @Override
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
     return array() == value;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/InstanceGet.java b/src/main/java/com/android/tools/r8/ir/code/InstanceGet.java
index 2a7f621..b070660 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InstanceGet.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InstanceGet.java
@@ -18,7 +18,6 @@
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexField;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis.AnalysisAssumption;
@@ -200,7 +199,7 @@
   }
 
   @Override
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
     return object() == value;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/InstancePut.java b/src/main/java/com/android/tools/r8/ir/code/InstancePut.java
index d073a90..00a5f30 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InstancePut.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InstancePut.java
@@ -18,7 +18,6 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexField;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis.AnalysisAssumption;
@@ -228,7 +227,7 @@
   }
 
   @Override
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
     return object() == value;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/Instruction.java b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
index 8cf2f40..b1459be 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Instruction.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
@@ -9,7 +9,6 @@
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DebugLocalInfo;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.AbstractError;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis.AnalysisAssumption;
@@ -1399,11 +1398,12 @@
    * given value is null at runtime execution.
    *
    * @param value the value representing an object that may be null at runtime execution.
-   * @param dexItemFactory where pre-defined descriptors are retrieved
+   * @param appView where pre-defined descriptors are retrieved
+   * @param context
    * @return true if the instruction throws NullPointerException if value is null at runtime, false
    *     otherwise.
    */
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
     return false;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/InvokeMethod.java b/src/main/java/com/android/tools/r8/ir/code/InvokeMethod.java
index d6219f4..401b4df 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InvokeMethod.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InvokeMethod.java
@@ -26,6 +26,7 @@
 import com.android.tools.r8.ir.regalloc.RegisterAllocator;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Sets;
+import java.util.BitSet;
 import java.util.Collection;
 import java.util.List;
 import java.util.Set;
@@ -221,4 +222,17 @@
     assert lookupDirectTargetOnItself == hierarchyResult;
     return true;
   }
+
+  @Override
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
+    DexEncodedMethod singleTarget = lookupSingleTarget(appView, context);
+    if (singleTarget != null) {
+      BitSet nonNullParamOrThrow = singleTarget.getOptimizationInfo().getNonNullParamOrThrow();
+      if (nonNullParamOrThrow != null) {
+        int argumentIndex = inValues.indexOf(value);
+        return argumentIndex >= 0 && nonNullParamOrThrow.get(argumentIndex);
+      }
+    }
+    return false;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/code/InvokeMethodWithReceiver.java b/src/main/java/com/android/tools/r8/ir/code/InvokeMethodWithReceiver.java
index a42988e..e15a387 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InvokeMethodWithReceiver.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InvokeMethodWithReceiver.java
@@ -5,7 +5,6 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
@@ -52,8 +51,8 @@
   }
 
   @Override
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
-    return getReceiver() == value;
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
+    return value == getReceiver() || super.throwsNpeIfValueIsNull(value, appView, context);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/Monitor.java b/src/main/java/com/android/tools/r8/ir/code/Monitor.java
index 78b9b64..598bc6c 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Monitor.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Monitor.java
@@ -12,7 +12,6 @@
 import com.android.tools.r8.code.MonitorExit;
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.conversion.CfBuilder;
 import com.android.tools.r8.ir.conversion.DexBuilder;
@@ -142,7 +141,7 @@
   }
 
   @Override
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
     return object() == value;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/Throw.java b/src/main/java/com/android/tools/r8/ir/code/Throw.java
index 8a513c2..115e9fd 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Throw.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Throw.java
@@ -6,7 +6,7 @@
 import com.android.tools.r8.cf.LoadStoreHelper;
 import com.android.tools.r8.cf.code.CfThrow;
 import com.android.tools.r8.dex.Constants;
-import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.conversion.CfBuilder;
@@ -87,7 +87,7 @@
   }
 
   @Override
-  public boolean throwsNpeIfValueIsNull(Value value, DexItemFactory dexItemFactory) {
+  public boolean throwsNpeIfValueIsNull(Value value, AppView<?> appView, DexType context) {
     if (exception() == value) {
       return true;
     }
@@ -105,7 +105,7 @@
     if (!aliasedValue.isPhi()) {
       Instruction definition = aliasedValue.getDefinition();
       if (definition.isNewInstance()
-          && definition.asNewInstance().clazz == dexItemFactory.npeType) {
+          && definition.asNewInstance().clazz == appView.dexItemFactory().npeType) {
         // throw new NullPointerException()
         return true;
       }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 55fbf2c..60813ab 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -251,6 +251,10 @@
         // Check for the patterns 'if (x == null) throw null' and
         // 'if (x == null) throw new NullPointerException()'.
         if (instruction.isIf()) {
+          if (appView.dexItemFactory().objectsMethods.isRequireNonNullMethod(code.method.method)) {
+            continue;
+          }
+
           If ifInstruction = instruction.asIf();
           if (!ifInstruction.isZeroTest()) {
             continue;
@@ -287,7 +291,6 @@
           }
 
           rewriteIfToRequireNonNull(
-              code,
               block,
               it,
               ifInstruction,
@@ -2914,7 +2917,6 @@
   }
 
   private void rewriteIfToRequireNonNull(
-      IRCode code,
       BasicBlock block,
       InstructionListIterator iterator,
       If theIf,
@@ -2925,8 +2927,8 @@
     assert theIf == block.exit();
     iterator.previous();
     Instruction instruction;
-    DexMethod requireNonNullMethod = appView.dexItemFactory().objectsMethods.requireNonNull;
-    if (appView.options().canUseRequireNonNull() && code.method.method != requireNonNullMethod) {
+    if (appView.options().canUseRequireNonNull()) {
+      DexMethod requireNonNullMethod = appView.dexItemFactory().objectsMethods.requireNonNull;
       instruction = new InvokeStatic(requireNonNullMethod, null, ImmutableList.of(theIf.lhs()));
     } else {
       DexMethod getClassMethod = appView.dexItemFactory().objectMembers.getClass;
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java b/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
index 71e7a85..9a01028 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
@@ -782,7 +782,7 @@
           if (isInstantiationOfNullPointerException(instr, it, appView.dexItemFactory())) {
             it.next(); // Skip call to NullPointerException.<init>.
             return InstructionEffect.NO_EFFECT;
-          } else if (instr.throwsNpeIfValueIsNull(value, appView.dexItemFactory())) {
+          } else if (instr.throwsNpeIfValueIsNull(value, appView, code.method.holder())) {
             // In order to preserve NPE semantic, the exception must not be caught by any handler.
             // Therefore, we must ignore this instruction if it is covered by a catch handler.
             // Note: this is a conservative approach where we consider that any catch handler could
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackIgnore.java b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackIgnore.java
index 5d0f4ef..66aefb4 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackIgnore.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackIgnore.java
@@ -123,8 +123,7 @@
   }
 
   @Override
-  public void setNonNullParamOrThrow(DexEncodedMethod method, BitSet facts) {
-  }
+  public void setNonNullParamOrThrow(DexEncodedMethod method, BitSet facts) {}
 
   @Override
   public void setNonNullParamOnNormalExits(DexEncodedMethod method, BitSet facts) {
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
index a963edf..a1c69fa 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
@@ -191,7 +191,8 @@
 
       unknownArg(instance);
       try {
-        unknownArg(null);
+        Foo alwaysNull = System.currentTimeMillis() > 0 ? null : instance;
+        unknownArg(alwaysNull);
         throw new AssertionError("Expected NullPointerException");
       } catch (NullPointerException npe) {
         System.out.println("Expected NPE");