Try to use an existing, free register when spilling

Bug: 70650543
Change-Id: I287b3dacb7de73bcafe3f5ee79cfd92eb54b6e1e
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index a9df6a3..8fd09d0 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -40,8 +40,10 @@
 import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceMap.Entry;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
 import it.unimi.dsi.fastutil.ints.IntArraySet;
 import it.unimi.dsi.fastutil.ints.IntIterator;
+import it.unimi.dsi.fastutil.ints.IntList;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import it.unimi.dsi.fastutil.objects.Reference2IntArrayMap;
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
@@ -144,6 +146,10 @@
   // List of intervals that no register has been allocated to sorted by first live range.
   protected PriorityQueue<LiveIntervals> unhandled = new PriorityQueue<>();
 
+  // The registers that have been released as a result of advancing to the next live intervals.
+  // A register is released if an active or inactive interval becomes handled.
+  private IntList expiredHere = new IntArrayList();
+
   // List of intervals for the result of move-exception instructions.
   // Always empty in mode ALLOW_ARGUMENT_REUSE.
   private List<LiveIntervals> moveExceptionIntervals = new ArrayList<>();
@@ -826,6 +832,7 @@
     // Go through each unhandled live interval and find a register for it.
     while (!unhandled.isEmpty()) {
       assert invariantsHold(mode);
+      expiredHere.clear();
 
       LiveIntervals unhandledInterval = unhandled.poll();
 
@@ -849,6 +856,12 @@
         if (start >= activeIntervals.getEnd()) {
           activeIterator.remove();
           freeOccupiedRegistersForIntervals(activeIntervals);
+          if (start == activeIntervals.getEnd()) {
+            expiredHere.add(activeIntervals.getRegister());
+            if (activeIntervals.getType().isWide()) {
+              expiredHere.add(activeIntervals.getRegister() + 1);
+            }
+          }
         } else if (!activeIntervals.overlapsPosition(start)) {
           activeIterator.remove();
           assert activeIntervals.getRegister() != NO_REGISTER;
@@ -863,6 +876,12 @@
         LiveIntervals inactiveIntervals = inactiveIterator.next();
         if (start >= inactiveIntervals.getEnd()) {
           inactiveIterator.remove();
+          if (start == inactiveIntervals.getEnd()) {
+            expiredHere.add(inactiveIntervals.getRegister());
+            if (inactiveIntervals.getType().isWide()) {
+              expiredHere.add(inactiveIntervals.getRegister() + 1);
+            }
+          }
         } else if (inactiveIntervals.overlapsPosition(start)) {
           inactiveIterator.remove();
           assert inactiveIntervals.getRegister() != NO_REGISTER;
@@ -1153,17 +1172,137 @@
     return false;
   }
 
-  private int getSpillRegister(LiveIntervals intervals) {
+  private int getNewSpillRegister(LiveIntervals intervals) {
     if (intervals.isArgumentInterval()) {
       return intervals.getSplitParent().getRegister();
     }
 
     int register = maxRegisterNumber + 1;
     increaseCapacity(maxRegisterNumber + intervals.requiredRegisters());
+    return register;
+  }
+
+  private int getSpillRegister(LiveIntervals intervals, IntList excludedRegisters) {
+    if (intervals.isArgumentInterval()) {
+      return intervals.getSplitParent().getRegister();
+    }
+
+    TreeSet<Integer> previousFreeRegisters = new TreeSet<>(freeRegisters);
+    int previousMaxRegisterNumber = maxRegisterNumber;
+    freeRegisters.removeAll(expiredHere);
+    if (excludedRegisters != null) {
+      freeRegisters.removeAll(excludedRegisters);
+    }
+
+    // Check if we can use a register that was previously used as a register for intervals.
+    // This could lead to fewer moves during resolution.
+    int register = -1;
+    for (LiveIntervals split : intervals.getSplitParent().getSplitChildren()) {
+      int candidate = split.getRegister();
+      if (candidate != NO_REGISTER
+          && registersAreFreeAndConsecutive(candidate, intervals.getType().isWide())
+          && maySpillLiveIntervalsToRegister(intervals, candidate, previousMaxRegisterNumber)) {
+        register = candidate;
+        break;
+      }
+    }
+
+    if (register == -1) {
+      do {
+        // If the register needs to fit in 4 bits at the next use, then prioritize a small register.
+        // If we can find a small register, we do not need to insert a move at the next use.
+        boolean prioritizeSmallRegisters =
+            !intervals.getUses().isEmpty()
+                && intervals.getUses().first().getLimit() == Constants.U4BIT_MAX;
+        register =
+            getFreeConsecutiveRegisters(intervals.requiredRegisters(), prioritizeSmallRegisters);
+      } while (!maySpillLiveIntervalsToRegister(intervals, register, previousMaxRegisterNumber));
+    }
+
+    // Going to spill to the register (pair).
+    freeRegisters = previousFreeRegisters;
+    // If getFreeConsecutiveRegisters had to increment |maxRegisterNumber|, we need to update
+    // freeRegisters.
+    for (int i = previousMaxRegisterNumber + 1; i <= maxRegisterNumber; ++i) {
+      freeRegisters.add(i);
+    }
     assert registersAreFree(register, intervals.getType().isWide());
     return register;
   }
 
+  private boolean maySpillLiveIntervalsToRegister(
+      LiveIntervals intervals, int register, int previousMaxRegisterNumber) {
+    if (register > previousMaxRegisterNumber) {
+      // Nothing can prevent us from spilling to an entirely fresh register.
+      return true;
+    }
+
+    // If we are about to spill to an argument register, we need to be careful that the live range
+    // that is being spilled does not overlap with the live range of the corresponding argument.
+    //
+    // Note that this is *not* guaranteed when overlapsInactiveIntervals is null, because it is
+    // possible that some live ranges of the argument are still in the unhandled set.
+    if (register < numberOfArgumentRegisters) {
+      // Find the first argument value that uses the given register.
+      LiveIntervals argumentLiveIntervals = firstArgumentValue.getLiveIntervals();
+      while (!argumentLiveIntervals.usesRegister(register, intervals.getType().isWide())) {
+        argumentLiveIntervals = argumentLiveIntervals.getNextConsecutive();
+        assert argumentLiveIntervals != null;
+      }
+      do {
+        if (argumentLiveIntervals.anySplitOverlaps(intervals)) {
+          // Remove so that next invocation of getFreeConsecutiveRegisters does not consider this.
+          freeRegisters.remove(register);
+          // We have just established that there is an overlap between the live range of the
+          // current argument and the live range we need to find a register for. Therefore, if
+          // the argument is wide, and the current register corresponds to the low register of the
+          // argument, we know that the subsequent register will not work either.
+          if (register == argumentLiveIntervals.getRegister()
+              && argumentLiveIntervals.getType().isWide()) {
+            freeRegisters.remove(register + 1);
+          }
+          return false;
+        }
+        // The next argument live interval may also use the register, if it is a wide register pair.
+        argumentLiveIntervals = argumentLiveIntervals.getNextConsecutive();
+      } while (argumentLiveIntervals != null
+          && argumentLiveIntervals.usesRegister(register, intervals.getType().isWide()));
+    }
+
+    // Check for overlap with inactive intervals.
+    LiveIntervals overlapsInactiveIntervals = null;
+    for (LiveIntervals inactiveIntervals : inactive) {
+      if (inactiveIntervals.usesRegister(register, intervals.getType().isWide())
+          && intervals.overlaps(inactiveIntervals)) {
+        overlapsInactiveIntervals = inactiveIntervals;
+        break;
+      }
+    }
+    if (overlapsInactiveIntervals != null) {
+      // Remove so that next invocation of getFreeConsecutiveRegisters does not consider this.
+      freeRegisters.remove(register);
+      if (register == overlapsInactiveIntervals.getRegister()
+          && overlapsInactiveIntervals.getType().isWide()) {
+        freeRegisters.remove(register + 1);
+      }
+      return false;
+    }
+
+    // Check for overlap with the move exception interval.
+    boolean overlapsMoveExceptionInterval =
+        hasDedicatedMoveExceptionRegister()
+            && (register == getMoveExceptionRegister()
+                || (intervals.getType().isWide() && register + 1 == getMoveExceptionRegister()))
+            && overlapsMoveExceptionInterval(intervals);
+    if (overlapsMoveExceptionInterval) {
+      // Remove so that next invocation of getFreeConsecutiveRegisters does not consider this.
+      freeRegisters.remove(register);
+      return false;
+    }
+
+    return true;
+  }
+
   private int toInstructionPosition(int position) {
     return position % 2 == 0 ? position : position + 1;
   }
@@ -1496,7 +1635,7 @@
       // of finding another candidate to spill via allocateBlockedRegister.
       if (!unhandledInterval.getUses().first().hasConstraint()) {
         int nextConstrainedPosition = unhandledInterval.firstUseWithConstraint().getPosition();
-        int register = getSpillRegister(unhandledInterval);
+        int register = getSpillRegister(unhandledInterval, null);
         LiveIntervals split = unhandledInterval.splitBefore(nextConstrainedPosition);
         assignFreeRegisterToUnhandledInterval(unhandledInterval, register);
         unhandled.add(split);
@@ -1839,7 +1978,8 @@
       int splitPosition = unhandledInterval.getFirstUse();
       LiveIntervals split = unhandledInterval.splitBefore(splitPosition);
       assert split != unhandledInterval;
-      int registerNumber = getSpillRegister(unhandledInterval);
+      // Experiments show that it has a positive impact on code size to use a fresh register here.
+      int registerNumber = getNewSpillRegister(unhandledInterval);
       assignFreeRegisterToUnhandledInterval(unhandledInterval, registerNumber);
       unhandledInterval.setSpilled(true);
       unhandled.add(split);
@@ -1926,6 +2066,18 @@
       LiveIntervals unhandledInterval, int candidate, boolean candidateIsWide) {
     assert unhandledInterval.getRegister() == NO_REGISTER;
     assert atLeastOneOfRegistersAreTaken(candidate, candidateIsWide);
+    // Registers that we cannot choose for spilling.
+    IntList excludedRegisters = new IntArrayList(candidateIsWide ? 2 : 1);
+    excludedRegisters.add(candidate);
+    if (candidateIsWide) {
+      excludedRegisters.add(candidate + 1);
+    }
+    if (unhandledInterval.isArgumentInterval()
+        && unhandledInterval != unhandledInterval.getSplitParent()) {
+      // This live interval will become active in its original argument register and in the
+      // candidate register simultaneously.
+      unhandledInterval.getSplitParent().forEachRegister(excludedRegisters::add);
+    }
     // Spill overlapping active intervals.
     List<LiveIntervals> newActive = new ArrayList<>();
     Iterator<LiveIntervals> activeIterator = active.iterator();
@@ -1934,7 +2086,7 @@
       assert registersForIntervalsAreTaken(intervals);
       if (intervals.usesRegister(candidate, candidateIsWide)) {
         activeIterator.remove();
-        int registerNumber = getSpillRegister(intervals);
+        int registerNumber = getSpillRegister(intervals, excludedRegisters);
         // Important not to free the registers for intervals before finding a spill register,
         // because we might otherwise end up spilling to the current registers of intervals,
         // depending on getSpillRegister.
@@ -2662,8 +2814,33 @@
   }
 
   private int getFreeConsecutiveRegisters(int numberOfRegisters) {
+    return getFreeConsecutiveRegisters(numberOfRegisters, false);
+  }
+
+  private int getFreeConsecutiveRegisters(int numberOfRegisters, boolean prioritizeSmallRegisters) {
     int oldMaxRegisterNumber = maxRegisterNumber;
-    Iterator<Integer> freeRegistersIterator = freeRegisters.iterator();
+    TreeSet<Integer> freeRegistersWithDesiredOrdering = this.freeRegisters;
+    if (prioritizeSmallRegisters) {
+      freeRegistersWithDesiredOrdering =
+          new TreeSet<>(
+              (Integer x, Integer y) -> {
+                boolean xIsArgument = x < numberOfArgumentRegisters;
+                boolean yIsArgument = y < numberOfArgumentRegisters;
+                // If x is an argument and y is not, then prioritize y.
+                if (xIsArgument && !yIsArgument) {
+                  return 1;
+                }
+                // If x is not an argument and y is, then prioritize x.
+                if (!xIsArgument && yIsArgument) {
+                  return -1;
+                }
+                // Otherwise use their normal ordering.
+                return x - y;
+              });
+      freeRegistersWithDesiredOrdering.addAll(this.freeRegisters);
+    }
+
+    Iterator<Integer> freeRegistersIterator = freeRegistersWithDesiredOrdering.iterator();
     int first = getNextFreeRegister(freeRegistersIterator);
     int current = first;
     while (current - first + 1 != numberOfRegisters) {
@@ -2692,6 +2869,22 @@
     return first;
   }
 
+  private boolean registersAreFreeAndConsecutive(int register, boolean registerIsWide) {
+    if (!freeRegisters.contains(register)) {
+      return false;
+    }
+    if (registerIsWide) {
+      if (!freeRegisters.contains(register + 1)) {
+        return false;
+      }
+      if (register == numberOfArgumentRegisters - 1) {
+        // Will not be consecutive after reordering the arguments and temporaries.
+        return false;
+      }
+    }
+    return true;
+  }
+
   private int getNextFreeRegister(Iterator<Integer> freeRegistersIterator) {
     if (freeRegistersIterator.hasNext()) {
       return freeRegistersIterator.next();