Account for current register allocation mode when spilling

This reduces the number of moves in JetNews by 1.6%.

Fixes: b/375142715
Change-Id: I2fdeacccc227f3f27c262d5d08ee1e501d36048d
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index f60eb2e..c02e5a1 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -112,6 +112,21 @@
     ALLOW_ARGUMENT_REUSE_U8BIT_RETRY,
     ALLOW_ARGUMENT_REUSE_U16BIT;
 
+    int getMaxRegisterNumber() {
+      switch (this) {
+        case ALLOW_ARGUMENT_REUSE_U4BIT:
+          return Constants.U4BIT_MAX;
+        case ALLOW_ARGUMENT_REUSE_U8BIT:
+        case ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT:
+        case ALLOW_ARGUMENT_REUSE_U8BIT_RETRY:
+          return Constants.U8BIT_MAX;
+        case ALLOW_ARGUMENT_REUSE_U16BIT:
+          return Constants.U16BIT_MAX;
+        default:
+          throw new Unreachable();
+      }
+    }
+
     boolean hasRegisterConstraint(LiveIntervals intervals) {
       return hasRegisterConstraint(intervals.getRegisterLimit());
     }
@@ -917,7 +932,7 @@
         }
       }
     } else {
-      assert mode.is4Bit();
+      assert !mode.is16Bit();
     }
 
     switch (mode) {
@@ -1138,10 +1153,10 @@
       advanceStateToLiveIntervals(unhandledInterval);
 
       // Perform the actual allocation.
-      if (!allocateSingleInterval(unhandledInterval)) {
+      if (!allocateSingleInterval(unhandledInterval)
+          || maxRegisterNumber > mode.getMaxRegisterNumber()) {
         return false;
       }
-
       expiredHere.clear();
     }
     assert invariantsHold(mode);
@@ -1173,15 +1188,16 @@
             LiveIntervals split;
             if (argumentInterval.numberOfUsesWithConstraint() == 1) {
               // If there is only one register-constrained use, split before that one use.
-              split = argumentInterval.splitBefore(use.getPosition());
+              split = argumentInterval.splitBefore(use.getPosition(), mode);
             } else {
               // If there are multiple register-constrained users, split right after the definition
               // to make it more likely that arguments get in usable registers from the start.
               // TODO(christofferqa): This is not great if there are many arguments with multiple
               // constrained uses, since we fill up all the low registers immediately, making it
               // likely that we will have to kick them back out before they are actually used.
-              split = argumentInterval
-                  .splitBefore(argumentInterval.getValue().definition.getNumber() + 1);
+              split =
+                  argumentInterval.splitBefore(
+                      argumentInterval.getValue().definition.getNumber() + 1, mode);
             }
             unhandled.add(split);
           }
@@ -1209,11 +1225,11 @@
       MoveException moveException = block.entry().asMoveException();
       LiveIntervals intervals = moveException.outValue().getLiveIntervals();
       if (intervals.getValue().hasAnyUsers()) {
-        LiveIntervals split = intervals.splitAfter(intervals.getValue().getDefinition());
+        LiveIntervals split = intervals.splitAfter(intervals.getValue().getDefinition(), mode);
         unhandled.add(split);
       }
       if (intervals.getStart() < moveException.getNumber()) {
-        intervals = intervals.splitBefore(moveException);
+        intervals = intervals.splitBefore(moveException, mode);
       } else {
         unhandled.remove(intervals);
       }
@@ -1237,12 +1253,12 @@
         if (overlappingIntervals.getStart() == toGapPosition(invoke.getNumber())) {
           invokeRangeIntervals = overlappingIntervals;
         } else {
-          invokeRangeIntervals = overlappingIntervals.splitBefore(invoke);
+          invokeRangeIntervals = overlappingIntervals.splitBefore(invoke, mode);
           unhandled.add(invokeRangeIntervals);
         }
         invokeRangeIntervals.setIsInvokeRangeIntervals();
         if (invoke.getNumber() + 1 < invokeRangeIntervals.getEnd()) {
-          LiveIntervals successorIntervals = invokeRangeIntervals.splitAfter(invoke);
+          LiveIntervals successorIntervals = invokeRangeIntervals.splitAfter(invoke, mode);
           unhandled.add(successorIntervals);
         }
         hasInvokeRangeLiveIntervals = true;
