Find more phis that have known-to-be-non-null values.

Currently, after adding a non-null IR, phi users, whose blocks are
dominated by the block with that non-null IR, are identified as
dominated phi users. This is conservative in that we only detected
phis only if all of its known-to-be-non-null operands are dominated,
e.g., phi(v0, v1, v1) -> phi(v0, vn, vn) where phi's block is dominated
by the block with vn.

We should consider incoming edges, instead of phi's block:
if the operand is non-null, and the corresponding predecessor is
dominated by the block with non-null IR, we can selectively replace
that operand with known-to-be-non-null value,
e.g., phi(v0, v1, v1) -> phi(v0, vn, v1) where the 2nd predecessor is
dominated by the block with vn.

Refer to http://b/76202537#comment4 and http://b/76202537#comment5 for
more detailed discussion and examples.

Bug: 76202537
Change-Id: I39e05bfe3ff070b1fdf83f89d528c7facae8973c
diff --git a/src/main/java/com/android/tools/r8/ir/code/Phi.java b/src/main/java/com/android/tools/r8/ir/code/Phi.java
index eadd2db..f5fab7d 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Phi.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Phi.java
@@ -142,7 +142,7 @@
     operands.addAll(copy.subList(current, copy.size()));
   }
 
-  public void replace(int predIndex, Value newValue) {
+  public void replaceOperandAt(int predIndex, Value newValue) {
     Value current = operands.get(predIndex);
     operands.set(predIndex, newValue);
     newValue.addPhiUser(this);
diff --git a/src/main/java/com/android/tools/r8/ir/code/Value.java b/src/main/java/com/android/tools/r8/ir/code/Value.java
index 7e51161..3cfb0c3 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Value.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Value.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.LongInterval;
 import com.google.common.collect.ImmutableSet;
+import it.unimi.dsi.fastutil.ints.IntList;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
@@ -403,7 +404,9 @@
   }
 
   public void replaceSelectiveUsers(
-      Value newValue, Set<Instruction> selectedInstructions, Set<Phi> selectedPhis) {
+      Value newValue,
+      Set<Instruction> selectedInstructions,
+      Map<Phi, IntList> selectedPhisWithPredecessorIndexes) {
     if (this == newValue) {
       return;
     }
@@ -416,10 +419,19 @@
         user.replaceValue(this, newValue);
       }
     }
+    Set<Phi> selectedPhis = selectedPhisWithPredecessorIndexes.keySet();
     for (Phi user : uniquePhiUsers()) {
       if (selectedPhis.contains(user)) {
-        fullyRemovePhiUser(user);
-        user.replaceOperand(this, newValue);
+        long count = user.getOperands().stream().filter(operand -> operand == this).count();
+        IntList positionsToUpdate = selectedPhisWithPredecessorIndexes.get(user);
+        // We may not _fully_ remove this from the phi, e.g., phi(v0, v1, v1) -> phi(v0, vn, v1).
+        if (count == positionsToUpdate.size()) {
+          fullyRemovePhiUser(user);
+        }
+        for (int position : positionsToUpdate) {
+          assert user.getOperand(position) == this;
+          user.replaceOperandAt(position, newValue);
+        }
       }
     }
     if (debugData != null) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java b/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
index 362112f..e6de678 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
@@ -20,6 +20,7 @@
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.optimize.Inliner.Constraint;
 import com.android.tools.r8.shaking.Enqueuer.AppInfoWithLiveness;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import java.util.IdentityHashMap;
 import java.util.ListIterator;
@@ -64,7 +65,7 @@
             InvokeVirtual devirtualizedInvoke = devirtualizedCall.get(origin.asInvokeInterface());
             if (dominatorTree.dominatedBy(block, devirtualizedInvoke.getBlock())) {
               nonNull.src().replaceSelectiveUsers(
-                  devirtualizedInvoke.getReceiver(), ImmutableSet.of(nonNull), ImmutableSet.of());
+                  devirtualizedInvoke.getReceiver(), ImmutableSet.of(nonNull), ImmutableMap.of());
             }
           }
         }
@@ -172,7 +173,7 @@
             }
 
             receiver.replaceSelectiveUsers(
-                newReceiver, ImmutableSet.of(devirtualizedInvoke), ImmutableSet.of());
+                newReceiver, ImmutableSet.of(devirtualizedInvoke), ImmutableMap.of());
             // TODO(b/72693244): Analyze it when creating a new Value or after replace*Users
             typeEnvironment.enqueue(newReceiver);
             typeEnvironment.analyze();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java b/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
index 9aff28e..8af4486 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
@@ -17,7 +17,13 @@
 import com.android.tools.r8.ir.code.ValueType;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Sets;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.ints.IntList;
+import java.util.IdentityHashMap;
+import java.util.Iterator;
+import java.util.List;
 import java.util.ListIterator;
+import java.util.Map;
 import java.util.Set;
 import java.util.function.Predicate;
 
@@ -122,7 +128,7 @@
         // propagated through dominance.
         Set<Instruction> users = knownToBeNonNullValue.uniqueUsers();
         Set<Instruction> dominatedUsers = Sets.newIdentityHashSet();
