Support for swapping local registers and move exception register

This adds an option to swap the local registers and the dedicated move exception register when register allocation requires more than 16 registers.

This frees up a 4 bit register.

Bug: b/374715251
Change-Id: I7f788b0e22a2c98727607d7b5f44eda05e0958b5
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index 5c44a1a..bd425db 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -238,17 +238,31 @@
   }
 
   // We allocate a dedicated move exception register right after the arguments.
-  // TODO(b/374715251): The move-exception instruction only requires its destination register to
-  //  fit in 8 bits. In some situations, it might be better to use a register which is >= 16 if we
-  //  end up using that many registers.
   private int getMoveExceptionRegister() {
     assert hasDedicatedMoveExceptionRegister();
     return numberOfArgumentRegisters;
   }
 
+  private int getMoveExceptionOffsetForLocalRegisters() {
+    return BooleanUtils.intValue(
+        hasDedicatedMoveExceptionRegister()
+            && isDedicatedMoveExceptionRegisterInLastLocalRegister());
+  }
+
   private boolean isDedicatedMoveExceptionRegisterInFirstLocalRegister() {
     assert hasDedicatedMoveExceptionRegister();
-    return true;
+    if (mode.is4Bit() || mode.is16Bit()) {
+      return true;
+    }
+    if (mode.is8BitRefinement()) {
+      assert numberOf4BitArgumentRegisters > 0;
+      return true;
+    }
+    return !options().getTestingOptions().enableUseLastLocalRegisterAsMoveExceptionRegister;
+  }
+
+  private boolean isDedicatedMoveExceptionRegisterInLastLocalRegister() {
+    return !isDedicatedMoveExceptionRegisterInFirstLocalRegister();
   }
 
   public LinearScanRegisterAllocator(AppView<?> appView, IRCode code) {
@@ -690,11 +704,7 @@
 
   private int[] computeUnusedRegistersFromUsedRegisters(IntSet usedRegisters) {
     assert firstParallelMoveTemporary != NO_REGISTER;
-    int firstLocalRegister =
-        numberOfArgumentRegisters
-            + BooleanUtils.intValue(
-                hasDedicatedMoveExceptionRegister()
-                    && !isDedicatedMoveExceptionRegisterInFirstLocalRegister());
+    int firstLocalRegister = numberOfArgumentRegisters + getMoveExceptionOffsetForLocalRegisters();
     assert verifyRegistersBeforeFirstLocalRegisterAreUsed(firstLocalRegister, usedRegisters);
     int numberOfParallelMoveTemporaryRegisters = registersUsed() - firstParallelMoveTemporary;
     int numberOfLocalRegisters =
@@ -958,15 +968,19 @@
   int unadjustedRealRegisterFromAllocated(int allocated) {
     assert allocated != NO_REGISTER;
     assert allocated >= 0;
-    int register;
     if (allocated < numberOfArgumentRegisters) {
       // For the |numberOfArguments| first registers map to the correct argument register.
-      register = maxRegisterNumber - (numberOfArgumentRegisters - allocated - 1);
+      return maxRegisterNumber - (numberOfArgumentRegisters - allocated - 1);
+    } else if (hasDedicatedMoveExceptionRegister()
+        && isDedicatedMoveExceptionRegisterInLastLocalRegister()
+        && allocated == getMoveExceptionRegister()) {
+      // Move the move-exception register to be the highest local register. We only do this in 8 bit
+      // register allocation since move-exception requires an 8 bit register.
+      return maxRegisterNumber - numberOfArgumentRegisters;
     } else {
       // For everything else use the lower numbers.
-      register = allocated - numberOfArgumentRegisters;
+      return allocated - numberOfArgumentRegisters - getMoveExceptionOffsetForLocalRegisters();
     }
-    return register;
   }
 
   int realRegisterNumberFromAllocated(int allocated) {
@@ -1073,6 +1087,9 @@
           }
         }
       }
+      for (LiveIntervals intervals : moveExceptionIntervals) {
+        assert intervals.getRegisterLimit() == Constants.U8BIT_MAX;
+      }
     }
 
     // Go through each unhandled live interval and find a register for it.