@@ -2039,11 +2055,18 @@
       // of finding another candidate to spill via allocateBlockedRegister.
       assert unhandledInterval.hasUses();
       if (!unhandledInterval.getUses().first().hasConstraint()) {
-        int nextConstrainedPosition = unhandledInterval.firstUseWithConstraint(mode).getPosition();
-        int register = getSpillRegister(unhandledInterval, null);
-        LiveIntervals split = unhandledInterval.splitBefore(nextConstrainedPosition);
-        assignFreeRegisterToUnhandledInterval(unhandledInterval, register);
-        unhandled.add(split);
+        if (mode.hasRegisterConstraint(unhandledInterval)) {
+          int nextConstrainedPosition =
+              unhandledInterval.firstUseWithConstraint(mode).getPosition();
+          int register = getSpillRegister(unhandledInterval, null);
+          LiveIntervals split = unhandledInterval.splitBefore(nextConstrainedPosition, mode);
+          assignFreeRegisterToUnhandledInterval(unhandledInterval, register);
+          unhandled.add(split);
+        } else {
+          assert unhandledInterval.firstUseWithConstraint(mode) == null;
+          int register = getSpillRegister(unhandledInterval, null);
+          assignFreeRegisterToUnhandledInterval(unhandledInterval, register);
+        }
       } else {
         allocateBlockedRegister(unhandledInterval, registerConstraint);
       }
@@ -2063,7 +2086,7 @@
         // The candidate is free for the beginning of an interval. We split the interval
         // and use the register for as long as we can.
         int registerConstraintBeforeSplit = unhandledInterval.getRegisterLimit();
-        LiveIntervals split = unhandledInterval.splitBefore(largestFreePosition);
+        LiveIntervals split = unhandledInterval.splitBefore(largestFreePosition, mode);
         assert split != unhandledInterval;
         unhandled.add(split);
 
@@ -2378,7 +2401,7 @@
     if (!expiredHere.isEmpty()) {
       return false;
     }
-    LiveIntervals split = blockingInterval.splitBefore(unhandledInterval.getStart());
+    LiveIntervals split = blockingInterval.splitBefore(unhandledInterval.getStart(), mode);
     freeOccupiedRegistersForIntervals(blockingInterval);
     assignFreeRegisterToUnhandledInterval(unhandledInterval, blockingInterval.getRegister());
     active.remove(blockingInterval);
@@ -2621,7 +2644,9 @@
           if (activeRegister + i <= registerConstraint) {
             int unhandledStart = unhandledInterval.getStart();
             usePositions.set(
-                activeRegister + i, intervals.firstUseAfter(unhandledStart), intervals);
+                activeRegister + i,
+                intervals.firstUseWithConstraintAfter(unhandledStart, mode),
+                intervals);
           }
         }
       }
@@ -2633,7 +2658,8 @@
       if (inactiveRegister <= registerConstraint && intervals.overlaps(unhandledInterval)) {
         for (int i = 0; i < intervals.requiredRegisters(); i++) {
           if (inactiveRegister + i <= registerConstraint) {
-            int firstUse = intervals.firstUseAfter(unhandledInterval.getStart());
+            int firstUse =
+                intervals.firstUseWithConstraintAfter(unhandledInterval.getStart(), mode);
             if (firstUse < usePositions.get(inactiveRegister + i)) {
               usePositions.set(inactiveRegister + i, firstUse, intervals);
             }
@@ -2723,7 +2749,7 @@
       // All active and inactive intervals are used before current. Therefore, it is best to spill
       // current itself.
       int splitPosition = unhandledInterval.getFirstUse();
-      LiveIntervals split = unhandledInterval.splitBefore(splitPosition);
+      LiveIntervals split = unhandledInterval.splitBefore(splitPosition, mode);
       assert split != unhandledInterval;
       // Experiments show that it has a positive impact on code size to use a fresh register here.
       int registerNumber = getNewSpillRegister(unhandledInterval);
@@ -2743,7 +2769,7 @@
         assignRegisterAndSpill(unhandledInterval, candidate, needsRegisterPair);
       } else {
         // Spilling only makes a register available for the first part of current.
-        LiveIntervals splitChild = unhandledInterval.splitBefore(blockedPosition);
+        LiveIntervals splitChild = unhandledInterval.splitBefore(blockedPosition, mode);
         unhandled.add(splitChild);
         assignRegisterAndSpill(unhandledInterval, candidate, needsRegisterPair);
       }
@@ -2787,7 +2813,7 @@
           // of the inactive interval and therefore do not have to split here.
           int nextUsePosition = intervals.firstUseAfter(unhandledInterval.getStart());
           if (nextUsePosition != Integer.MAX_VALUE) {
-            LiveIntervals split = intervals.splitBefore(nextUsePosition);
+            LiveIntervals split = intervals.splitBefore(nextUsePosition, mode);
             split.setRegister(intervals.getRegister());
             newInactive.add(split);
           }
@@ -2801,7 +2827,7 @@
         } else {
           // The inactive live intervals is in a live range hole. Split the interval and
           // put the ranges after the hole into the unhandled set for register reassignment.
-          LiveIntervals split = intervals.splitBefore(unhandledInterval.getStart());
+          LiveIntervals split = intervals.splitBefore(unhandledInterval.getStart(), mode);
           unhandled.add(split);
         }
       }