-        Set<Phi> dominatedPhiUsers = Sets.newIdentityHashSet();
+        Map<Phi, IntList> dominatedPhiUsersWithPotisions = new IdentityHashMap<>();
         DominatorTree dominatorTree = new DominatorTree(code);
         Set<BasicBlock> dominatedBlocks = Sets.newIdentityHashSet();
         for (BasicBlock dominatee : dominatorTree.dominatedBlocks(blockWithNonNullInstruction)) {
@@ -142,12 +148,14 @@
           }
         }
         for (Phi user : knownToBeNonNullValue.uniquePhiUsers()) {
-          if (dominatedBlocks.contains(user.getBlock())) {
-            dominatedPhiUsers.add(user);
+          IntList dominatedPredecessorIndexes =
+              findDominatedPredecessorIndexesInPhi(user, knownToBeNonNullValue, dominatedBlocks);
+          if (!dominatedPredecessorIndexes.isEmpty()) {
+            dominatedPhiUsersWithPotisions.put(user, dominatedPredecessorIndexes);
           }
         }
         knownToBeNonNullValue.replaceSelectiveUsers(
-            nonNullValue, dominatedUsers, dominatedPhiUsers);
+            nonNullValue, dominatedUsers, dominatedPhiUsersWithPotisions);
       }
 
       // Add non-null on top of the successor block if the current block ends with a null check.
@@ -185,7 +193,7 @@
             if (dominatorTree.dominatedBy(target, block)) {
               // Collect users of the original value that are dominated by the target block.
               Set<Instruction> dominatedUsers = Sets.newIdentityHashSet();
-              Set<Phi> dominatedPhiUsers = Sets.newIdentityHashSet();
+              Map<Phi, IntList> dominatedPhiUsersWithPositions = new IdentityHashMap<>();
               Set<BasicBlock> dominatedBlocks =
                   Sets.newHashSet(dominatorTree.dominatedBlocks(target));
               for (Instruction user : knownToBeNonNullValue.uniqueUsers()) {
@@ -194,12 +202,14 @@
                 }
               }
               for (Phi user : knownToBeNonNullValue.uniquePhiUsers()) {
-                if (dominatedBlocks.contains(user.getBlock())) {
-                  dominatedPhiUsers.add(user);
+                IntList dominatedPredecessorIndexes = findDominatedPredecessorIndexesInPhi(
+                    user, knownToBeNonNullValue, dominatedBlocks);
+                if (!dominatedPredecessorIndexes.isEmpty()) {
+                  dominatedPhiUsersWithPositions.put(user, dominatedPredecessorIndexes);
                 }
               }
               // Avoid adding a non-null for the value without meaningful users.
-              if (!dominatedUsers.isEmpty() || !dominatedPhiUsers.isEmpty()) {
+              if (!dominatedUsers.isEmpty() || !dominatedPhiUsersWithPositions.isEmpty()) {
                 Value nonNullValue = code.createValue(
                     knownToBeNonNullValue.outType(), knownToBeNonNullValue.getLocalInfo());
                 NonNull nonNull = new NonNull(nonNullValue, knownToBeNonNullValue, theIf);
@@ -208,7 +218,7 @@
                 targetIterator.previous();
                 targetIterator.add(nonNull);
                 knownToBeNonNullValue.replaceSelectiveUsers(
-                    nonNullValue, dominatedUsers, dominatedPhiUsers);
+                    nonNullValue, dominatedUsers, dominatedPhiUsersWithPositions);
               }
             }
           }
@@ -217,6 +227,31 @@
     }
   }
 
+  private IntList findDominatedPredecessorIndexesInPhi(
+      Phi user, Value knownToBeNonNullValue, Set<BasicBlock> dominatedBlocks) {
+    assert user.getOperands().contains(knownToBeNonNullValue);
+    List<Value> operands = user.getOperands();
+    List<BasicBlock> predecessors = user.getBlock().getPredecessors();
+    assert operands.size() == predecessors.size();
+
+    IntList predecessorIndexes = new IntArrayList();
+    int index = 0;
+    Iterator<Value> operandIterator = operands.iterator();
+    Iterator<BasicBlock> predecessorIterator = predecessors.iterator();
+    while (operandIterator.hasNext() && predecessorIterator.hasNext()) {
+      Value operand = operandIterator.next();
+      BasicBlock predecessor = predecessorIterator.next();
+      // When this phi is chosen to be known-to-be-non-null value,
+      // check if the corresponding predecessor is dominated by the block where non-null is added.
+      if (operand == knownToBeNonNullValue && dominatedBlocks.contains(predecessor)) {
+        predecessorIndexes.add(index);
+      }
+
+      index++;
+    }
+    return predecessorIndexes;
+  }
+
   public void cleanupNonNull(IRCode code) {
     InstructionIterator it = code.instructionIterator();
     boolean needToCheckTrivialPhis = false;
diff --git a/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java b/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
index ab6e949..2b1c33f 100644
--- a/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
+++ b/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
@@ -68,14 +68,8 @@
       DexCode dexCode = getDexCode(testMethod);
       long count = Arrays.stream(dexCode.instructions)
           .filter(SimplifyIfNotNullKotlinTest::isIf).count();
-      if (allowAccessModification) {
-        // TODO(b/76202537): 3 -> 2,
-        //   Yet another null-check from checkParameterIsNotNull should subsume another from ?:
-        assertEquals(3, count);
-      } else {
-        // One null-check from force inlined coalesce and another from ?:
-        assertEquals(2, count);
-      }
+      // One null-check from force inlined coalesce and another from ?:
+      assertEquals(2, count);
     });
   }