Further unsplit live arguments after register allocation

This makes two optimizations to unsplitArguments():

* Instead of only unsplitting the split children of the argument live intervals when all the split children can be unsplit, this unsplits all of the split children that are eligible for unsplitting.

* Instead of pessimistically using that each argument will be assigned a register that is <= the maximum register, this fetches the unadjusted real register for each argument value.

This removes 1.8% of all moves in JetNews (only 0.3% of all moves in Composable functions).

Fixes: b/376654519
Change-Id: I5d92ca1c092119815bf97a1e61d06399ee6d1dc2
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/InsertMovesResult.java b/src/main/java/com/android/tools/r8/ir/regalloc/InsertMovesResult.java
new file mode 100644
index 0000000..574ebee
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/InsertMovesResult.java
@@ -0,0 +1,25 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.regalloc;
+
+public class InsertMovesResult {
+
+  private final LinearScanRegisterAllocator allocator;
+  private final int numberOfParallelMoveTemporaryRegisters;
+
+  public InsertMovesResult(
+      LinearScanRegisterAllocator allocator, int numberOfParallelMoveTemporaryRegisters) {
+    this.allocator = allocator;
+    this.numberOfParallelMoveTemporaryRegisters = numberOfParallelMoveTemporaryRegisters;
+  }
+
+  public int getNumberOfParallelMoveTemporaryRegisters() {
+    return numberOfParallelMoveTemporaryRegisters;
+  }
+
+  public void revert() {
+    allocator.removeSpillAndPhiMoves();
+    allocator.removeParallelMoveTemporaryRegisters();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index 58af88f..592b6bc 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -66,7 +66,9 @@
 import it.unimi.dsi.fastutil.ints.IntSet;
 import it.unimi.dsi.fastutil.objects.Reference2IntArrayMap;
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
+import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
@@ -672,14 +674,46 @@
   // Compute a table that for each register numbers contains the number of previous register
   // numbers that were unused. This table is then used to slide down the actual registers
   // used to fill the gaps.
-  private boolean computeUnusedRegisters() {
+  private void computeUnusedRegisters() {
+    unusedRegisters = internalComputeUnusedRegisters();
+  }
+
+  private void recomputeUnusedRegisters() {
+    int[] newUnusedRegisters = internalComputeUnusedRegisters();
+    assert verifyNoUsesOfPreviouslyUnusedRegisters(newUnusedRegisters);
+    unusedRegisters = newUnusedRegisters;
+  }
+
+  private int[] internalComputeUnusedRegisters() {
     if (mode.is4Bit() || registersUsed() == 0) {
-      return false;
+      return null;
     }
     // Compute the table based on the set of used registers.
     IntSet usedRegisters = computeUsedRegisters();
-    unusedRegisters = computeUnusedRegistersFromUsedRegisters(usedRegisters);
-    return ArrayUtils.lastOrDefault(unusedRegisters, 0) > 0;
+    return computeUnusedRegistersFromUsedRegisters(usedRegisters);
+  }
+
+  private boolean verifyNoChangesToUnusedRegisters() {
+    assert Arrays.equals(unusedRegisters, internalComputeUnusedRegisters());
+    return true;
+  }
+
+  private boolean verifyNoUsesOfPreviouslyUnusedRegisters(int[] newUnusedRegisters) {
+    // We only recompute the unused registers when at least one argument live intervals was unsplit,
+    // thus we always compute a non-null unused registers result.
+    assert unusedRegisters != null;
+    assert newUnusedRegisters != null;
+    assert unusedRegisters.length == newUnusedRegisters.length;
+    int previousNumberOfUnusedRegisters = 0;
+    int previousNumberOfNewUnusedRegisters = 0;
+    for (int i = 0; i < unusedRegisters.length; i++) {
+      boolean wasRegisterUnused = previousNumberOfUnusedRegisters != unusedRegisters[i];
+      boolean isRegisterUnused = previousNumberOfNewUnusedRegisters != newUnusedRegisters[i];
+      assert !wasRegisterUnused || isRegisterUnused;
+      previousNumberOfUnusedRegisters = unusedRegisters[i];
+      previousNumberOfNewUnusedRegisters = newUnusedRegisters[i];
+    }
+    return true;
   }
 
   private IntSet computeUsedRegisters() {
@@ -693,8 +727,10 @@
     }
     // Additionally, we have used temporary registers for parallel move scheduling, those
     // are used as well.
-    for (int i = firstParallelMoveTemporary; i < maxRegisterNumber + 1; i++) {
-      usedRegisters.add(i);
+    if (firstParallelMoveTemporary != NO_REGISTER) {
+      for (int i = firstParallelMoveTemporary; i < maxRegisterNumber + 1; i++) {
+        usedRegisters.add(i);
+      }
     }
     return usedRegisters;
   }
@@ -708,12 +744,13 @@
   }
 
   private int[] computeUnusedRegistersFromUsedRegisters(IntSet usedRegisters) {
-    assert firstParallelMoveTemporary != NO_REGISTER;
     int firstLocalRegister = numberOfArgumentRegisters + getMoveExceptionOffsetForLocalRegisters();
     assert verifyRegistersBeforeFirstLocalRegisterAreUsed(firstLocalRegister, usedRegisters);
-    int numberOfParallelMoveTemporaryRegisters = registersUsed() - firstParallelMoveTemporary;
+    int registersUsed = unadjustedRegistersUsed();
+    int numberOfParallelMoveTemporaryRegisters =
+        firstParallelMoveTemporary != NO_REGISTER ? registersUsed - firstParallelMoveTemporary : 0;
     int numberOfLocalRegisters =
-        registersUsed() - firstLocalRegister - numberOfParallelMoveTemporaryRegisters;
+        registersUsed - firstLocalRegister - numberOfParallelMoveTemporaryRegisters;
     int unused = 0;
     int[] unusedRegisters = new int[numberOfLocalRegisters];
     for (int i = 0; i < numberOfLocalRegisters; i++) {
@@ -746,6 +783,10 @@
     return numberOfRegister;
   }
 
+  private int unadjustedRegistersUsed() {
+    return maxRegisterNumber + 1;
+  }
+
   @Override
   public int getRegisterForValue(Value value, int instructionNumber) {
     if (value.isFixedRegisterValue()) {
@@ -825,15 +866,55 @@
 
     boolean succeeded = performLinearScan(mode);
     if (succeeded) {
-      insertMoves();
-      // Now that we know the max register number we can compute whether it is safe to use
-      // argument registers in place. If it is, we redo move insertion to get rid of the moves
-      // caused by splitting of the argument registers.
-      if (unsplitArguments()) {
-        removeSpillAndPhiMoves();
-        insertMoves();
-      }
+      InsertMovesResult insertMovesResult = insertMoves();
+
+      // We only compute unused information for local registers (temporary registers for parallel
+      // move scheduling and argument registers are never unused). Therefore, we can already compute
+      // unused registers now. This can help lead to more aggressive argument unsplitting, since
+      // this effectively lowers the real argument registers.
       computeUnusedRegisters();
+
+      // After having finished move insertion (which can allocate temporary registers for parallel
+      // move scheduling), we now know the final registers of the arguments. If we have moved some
+      // arguments down to low registers, but the input argument register itself ended up being in a
+      // low register, then we can avoid the move into a low register by just using the argument
+      // register directly. This is achieved by updating the register assignment to argument split
+      // live intervals.
+      UnsplitArgumentsResult unsplitArgumentsResult = unsplitArguments();
+      if (unsplitArgumentsResult != null) {
+        // If any changes were made, we need to redo move insertion.
+        insertMovesResult.revert();
+        InsertMovesResult newInsertMovesResult = insertMoves();
+        int iterations = 0;
+
+        // In some cases, the new move insertion may lead to more temporary registers being used for
+        // parallel move scheduling. This is rare (e.g., never happens when compiling JetNews). If
+        // that happened, the argument registers are now in higher registers, meaning we may have
+        // invalidated the argument unsplitting. We therefore (partially) revert the argument
+        // unsplitting and redo move insertion.
+        while (newInsertMovesResult.getNumberOfParallelMoveTemporaryRegisters()
+            > insertMovesResult.getNumberOfParallelMoveTemporaryRegisters()) {
+          assert iterations < 5;
+          boolean changed = unsplitArgumentsResult.revertPartial();
+          if (changed) {
+            // We invalidated the unsplit arguments optimization (or some of it). Redo move
+            // insertion and check again.
+            insertMovesResult = newInsertMovesResult;
+            insertMovesResult.revert();
+            newInsertMovesResult = insertMoves();
+          } else {
+            // Although we used more parallel move temporary registers this did not invalidate the
+            // unsplit arguments result.
+            break;
+          }
+          iterations++;
+        }
+        if (unsplitArgumentsResult.isFullyReverted()) {
+          assert verifyNoChangesToUnusedRegisters();
+        } else {
+          recomputeUnusedRegisters();
+        }
+      }
     } else {
       assert mode.is4Bit();
     }
@@ -896,52 +977,40 @@
   // we can get the argument into low enough registers at uses that require low numbers. After
   // register allocation we can check if it is safe to just use the argument register itself
   // for all uses and thereby avoid moving argument values around.
-  // TODO(b/376654519): This unsplits the entire argument live intervals or does nothing. Couldn't
-  //  we save some moves by partially unsplitting the argument live intervals?
-  private boolean unsplitArguments() {
+  private UnsplitArgumentsResult unsplitArguments() {
     if (mode.is4Bit()) {
-      return false;
+      return null;
     }
-    boolean argumentRegisterUnsplit = false;
+    Reference2IntMap<LiveIntervals> originalRegisterAssignment = new Reference2IntOpenHashMap<>();
+    originalRegisterAssignment.defaultReturnValue(NO_REGISTER);
     for (Value current = firstArgumentValue;
         current != null;
         current = current.getNextConsecutive()) {
       LiveIntervals intervals = current.getLiveIntervals();
+      int conservativeRealRegisterEnd = realRegisterNumberFromAllocated(intervals.getRegisterEnd());
       assert !mode.hasRegisterConstraint(intervals)
           || (mode.is8BitRefinement()
               && intervals.getRegisterEnd() < numberOf4BitArgumentRegisters);
-      boolean canUseArgumentRegister = true;
-      boolean couldUseArgumentRegister = true;
       for (LiveIntervals child : intervals.getSplitChildren()) {
-        if (child.isInvokeRangeIntervals()) {
-          canUseArgumentRegister = false;
-          break;
-        }
-        int registerConstraint = child.getRegisterLimit();
-        if (registerConstraint < Constants.U16BIT_MAX) {
-          couldUseArgumentRegister = false;
-
-          if (registerConstraint < highestUsedRegister()) {
-            canUseArgumentRegister = false;
-            break;
-          }
-        }
-      }
-      if (canUseArgumentRegister && !couldUseArgumentRegister) {
-        // Only return true if there is a constrained use where it turns out that we can use the
-        // original argument register. This way we will not unnecessarily redo move insertion.
-        argumentRegisterUnsplit = true;
-        for (LiveIntervals child : intervals.getSplitChildren()) {
+        if (!child.isInvokeRangeIntervals()
+            && conservativeRealRegisterEnd <= child.getRegisterLimit()
+            && child.getRegister() != intervals.getRegister()) {
+          originalRegisterAssignment.put(child, child.getRegister());
           child.clearRegisterAssignment();
           child.setRegister(intervals.getRegister());
-          child.setSpilled(false);
+          // If the child could be spilled then we would need to unset it here + update
+          // UnsplitArgumentsResult#revertPartial to account for this.
+          assert !child.isSpilled();
         }
       }
     }
-    return argumentRegisterUnsplit;
+    if (!originalRegisterAssignment.isEmpty()) {
+      return new UnsplitArgumentsResult(this, originalRegisterAssignment);
+    }
+    return null;
   }
 
-  private void removeSpillAndPhiMoves() {
+  void removeSpillAndPhiMoves() {
     for (BasicBlock block : code.blocks) {
       InstructionListIterator it = block.listIterator(code);
       while (it.hasNext()) {
@@ -953,6 +1022,13 @@
     }
   }
 
+  void removeParallelMoveTemporaryRegisters() {
+    if (firstParallelMoveTemporary != NO_REGISTER) {
+      maxRegisterNumber = firstParallelMoveTemporary - 1;
+      firstParallelMoveTemporary = NO_REGISTER;
+    }
+  }
+
   private static boolean isSpillInstruction(Instruction instruction) {
     Value outValue = instruction.outValue();
     if (outValue != null && outValue.isFixedRegisterValue()) {
@@ -972,6 +1048,7 @@
     maxRegisterNumber = -1;
     active.clear();
     expiredHere.clear();
+    firstParallelMoveTemporary = NO_REGISTER;
     inactive.clear();
     unhandled.clear();
     moveExceptionIntervals.clear();
@@ -2921,7 +2998,9 @@
     }
   }
 
-  private void insertMoves() {
+  // Returns the number of added parallel move temporary registers.
+  private InsertMovesResult insertMoves() {
+    assert firstParallelMoveTemporary == NO_REGISTER;
     computeRematerializableBits();
 
     SpillMoveSet spillMoves = new SpillMoveSet(this, code, appView);
@@ -2943,8 +3022,14 @@
     }
 
     resolveControlFlow(spillMoves);
-    firstParallelMoveTemporary = maxRegisterNumber + 1;
-    maxRegisterNumber += spillMoves.scheduleAndInsertMoves(maxRegisterNumber + 1);
+    int firstParallelMoveTemporaryRegister = maxRegisterNumber + 1;
+    int numberOfParallelMoveTemporaryRegisters =
+        spillMoves.scheduleAndInsertMoves(firstParallelMoveTemporaryRegister);
+    if (numberOfParallelMoveTemporaryRegisters > 0) {
+      firstParallelMoveTemporary = firstParallelMoveTemporaryRegister;
+      maxRegisterNumber += numberOfParallelMoveTemporaryRegisters;
+    }
+    return new InsertMovesResult(this, numberOfParallelMoveTemporaryRegisters);
   }
 
   private void computeRematerializableBits() {
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/UnsplitArgumentsResult.java b/src/main/java/com/android/tools/r8/ir/regalloc/UnsplitArgumentsResult.java
new file mode 100644
index 0000000..5b0810d
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/UnsplitArgumentsResult.java
@@ -0,0 +1,58 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.regalloc;
+
+import static com.android.tools.r8.ir.regalloc.LiveIntervals.NO_REGISTER;
+
+import com.android.tools.r8.ir.code.Value;
+import it.unimi.dsi.fastutil.objects.Reference2IntMap;
+
+public class UnsplitArgumentsResult {
+
+  private final LinearScanRegisterAllocator allocator;
+  private final Reference2IntMap<LiveIntervals> originalRegisterAssignment;
+
+  public UnsplitArgumentsResult(
+      LinearScanRegisterAllocator allocator,
+      Reference2IntMap<LiveIntervals> originalRegisterAssignment) {
+    assert originalRegisterAssignment.defaultReturnValue() == NO_REGISTER;
+    this.allocator = allocator;
+    this.originalRegisterAssignment = originalRegisterAssignment;
+  }
+
+  public boolean isFullyReverted() {
+    return originalRegisterAssignment.isEmpty();
+  }
+
+  // Returns true if any changes were made.
+  public boolean revertPartial() {
+    boolean changed = false;
+    for (Value argument = allocator.firstArgumentValue;
+        argument != null;
+        argument = argument.getNextConsecutive()) {
+      for (LiveIntervals child : argument.getLiveIntervals().getSplitChildren()) {
+        changed |= revertPartial(child);
+      }
+    }
+    return changed;
+  }
+
+  private boolean revertPartial(LiveIntervals intervals) {
+    int originalRegister = originalRegisterAssignment.getInt(intervals);
+    if (originalRegister == NO_REGISTER) {
+      // This live intervals was not affected by the unsplit arguments optimization.
+      return false;
+    }
+    int conservativeRealRegisterEnd =
+        allocator.realRegisterNumberFromAllocated(intervals.getRegisterEnd());
+    if (conservativeRealRegisterEnd <= intervals.getRegister()) {
+      return false;
+    }
+    // Apply revert.
+    intervals.clearRegisterAssignment();
+    intervals.setRegister(originalRegister);
+    originalRegisterAssignment.removeInt(intervals);
+    return true;
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentInLowRegisterWithMoreThan16RegistersTest.java b/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentInLowRegisterWithMoreThan16RegistersTest.java
index 24fbe6d..e5a34c5 100644
--- a/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentInLowRegisterWithMoreThan16RegistersTest.java
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentInLowRegisterWithMoreThan16RegistersTest.java
@@ -44,7 +44,7 @@
                   inspector.clazz(Main.class).uniqueInstanceInitializer();
               assertThat(testMethodSubject, isPresent());
               assertEquals(
-                  2,
+                  1,
                   testMethodSubject
                       .streamInstructions()
                       .filter(InstructionSubject::isMove)