@@ -1369,8 +1386,14 @@
     // Exclude move exception register if the first interval overlaps a move exception interval.
     // It is not necessary to check the remaining consecutive intervals, since we always use
     // register 0 (after remapping) for the argument register.
-    if (overlapsMoveExceptionInterval(start) && freeRegisters.remove(getMoveExceptionRegister())) {
-      excludedRegisters.add(getMoveExceptionRegister());
+    if (hasDedicatedMoveExceptionRegister()) {
+      boolean canUseMoveExceptionRegisterForLinkedIntervals =
+          isDedicatedMoveExceptionRegisterInFirstLocalRegister()
+              && !overlapsMoveExceptionInterval(start);
+      if (!canUseMoveExceptionRegisterForLinkedIntervals
+          && freeRegisters.remove(getMoveExceptionRegister())) {
+        excludedRegisters.add(getMoveExceptionRegister());
+      }
     }
     // Select registers.
     int numberOfRegisters = start.numberOfConsecutiveRegisters();
@@ -1798,6 +1821,9 @@
       // Since we swap the argument registers and the temporary registers after register allocation,
       // we can allow the use of number of arguments more registers.
       registerConstraint += numberOfArgumentRegisters;
+      // If we swap the locals and the dedicated move exception register we can allow the use of
+      // one additional register.
+      registerConstraint += getMoveExceptionOffsetForLocalRegisters();
     }
 
     RegisterPositions freePositions = computeFreePositions(unhandledInterval, registerConstraint);
@@ -1935,10 +1961,15 @@
     // register. If we cannot find a free valid register for the move exception value we have no
     // place to put a spill move (because the move exception instruction has to be the
     // first instruction in the handler block).
-    if (overlapsMoveExceptionInterval(unhandledInterval)) {
-      int moveExceptionRegister = getMoveExceptionRegister();
-      if (moveExceptionRegister <= registerConstraint) {
-        freePositions.setBlocked(moveExceptionRegister);
+    if (hasDedicatedMoveExceptionRegister()) {
+      if (unhandledInterval.getRegisterLimit() == Constants.U4BIT_MAX
+          && isDedicatedMoveExceptionRegisterInLastLocalRegister()) {
+        freePositions.setBlocked(getMoveExceptionRegister());
+      } else if (overlapsMoveExceptionInterval(unhandledInterval)) {
+        int moveExceptionRegister = getMoveExceptionRegister();
+        if (moveExceptionRegister <= registerConstraint) {
+          freePositions.setBlocked(moveExceptionRegister);
+        }
       }
     }
 
@@ -2130,6 +2161,13 @@
           // The last register of the method is |i|, so we cannot use the pair (|i|, |i+1|).
           continue;
         }
+        if (hasDedicatedMoveExceptionRegister()
+            && isDedicatedMoveExceptionRegisterInLastLocalRegister()
+            && i == getMoveExceptionRegister()) {
+          // After register allocation we swap the dedicated move-exception register and all other
+          // local registers, so we cannot use the pair (|i|, |i+1|).
+          continue;
+        }
         if (i >= registerConstraint) {
           break;
         }
@@ -2286,8 +2324,13 @@
     }
 
     // Disallow reuse of the move exception register if we have reserved one.