@@ -2838,7 +2864,7 @@
         // because we might otherwise end up spilling to the current registers of intervals,
         // depending on getSpillRegister.
         freeOccupiedRegistersForIntervals(intervals);
-        LiveIntervals splitChild = intervals.splitBefore(unhandledInterval.getStart());
+        LiveIntervals splitChild = intervals.splitBefore(unhandledInterval.getStart(), mode);
         assignRegister(splitChild, registerNumber);
         splitChild.setSpilled(true);
         takeFreeRegistersForIntervals(splitChild);
@@ -2855,7 +2881,7 @@
         if (splitChild.hasUses()) {
           if (splitChild.isLinked() && !splitChild.isArgumentInterval()) {
             // Spilling a value with a pinned register. We need to move back at the next use.
-            LiveIntervals splitOfSplit = splitChild.splitBefore(splitChild.getFirstUse());
+            LiveIntervals splitOfSplit = splitChild.splitBefore(splitChild.getFirstUse(), mode);
             splitOfSplit.setRegister(intervals.getRegister());
             inactive.add(splitOfSplit);
           } else if (intervals.getValue().isConstNumber()) {
@@ -2881,8 +2907,12 @@
     // that is yet, and therefore we split before the next use to make sure we get a usable
     // register at the next use.
     if (!spilled.getUses().isEmpty()) {
-      LiveIntervals split = spilled.splitBefore(spilled.getUses().first().getPosition());
-      unhandled.add(split);
+      LiveIntervals split = spilled.splitBefore(spilled.getUses().first().getPosition(), mode);
+      if (split != spilled) {
+        unhandled.add(split);
+      } else {
+        spilled.setRegister(spilled.getSplitParent().getRegister());
+      }
     }
   }
 
@@ -2901,24 +2931,20 @@
         registerNumber = Constants.U16BIT_MAX;
       }
     }
-    LiveIntervalsUse firstUseWithLowerLimit = null;
-    boolean hasUsesBeforeFirstUseWithLowerLimit = false;
-    int highestRegisterNumber = registerNumber + spilled.requiredRegisters() - 1;
-    for (LiveIntervalsUse use : spilled.getUses()) {
-      if (highestRegisterNumber > use.getLimit()) {
-        firstUseWithLowerLimit = use;
-        break;
+    LiveIntervalsUse firstUseWithConstraint = spilled.firstUseWithConstraint(mode);
+    if (firstUseWithConstraint != null) {
+      int register = spilled.getRegister();
+      LiveIntervals splitOfSplit = spilled.splitBefore(firstUseWithConstraint.getPosition(), mode);
+      if (splitOfSplit != spilled) {
+        unhandled.add(splitOfSplit);
       } else {
-        hasUsesBeforeFirstUseWithLowerLimit = true;
+        assert !spilled.hasRegister();
+        spilled.setRegister(register);
+        if (spilled.hasUses()) {
+          spilled.setSpilled(false);
+        }
       }
     }
-    if (hasUsesBeforeFirstUseWithLowerLimit) {
-      spilled.setSpilled(false);
-    }
-    if (firstUseWithLowerLimit != null) {
-      LiveIntervals splitOfSplit = spilled.splitBefore(firstUseWithLowerLimit.getPosition());
-      unhandled.add(splitOfSplit);
-    }
   }
 
   private void splitRangesForSpilledConstant(LiveIntervals spilled, int spillRegister) {
@@ -2932,10 +2958,14 @@
     assert !spilled.isLinked() || spilled.isArgumentInterval();
     // Do not split range if constant is reused by one of the eleven following instruction.
     int maxGapSize = 11 * INSTRUCTION_NUMBER_DELTA;
-    if (!spilled.getUses().isEmpty()) {
+    LiveIntervalsUse firstUseWithConstraint = spilled.firstUseWithConstraint(mode);
+    if (firstUseWithConstraint != null) {
       // Split at first use after the spill position and add to unhandled to get a register
       // assigned for rematerialization.
-      LiveIntervals split = spilled.splitBefore(spilled.getFirstUse());
+      LiveIntervals split = spilled.splitBefore(firstUseWithConstraint.getPosition(), mode);
+      if (spilled.hasUses()) {
+        spilled.setSpilled(false);
+      }
       unhandled.add(split);
       // Now repeatedly split for each use that is more than maxGapSize away from the previous use.
       boolean changed = true;
@@ -2946,14 +2976,14 @@
           if (use.getPosition() - previousUse > maxGapSize) {
             // Found a use that is more than gap size away from the previous use. Split after
             // the previous use.
-            split = split.splitBefore(previousUse + INSTRUCTION_NUMBER_DELTA);
+            split = split.splitBefore(previousUse + INSTRUCTION_NUMBER_DELTA, mode);
             // If the next use is not at the start of the new split, we split again at the next use
             // and spill the gap.
             if (toGapPosition(use.getPosition()) > split.getStart()) {
               assignRegister(split, spillRegister);
               split.setSpilled(true);
               inactive.add(split);
-              split = split.splitBefore(use.getPosition());
+              split = split.splitBefore(use.getPosition(), mode);
             }
             // |split| now starts at the next use - add it to unhandled to get a register
             // assigned for rematerialization.
@@ -2965,6 +2995,8 @@
           previousUse = use.getPosition();
         }
       }
+    } else if (spilled.hasUses()) {
+      spilled.setSpilled(false);
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
index 64d42b0..6eadee6 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
@@ -427,6 +427,18 @@
     return Integer.MAX_VALUE;
   }
 
+  public int firstUseWithConstraintAfter(int unhandledStart, ArgumentReuseMode mode) {
+    if (isInvokeRangeIntervals()) {
+      return getFirstUse();
+    }
+    for (LiveIntervalsUse use : uses) {
+      if (use.hasConstraint(mode) && use.getPosition() >= unhandledStart) {
+        return use.getPosition();
+      }
+    }
+    return Integer.MAX_VALUE;
+  }
+
   public boolean hasUses() {
     return !uses.isEmpty();
   }
@@ -451,17 +463,19 @@
     }
   }
 
