Add non-null IR _after_ inlining.

non-null IRs have been added when building IR code for inlinee. Since
blocks in the inlinee aren't linked to the caller yet, users of non-null
value in the caller were not linked either, i.e., not shown to inlinee,
and thus those uses were not properly updated or sometimes addition of
non-null IR itself has been skipped.

The key change in this CL is to run NonNullTracker *after* linking the
inlinee blocks. Also, found a number of small bugs in NonNullTracker and
corresponding tests.

Bug: 76200247
Change-Id: I02d56e6a5ae8e06da5956d969f24b6267ba4c1ea
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 3feb59a..bed4df3 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -2023,7 +2023,8 @@
               simplifyIfWithKnownCondition(code, block, theIf, cond, color);
             }
           }
-        } else if (theIf.isZeroTest() && !inValues.get(0).isConstNumber()) {
+        } else if (theIf.isZeroTest() && !inValues.get(0).isConstNumber()
+            && (theIf.getType() == Type.EQ || theIf.getType() == Type.NE)) {
           if (inValues.get(0).isNeverNull()) {
             simplifyIfWithKnownCondition(code, block, theIf, 1, color);
           } else {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
index a5ac28c..7eb771c 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
@@ -55,7 +55,10 @@
 
   private final Set<DexMethod> blackList = Sets.newIdentityHashSet();
 
-  public Inliner(AppInfoWithLiveness appInfo, GraphLense graphLense, InternalOptions options) {
+  public Inliner(
+      AppInfoWithLiveness appInfo,
+      GraphLense graphLense,
+      InternalOptions options) {
     this.appInfo = appInfo;
     this.graphLense = graphLense;
     this.options = options;
@@ -269,9 +272,6 @@
       if (!target.isProcessed()) {
         new LensCodeRewriter(graphLense, appInfo).rewrite(code, target);
       }
-      if (options.enableNonNullTracking) {
-        new NonNullTracker().addNonNull(code);
-      }
       return code;
     }
   }
@@ -443,7 +443,17 @@
               iterator.previous();
               instruction_allowance -= numberOfInstructions(inlinee);
               if (instruction_allowance >= 0 || result.ignoreInstructionBudget()) {
-                iterator.inlineInvoke(code, inlinee, blockIterator, blocksToRemove, downcast);
+                BasicBlock invokeSuccessor =
+                    iterator.inlineInvoke(code, inlinee, blockIterator, blocksToRemove, downcast);
+                if (options.enableNonNullTracking) {
+                  // Move the cursor back to where the inlinee blocks are added.
+                  blockIterator = code.blocks.listIterator(code.blocks.indexOf(block));
+                  // Kick off the tracker to add non-null IRs only to the inlinee blocks.
+                  new NonNullTracker()
+                      .addNonNullInPart(code, blockIterator, inlinee.blocks::contains);
+                  // Move the cursor forward to where the inlinee blocks end.
+                  blockIterator = code.blocks.listIterator(code.blocks.indexOf(invokeSuccessor));
+                }
                 // Update type env for inlined blocks.
                 typeEnvironment.analyzeBlocks(inlinee.topologicallySortedBlocks());
                 // TODO(b/69964136): need a test where refined env in inlinee affects the caller.
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java b/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
index 4eae3d2..9aff28e 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
@@ -19,6 +19,7 @@
 import com.google.common.collect.Sets;
 import java.util.ListIterator;
 import java.util.Set;
+import java.util.function.Predicate;
 
 public class NonNullTracker {
 
@@ -56,9 +57,16 @@
   }
 
   public void addNonNull(IRCode code) {
-    ListIterator<BasicBlock> blocks = code.blocks.listIterator();
-    while (blocks.hasNext()) {
-      BasicBlock block = blocks.next();
+    addNonNullInPart(code, code.blocks.listIterator(), b -> true);
+  }
+
+  public void addNonNullInPart(
+      IRCode code, ListIterator<BasicBlock> blockIterator, Predicate<BasicBlock> blockTester) {
+    while (blockIterator.hasNext()) {
+      BasicBlock block = blockIterator.next();
+      if (!blockTester.test(block)) {
+        continue;
+      }
       // Add non-null after instructions that implicitly indicate receiver/array is not null.
       InstructionListIterator iterator = block.listIterator();
       while (iterator.hasNext()) {
@@ -90,7 +98,7 @@
         // A: ...y // blockWithNonNullInstruction
         //
         BasicBlock blockWithNonNullInstruction =
-            block.hasCatchHandlers() ? iterator.split(code, blocks) : block;
+            block.hasCatchHandlers() ? iterator.split(code, blockIterator) : block;
         // Next, add non-null fake IR, e.g.,
         // ...x
         // invoke(rcv, ...)
@@ -191,7 +199,7 @@
                 }
               }
               // Avoid adding a non-null for the value without meaningful users.
-              if (!dominatedUsers.isEmpty() && !dominatedPhiUsers.isEmpty()) {
+              if (!dominatedUsers.isEmpty() || !dominatedPhiUsers.isEmpty()) {
                 Value nonNullValue = code.createValue(
                     knownToBeNonNullValue.outType(), knownToBeNonNullValue.getLocalInfo());
                 NonNull nonNull = new NonNull(nonNullValue, knownToBeNonNullValue, theIf);
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/NonNullTrackerTest.java b/src/test/java/com/android/tools/r8/ir/optimize/NonNullTrackerTest.java
index 646cfaf..b6109d4 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/NonNullTrackerTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/NonNullTrackerTest.java
@@ -185,6 +185,6 @@
     buildAndTest(NonNullAfterNullCheck.class, bar, 1, this::checkInvokeGetsNonNullReceiver);
     MethodSignature baz =
         new MethodSignature("baz", "int", new String[]{"java.lang.String"});
-    buildAndTest(NonNullAfterNullCheck.class, baz, 1, this::checkInvokeGetsNullReceiver);
+    buildAndTest(NonNullAfterNullCheck.class, baz, 2, this::checkInvokeGetsNullReceiver);
   }
 }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/SimplifyIfNotNullTest.java b/src/test/java/com/android/tools/r8/ir/optimize/SimplifyIfNotNullTest.java
index 92e43d0..525e21b 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/SimplifyIfNotNullTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/SimplifyIfNotNullTest.java
@@ -30,8 +30,8 @@
 
   private void buildAndTest(Class<?> testClass, List<MethodSignature> signatures) throws Exception {
     AndroidApp app = buildAndroidApp(ToolHelper.getClassAsBytes(testClass));
-    AndroidApp r8Result = compileWithR8(
-        app, keepMainProguardConfiguration(testClass), o -> o.enableInlining = false);
+    AndroidApp r8Result = compileWithR8(app,
+        "-keep class " + testClass.getCanonicalName() + " { *; }");
     DexInspector dexInspector = new DexInspector(r8Result);
     for (MethodSignature signature : signatures) {
       DexEncodedMethod method =
diff --git a/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java b/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
index e324dee..ab6e949 100644
--- a/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
+++ b/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
@@ -43,9 +43,8 @@
       long count = Arrays.stream(dexCode.instructions)
           .filter(SimplifyIfNotNullKotlinTest::isIf).count();
       if (allowAccessModification) {
-        // TODO(b/76200247): 6 -> 5
         // Three null-check's from inlined checkParameterIsNotNull for receiver and two arguments.
-        assertEquals(6, count);
+        assertEquals(5, count);
       } else {
         // One after Iterator#hasNext, and another in the filter predicate: sinceYear != null.
         assertEquals(2, count);