-    if (overlapsMoveExceptionInterval(unhandledInterval)) {
-      usePositions.setBlocked(getMoveExceptionRegister());
+    if (hasDedicatedMoveExceptionRegister()) {
+      if (unhandledInterval.getRegisterLimit() == Constants.U4BIT_MAX
+          && isDedicatedMoveExceptionRegisterInLastLocalRegister()) {
+        usePositions.setBlocked(getMoveExceptionRegister());
+      } else if (overlapsMoveExceptionInterval(unhandledInterval)) {
+        usePositions.setBlocked(getMoveExceptionRegister());
+      }
     }
 
     // Treat active and inactive linked argument intervals as pinned. They cannot be given another
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 17ea3af..76d22ef 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -2486,6 +2486,8 @@
     //  successful (i.e., how often the assumed 4 bit argument registers actually end up being 4
     //  bit). If the failure rate is too high maybe add a some buffer.
     public boolean enableRegisterAllocation8BitRefinement = false;
+    // TODO(b/374715251): Look into enabling this.
+    public boolean enableUseLastLocalRegisterAsMoveExceptionRegister = false;
     public boolean enableKeepInfoCanonicalizer = true;
     public boolean enableBridgeHoistingToSharedSyntheticSuperclass = false;
     public boolean enableCheckCastAndInstanceOfRemoval = true;
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/MoveExceptionInHighestLocalTest.java b/src/test/java/com/android/tools/r8/ir/regalloc/MoveExceptionInHighestLocalTest.java
new file mode 100644
index 0000000..38c5012
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/MoveExceptionInHighestLocalTest.java
@@ -0,0 +1,83 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.regalloc;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.dex.Constants;
+import com.android.tools.r8.dex.code.DexMoveException;
+import com.android.tools.r8.graph.DexCode;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import com.google.common.collect.MoreCollectors;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class MoveExceptionInHighestLocalTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDefaultDexRuntime().withMaximumApiLevel().build();
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    testForD8()
+        .addInnerClasses(getClass())
+        .addOptionsModification(
+            options -> {
+              options.getTestingOptions().enableRegisterAllocation8BitRefinement = true;
+              options.getTestingOptions().enableUseLastLocalRegisterAsMoveExceptionRegister = true;
+            })
+        .release()
+        .setMinApi(parameters)
+        .compile()
+        .inspect(
+            inspector -> {
+              MethodSubject testMethodSubject =
+                  inspector.clazz(Main.class).uniqueMethodWithOriginalName("test");
+              assertThat(testMethodSubject, isPresent());
+
+              DexCode code = testMethodSubject.getMethod().getCode().asDexCode();
+              DexMoveException moveException =
+                  testMethodSubject
+                      .streamInstructions()
+                      .filter(InstructionSubject::isMoveException)
+                      .collect(MoreCollectors.onlyElement())
+                      .asDexInstruction()
+                      .getInstruction();
+              int expectedMoveExceptionRegister = code.registerSize - code.incomingRegisterSize - 1;
+              assertTrue(expectedMoveExceptionRegister > Constants.U4BIT_MAX);
+              assertEquals(expectedMoveExceptionRegister, moveException.AA);
+            })
+        .runDex2Oat(parameters.getRuntime())
+        .assertNoVerificationErrors();
+  }
+
+  static class Main {
+
+    void test(long a, long b, long c, long d, long e, long f, long g, long h) {
+      try {
+        test(1, 2, 3, 4, 5, 6, 7, 8);
+      } catch (Exception exception) {
+        constrainedUse(exception, exception);
+      }
+    }
+
+    static void constrainedUse(Object a, Object b) {}
+  }
+}
diff --git a/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java b/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
index aa3e569..7ad1a8f 100644
--- a/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
+++ b/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
@@ -407,6 +407,11 @@
   }
 
   @Override
+  public boolean isMoveException() {
+    return false;
+  }
+
+  @Override
   public boolean isMoveFrom(int register) {
     return false;
   }
diff --git a/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java b/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
index aec79c9..fee7623 100644
--- a/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
+++ b/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
@@ -98,6 +98,7 @@
 import com.android.tools.r8.dex.code.DexMonitorExit;
 import com.android.tools.r8.dex.code.DexMove;
 import com.android.tools.r8.dex.code.DexMove16;
+import com.android.tools.r8.dex.code.DexMoveException;
 import com.android.tools.r8.dex.code.DexMoveFrom16;
 import com.android.tools.r8.dex.code.DexMoveObject;
 import com.android.tools.r8.dex.code.DexMoveObject16;
@@ -679,6 +680,11 @@
   }
 
   @Override
+  public boolean isMoveException() {
+    return instruction instanceof DexMoveException;
+  }
+
+  @Override
   public boolean isMoveFrom(int register) {
     if (instruction instanceof DexMove) {
       DexMove move = getInstruction();
diff --git a/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java b/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
index 82b9263..983751f 100644
--- a/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
+++ b/src/test/testbase/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
@@ -168,6 +168,8 @@
 
   boolean isMove();
 
+  boolean isMoveException();
+
   boolean isMoveFrom(int register);
 
   boolean isMoveResult();