-  public LiveIntervals splitBefore(Instruction instruction) {
-    return splitBefore(instruction.getNumber());
+  public LiveIntervals splitBefore(Instruction instruction, ArgumentReuseMode mode) {
+    return splitBefore(instruction.getNumber(), mode);
   }
 
-  public LiveIntervals splitAfter(Instruction instruction) {
-    return splitBefore(instruction.getNumber() + INSTRUCTION_NUMBER_DELTA);
+  public LiveIntervals splitAfter(Instruction instruction, ArgumentReuseMode mode) {
+    return splitBefore(instruction.getNumber() + INSTRUCTION_NUMBER_DELTA, mode);
   }
 
-  public LiveIntervals splitBefore(int start) {
+  public LiveIntervals splitBefore(int start, ArgumentReuseMode mode) {
     if (toInstructionPosition(start) == toInstructionPosition(getStart())) {
-      assert uses.size() == 0 || getFirstUse() != start;
+      assert uses.isEmpty()
+          || getFirstUse() != start
+          || (!uses.first().hasConstraint(mode) && !isInvokeRangeIntervals());
       register = NO_REGISTER;
       return this;
     }
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/SpillToHighUnusedArgumentRegisterTest.java b/src/test/java/com/android/tools/r8/ir/regalloc/SpillToHighUnusedArgumentRegisterTest.java
index 15ff3af..90b1d6b 100644
--- a/src/test/java/com/android/tools/r8/ir/regalloc/SpillToHighUnusedArgumentRegisterTest.java
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/SpillToHighUnusedArgumentRegisterTest.java
@@ -5,11 +5,12 @@
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertEquals;
 
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.dex.code.DexMoveFrom16;
 import com.android.tools.r8.dex.code.DexMoveResult;
 import com.android.tools.r8.graph.DexCode;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
@@ -56,13 +57,15 @@
                       .asDexInstruction()
                       .getInstruction();
 
-              // TODO(b/375142715): The test no longer spills the `i` value. Look into if the test
-              //  can be tweeked so that `i` is spilled, and validate that it is spilled to the
-              //  unused argument register.
-              assertTrue(
+              DexMoveFrom16 spillMove =
                   testMethodSubject
                       .streamInstructions()
-                      .noneMatch(i -> i.isMoveFrom(moveResult.AA)));
+                      .filter(i -> i.isMoveFrom(moveResult.AA))
+                      .collect(MoreCollectors.onlyElement())
+                      .asDexInstruction()
+                      .getInstruction();
+              int lastArgumentRegister = code.registerSize - 1;
+              assertEquals(lastArgumentRegister, spillMove.AA);
             });
   }