Refine 8 bit register allocation when some arguments are in 4 bits

Bug: b/374266460
Change-Id: Ifaf7bf0fa788ba2373ca4300f9031184b24fa8ea
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index 6cb2c2b..71f81ad 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -17,6 +17,7 @@
 import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.code.Add;
 import com.android.tools.r8.ir.code.And;
+import com.android.tools.r8.ir.code.Argument;
 import com.android.tools.r8.ir.code.ArithmeticBinop;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.CheckCast;
@@ -96,6 +97,8 @@
   public enum ArgumentReuseMode {
     ALLOW_ARGUMENT_REUSE_U4BIT,
     ALLOW_ARGUMENT_REUSE_U8BIT,
+    ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT,
+    ALLOW_ARGUMENT_REUSE_U8BIT_RETRY,
     ALLOW_ARGUMENT_REUSE_U16BIT;
 
     boolean hasRegisterConstraint(LiveIntervals intervals) {
@@ -111,6 +114,8 @@
         case ALLOW_ARGUMENT_REUSE_U4BIT:
           return false;
         case ALLOW_ARGUMENT_REUSE_U8BIT:
+        case ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT:
+        case ALLOW_ARGUMENT_REUSE_U8BIT_RETRY:
           return constraint == Constants.U4BIT_MAX;
         case ALLOW_ARGUMENT_REUSE_U16BIT:
           return constraint != Constants.U16BIT_MAX;
@@ -118,6 +123,28 @@
           throw new Unreachable();
       }
     }
+
+    boolean is4Bit() {
+      return this == ALLOW_ARGUMENT_REUSE_U4BIT;
+    }
+
+    boolean is8Bit() {
+      return this == ALLOW_ARGUMENT_REUSE_U8BIT
+          || this == ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT
+          || this == ALLOW_ARGUMENT_REUSE_U8BIT_RETRY;
+    }
+
+    boolean is8BitRefinement() {
+      return this == ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT;
+    }
+
+    boolean is8BitRetry() {
+      return this == ALLOW_ARGUMENT_REUSE_U8BIT_RETRY;
+    }
+
+    boolean is16Bit() {
+      return this == ALLOW_ARGUMENT_REUSE_U16BIT;
+    }
   }
 
   private static class LocalRange implements Comparable<LocalRange> {
@@ -155,6 +182,9 @@
   private final IRCode code;
   // Number of registers used for arguments.
   protected final int numberOfArgumentRegisters;
+  // Number of argument registers that may be assumed to be in 4 bit registers. This should only be
+  // used when mode is ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT.
+  private int numberOf4BitArgumentRegisters = 0;
 
   // Mapping from basic blocks to the set of values live at entry to that basic block.
   private Map<BasicBlock, LiveAtEntrySets> liveAtEntrySets;
@@ -164,7 +194,7 @@
   private Value lastArgumentValue;
 
   // The current register allocation mode.
-  private ArgumentReuseMode mode = ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U4BIT;
+  private ArgumentReuseMode mode;
   // The set of registers that are free for allocation.
   private TreeSet<Integer> freeRegisters = new TreeSet<>();
   // The max register number used.
@@ -206,9 +236,9 @@
   }
 
   // We allocate a dedicated move exception register right after the arguments.
-  // TODO(christofferqa): The move-exception instruction only requires its destination register to
-  // fit in 8 bits. In some situations, it might be better to use a register which is >= 16 if we
-  // end up using that many registers.
+  // TODO(b/374715251): The move-exception instruction only requires its destination register to
+  //  fit in 8 bits. In some situations, it might be better to use a register which is >= 16 if we
+  //  end up using that many registers.
   private int getMoveExceptionRegister() {
     assert hasDedicatedMoveExceptionRegister();
     return numberOfArgumentRegisters;
@@ -221,11 +251,46 @@
     for (Instruction instruction : code.entryBlock().getInstructions()) {
       if (instruction.isArgument()) {
         argumentRegisters += instruction.outValue().requiredRegisters();
+      } else {
+        break;
       }
     }
     numberOfArgumentRegisters = argumentRegisters;
   }
 
+  private boolean retry8BitAllocationWith4BitArgumentRegisters() {
+    assert mode.is8Bit();
+    assert numberOf4BitArgumentRegisters == 0;
+    if (!options().getTestingOptions().enableRegisterAllocation8BitRefinement
+        || code.context().getDefinition().getNumberOfArguments() == 0) {
+      return false;
+    }
+    numberOf4BitArgumentRegisters = computeNumberOf4BitArgumentRegisters();
+    return numberOf4BitArgumentRegisters > 0;
+  }
+
+  private int computeNumberOf4BitArgumentRegisters() {
+    int numberOf4BitArgumentRegisters = 0;
+    Iterator<Argument> argumentIterator = code.argumentIterator();
+    int currentArgumentRegisterStart = registersUsed() - numberOfArgumentRegisters;
+    while (argumentIterator.hasNext()) {
+      Argument argument = argumentIterator.next();
+      int requiredRegisters = argument.outValue().requiredRegisters();
+      int nextArgumentRegisterStart = currentArgumentRegisterStart + requiredRegisters;
+      int currentArgumentRegisterEnd = nextArgumentRegisterStart - 1;
+      if (currentArgumentRegisterEnd <= Constants.U4BIT_MAX) {
+        currentArgumentRegisterStart = nextArgumentRegisterStart;
+        numberOf4BitArgumentRegisters += requiredRegisters;
+      } else {
+        if (currentArgumentRegisterStart <= Constants.U4BIT_MAX) {
+          numberOf4BitArgumentRegisters++;
+        }
+        break;
+      }
+    }
+    return numberOf4BitArgumentRegisters;
+  }
+
   @Override
   public ProgramMethod getProgramMethod() {
     return code.context();
@@ -248,7 +313,7 @@
     ImmutableList<BasicBlock> blocks = computeLivenessInformation();
     performAllocation();
     assert code.isConsistentGraph(appView);
-    assert registersUsed() == 0 || unusedRegisters != null;
+    assert mode.is4Bit() || registersUsed() == 0 || unusedRegisters != null;
     // Even if the method is reachability sensitive, we do not compute debug information after
     // register allocation. We just treat the method as being in debug mode in order to keep
     // locals alive for their entire live range. In release mode the liveness is all that matters
@@ -582,7 +647,7 @@
   // numbers that were unused. This table is then used to slide down the actual registers
   // used to fill the gaps.
   private boolean computeUnusedRegisters() {
-    if (registersUsed() == 0) {
+    if (mode.is4Bit() || registersUsed() == 0) {
       return false;
     }
     // Compute the set of registers that is used based on all live intervals.
@@ -681,14 +746,15 @@
   private void performAllocation() {
     // Will automatically continue to ALLOW_ARGUMENT_REUSE_U8BIT and ALLOW_ARGUMENT_REUSE_U16BIT,
     // if needed.
-    performAllocation(ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U4BIT, false);
+    performAllocation(ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U4BIT);
   }
 
-  private ArgumentReuseMode performAllocation(ArgumentReuseMode mode, boolean isRetry) {
+  private ArgumentReuseMode performAllocation(ArgumentReuseMode mode) {
+    assert numberOf4BitArgumentRegisters == 0 || mode.is8BitRefinement();
     ArgumentReuseMode result = mode;
     this.mode = mode;
 
-    if (isRetry) {
+    if (!mode.is4Bit()) {
       clearRegisterAssignments(mode);
       removeSpillAndPhiMoves();
     }
@@ -698,6 +764,16 @@
     boolean succeeded = performLinearScan(mode);
     if (succeeded) {
       insertMoves();
+      // Now that we know the max register number we can compute whether it is safe to use
+      // argument registers in place. If it is, we redo move insertion to get rid of the moves
+      // caused by splitting of the argument registers.
+      if (unsplitArguments()) {
+        removeSpillAndPhiMoves();
+        insertMoves();
+      }
+      computeUnusedRegisters();
+    } else {
+      assert mode.is4Bit();
     }
 
     switch (mode) {
@@ -707,49 +783,49 @@
             || options().testing.alwaysUsePessimisticRegisterAllocation) {
           // Redo allocation in mode ALLOW_ARGUMENT_REUSE_U8BIT. This may in principle also fail.
           // It is extremely rare that a method will use more than 256 registers, though.
-          result = performAllocation(ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT, true);
-        } else {
-          // Never has unused registers.
-          assert !computeUnusedRegisters();
+          result = performAllocation(ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT);
         }
         break;
 
       case ALLOW_ARGUMENT_REUSE_U8BIT:
-        assert succeeded;
-        // Now that we know the max register number we can compute whether it is safe to use
-        // argument registers in place. If it is, we redo move insertion to get rid of the moves
-        // caused by splitting of the argument registers.
-        if (unsplitArguments()) {
-          removeSpillAndPhiMoves();
-          insertMoves();
-        }
-        computeUnusedRegisters();
-
         if (highestUsedRegister() > Constants.U8BIT_MAX
-            || options().testing.alwaysUsePessimisticRegisterAllocation) {
+            || options().getTestingOptions().alwaysUsePessimisticRegisterAllocation) {
           // Redo allocation in mode ALLOW_ARGUMENT_REUSE_U16BIT. This always succeed.
           unusedRegisters = null;
-          result = performAllocation(ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U16BIT, true);
+          result = performAllocation(ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U16BIT);
+          break;
         }
+
+        if (retry8BitAllocationWith4BitArgumentRegisters()) {
+          // Refine register allocation result using the knowledge that some of the argument
+          // registers are 4 bit registers.
+          unusedRegisters = null;
+          result = performAllocation(ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT);
+        }
+        break;
+
+      case ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT:
+        if (highestUsedRegister() > Constants.U8BIT_MAX
+            || numberOf4BitArgumentRegisters > computeNumberOf4BitArgumentRegisters()) {
+          // Redo allocation in mode ALLOW_ARGUMENT_REUSE_U8BIT_RETRY. This always succeed.
+          numberOf4BitArgumentRegisters = 0;
+          unusedRegisters = null;
+          result = performAllocation(ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT_RETRY);
+        }
+        break;
+
+      case ALLOW_ARGUMENT_REUSE_U8BIT_RETRY:
+        assert highestUsedRegister() <= Constants.U8BIT_MAX;
         break;
 
       case ALLOW_ARGUMENT_REUSE_U16BIT:
-        assert succeeded;
-        // Now that we know the max register number we can compute whether it is safe to use
-        // argument registers in place. If it is, we redo move insertion to get rid of the moves
-        // caused by splitting of the argument registers.
-        if (unsplitArguments()) {
-          removeSpillAndPhiMoves();
-          insertMoves();
-        }
-        computeUnusedRegisters();
+        assert highestUsedRegister() <= Constants.U16BIT_MAX;
         break;
     }
 
-    assert result != ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U4BIT
-        || highestUsedRegister() <= Constants.U4BIT_MAX;
-    assert result != ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT
-        || highestUsedRegister() <= Constants.U8BIT_MAX;
+    assert !result.is4Bit() || highestUsedRegister() <= Constants.U4BIT_MAX;
+    assert !result.is8Bit() || highestUsedRegister() <= Constants.U8BIT_MAX;
+    assert !result.is16Bit() || highestUsedRegister() <= Constants.U16BIT_MAX;
 
     return result;
   }
@@ -759,11 +835,17 @@
   // register allocation we can check if it is safe to just use the argument register itself
   // for all uses and thereby avoid moving argument values around.
   private boolean unsplitArguments() {
+    if (mode.is4Bit()) {
+      return false;
+    }
     boolean argumentRegisterUnsplit = false;
-    Value current = firstArgumentValue;
-    while (current != null) {
+    for (Value current = firstArgumentValue;
+        current != null;
+        current = current.getNextConsecutive()) {
       LiveIntervals intervals = current.getLiveIntervals();
-      assert !mode.hasRegisterConstraint(intervals);
+      assert !mode.hasRegisterConstraint(intervals)
+          || (mode.is8BitRefinement()
+              && intervals.getRegisterEnd() < numberOf4BitArgumentRegisters);
       boolean canUseArgumentRegister = true;
       boolean couldUseArgumentRegister = true;
       for (LiveIntervals child : intervals.getSplitChildren()) {
@@ -787,7 +869,6 @@
           child.setSpilled(false);
         }
       }
-      current = current.getNextConsecutive();
     }
     return argumentRegisterUnsplit;
   }
@@ -826,7 +907,7 @@
     unhandled.clear();
     moveExceptionIntervals.clear();
     for (LiveIntervals intervals : liveIntervals) {
-      if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U16BIT) {
+      if (mode.is8BitRefinement() || mode.is8BitRetry() || mode.is16Bit()) {
         intervals.undoSplits();
         intervals.setSpilled(false);
       }
@@ -877,6 +958,10 @@
       if (!mode.hasRegisterConstraint(argumentInterval)) {
         // All the argument intervals are active in the beginning and have preallocated registers.
         active.add(argumentInterval);
+      } else if (mode.is8BitRefinement()
+          && argumentInterval.getRegister() + argumentValue.requiredRegisters()
+              <= numberOf4BitArgumentRegisters) {
+        active.add(argumentInterval);
       } else {
         // Treat the argument interval as spilled which will require a load to a different
         // register for all register-constrained usages.
@@ -915,8 +1000,7 @@
     // When we allow argument reuse we do not allow any splitting, therefore we cannot get into
     // trouble with move exception registers. When argument reuse is disallowed we block a fixed
     // register to be used only by move exception instructions.
-    if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT
-        || mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U16BIT) {
+    if (mode.is8Bit() || mode.is16Bit()) {
       // Force all move exception ranges to start out with the exception in a fixed register. Split
       // their live ranges which will force another register if used.
       boolean overlappingMoveExceptionIntervals = false;
@@ -1036,8 +1120,7 @@
             computedFreeRegisters.remove(register);
           });
     }
-    if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT
-        || mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U16BIT) {
+    if (mode.is8Bit() || mode.is16Bit()) {
       // Each time an argument interval is active, we currently require that it is present in its
       // original, incoming argument register.
       for (LiveIntervals activeIntervals : active) {
@@ -1084,7 +1167,7 @@
     return true;
   }
 
-  private boolean registerAssignmentNotConflictingWithArgument(LiveIntervals interval) {
+  private boolean verifyRegisterAssignmentNotConflictingWithArgument(LiveIntervals interval) {
     assert interval.getRegister() != NO_REGISTER;
     for (Value argumentValue = firstArgumentValue;
         argumentValue != null;
@@ -1183,8 +1266,7 @@
           // Allocate the argument intervals.
           unhandled.remove(destIntervals);
           boolean excludeUnhandledOverlappingArgumentIntervals = false;
-          if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT
-              || mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U16BIT) {
+          if (mode.is8Bit() || mode.is16Bit()) {
             // Since we are going to do a look-ahead, there may be argument live interval splits,
             // which are currently unhandled, but would be inactive at the invoke-range instruction.
             // Thus, the implementation of allocateLinkedIntervals needs to exclude the argument
@@ -1254,7 +1336,7 @@
     int nextRegister = getFreeConsecutiveRegisters(numberOfRegisters);
     for (LiveIntervals current = start; current != null; current = current.getNextConsecutive()) {
       current.setRegister(nextRegister);
-      assert registerAssignmentNotConflictingWithArgument(current);
+      assert verifyRegisterAssignmentNotConflictingWithArgument(current);
       // Propagate hints to the move sources.
       Value value = current.getValue();
       if (!value.isPhi() && value.definition.isMove()) {
@@ -1664,17 +1746,14 @@
     // avoid move generation for the argument.
     if (unhandledInterval.isArgumentInterval()) {
       if (registerConstraint == Constants.U16BIT_MAX
-          || (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT
-              && registerConstraint == Constants.U8BIT_MAX)) {
+          || (mode.is8Bit() && registerConstraint == Constants.U8BIT_MAX)) {
         int argumentRegister = unhandledInterval.getSplitParent().getRegister();
         assignFreeRegisterToUnhandledInterval(unhandledInterval, argumentRegister);
         return true;
       }
     }
 
-    if (registerConstraint < Constants.U16BIT_MAX
-        && (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT
-            || mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U16BIT)) {
+    if (!mode.is4Bit() && registerConstraint < Constants.U16BIT_MAX) {
       // Since we swap the argument registers and the temporary registers after register allocation,
       // we can allow the use of number of arguments more registers.
       registerConstraint += numberOfArgumentRegisters;
@@ -1766,21 +1845,47 @@
     // Set all free positions for possible registers to max integer.
     RegisterPositions freePositions = new RegisterPositionsImpl(registerConstraint + 1);
 
-    if ((options().debug || code.context().isReachabilitySensitive())
-        && !code.method().accessFlags.isStatic()) {
-      // If we are generating debug information or if the method is reachability sensitive,
-      // we pin the this value register. The debugger expects to always be able to find it in
-      // the input register.
+    if (options().shouldCompileMethodInDebugMode(code.context())
+        && !code.context().getAccessFlags().isStatic()) {
+      // When compiling the method in debug mode we pin the this value register. The debugger
+      // expects to be able to find it in the input register.
       assert numberOfArgumentRegisters > 0;
       assert firstArgumentValue != null && firstArgumentValue.requiredRegisters() == 1;
       freePositions.setBlocked(0);
     }
 
-    if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT
-        || mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U16BIT) {
-      // Argument reuse is not allowed and we block all the argument registers so that
+    if (!mode.is4Bit()) {
+      // Generally argument reuse is not allowed and we block all the argument registers so that
       // arguments are never free.
-      for (int i = 0; i < numberOfArgumentRegisters && i <= registerConstraint; i++) {
+      //
+      // When mode=ALLOW_ARGUMENT_REUSE_U8BIT_REFINEMENT we assume that some argument registers are
+      // in 4 bits. If the current live intervals does not overlap with a 4 bit argument intervals
+      // then we allow using that argument register for the current value.
+      int i = 0;
+      if (mode.is8BitRefinement()) {
+        assert numberOf4BitArgumentRegisters > 0;
+        int remainingNumberOf4BitArgumentRegisters = numberOf4BitArgumentRegisters;
+        for (Value argumentValue = firstArgumentValue;
+            argumentValue != null;
+            argumentValue = argumentValue.getNextConsecutive()) {
+          int requiredRegisters = argumentValue.requiredRegisters();
+          remainingNumberOf4BitArgumentRegisters -= requiredRegisters;
+          if (remainingNumberOf4BitArgumentRegisters < 0) {
+            // Block all subsequent argument registers.
+            break;
+          }
+          // Block this argument register if there is any overlap between the two live intervals.
+          // TODO(b/374266460): Allow using the argument register even when there are overlapping
+          //  live intervals.
+          if (argumentValue.getLiveIntervals().anySplitOverlaps(unhandledInterval)) {
+            for (int j = 0; j < requiredRegisters; j++) {
+              freePositions.setBlocked(i + j);
+            }
+          }
+          i += requiredRegisters;
+        }
+      }
+      for (; i < numberOfArgumentRegisters && i <= registerConstraint; i++) {
         freePositions.setBlocked(i);
       }
     }
@@ -2373,7 +2478,7 @@
     boolean isSpillingToArgumentRegister =
         (spilled.isArgumentInterval() || registerNumber < numberOfArgumentRegisters);
     if (isSpillingToArgumentRegister) {
-      if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U8BIT) {
+      if (mode.is8Bit()) {
         registerNumber = Constants.U8BIT_MAX;
       } else {
         registerNumber = Constants.U16BIT_MAX;
@@ -2568,11 +2673,20 @@
   }
 
   boolean isPinnedArgumentRegister(LiveIntervals intervals) {
-    if (intervals.isArgumentInterval()) {
-      assert intervals.getRegister() != NO_REGISTER;
-      return intervals.getRegister() < numberOfArgumentRegisters;
+    if (!intervals.isArgumentInterval()) {
+      return false;
     }
-    return false;
+    assert intervals.getRegister() != NO_REGISTER;
+    if (intervals.getRegister() >= numberOfArgumentRegisters) {
+      return false;
+    }
+    if (mode.is8BitRefinement()) {
+      // An 8 bit argument register could be moved to a 4 bit argument register.
+      if (intervals.getRegister() != intervals.getSplitParent().getRegister()) {
+        return false;
+      }
+    }
+    return true;
   }
 
   private static void addLiveRange(
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
index 87ca105..9597249 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
@@ -277,6 +277,10 @@
     return register;
   }
 
+  public int getRegisterEnd() {
+    return register + requiredRegisters() - 1;
+  }
+
   public int getRegisterLimit() {
     return registerLimit;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/SpillMoveSet.java b/src/main/java/com/android/tools/r8/ir/regalloc/SpillMoveSet.java
index fe7c2fc..87c059a 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/SpillMoveSet.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/SpillMoveSet.java
@@ -300,10 +300,7 @@
     // the arguments are not live, so it is insufficient to check that the destination register
     // is in the argument register range.
     for (SpillMove move : moves) {
-      boolean isArgumentRestore =
-          move.to.getRegister() < allocator.numberOfArgumentRegisters
-              && move.to.isArgumentInterval();
-      assert !isArgumentRestore;
+      assert !allocator.isPinnedArgumentRegister(move.to);
     }
     return true;
   }
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index cbd5675..17ea3af 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -2481,6 +2481,11 @@
     public boolean allowUnusedDontWarnRules = true;
     public boolean alwaysUseExistingAccessInfoCollectionsInMemberRebinding = true;
     public boolean alwaysUsePessimisticRegisterAllocation = false;
+    // TODO(b/374266460): Investigate why enabling this leads to more moves, for example, in
+    //  JetNews. Also investigate the impact on performance and how often the refinement pass is
+    //  successful (i.e., how often the assumed 4 bit argument registers actually end up being 4
+    //  bit). If the failure rate is too high maybe add a some buffer.
+    public boolean enableRegisterAllocation8BitRefinement = false;
     public boolean enableKeepInfoCanonicalizer = true;
     public boolean enableBridgeHoistingToSharedSyntheticSuperclass = false;
     public boolean enableCheckCastAndInstanceOfRemoval = true;
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentIn4BitRegisterTest.java b/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentIn4BitRegisterTest.java
index d2b9b6e..197fa98 100644
--- a/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentIn4BitRegisterTest.java
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentIn4BitRegisterTest.java
@@ -33,6 +33,8 @@
   public void test() throws Exception {
     testForD8()
         .addInnerClasses(getClass())
+        .addOptionsModification(
+            options -> options.getTestingOptions().enableRegisterAllocation8BitRefinement = true)
         .release()
         .setMinApi(parameters)
         .compile()
@@ -41,9 +43,8 @@
               MethodSubject testMethodSubject =
                   inspector.clazz(Main.class).uniqueMethodWithOriginalName("test");
               assertThat(testMethodSubject, isPresent());
-              // TODO(b/374266460): Should be 0.
               assertEquals(
-                  1,
+                  0,
                   testMethodSubject
                       .streamInstructions()
                       .filter(InstructionSubject::isMove)
@@ -62,6 +63,7 @@
       use(d);
       use(e);
       use(f);
+      use(g);
       use(h);
       use(i);
     }
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentInLowRegisterWithMoreThan16RegistersTest.java b/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentInLowRegisterWithMoreThan16RegistersTest.java
index 8885519..24fbe6d 100644
--- a/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentInLowRegisterWithMoreThan16RegistersTest.java
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/ArgumentInLowRegisterWithMoreThan16RegistersTest.java
@@ -33,6 +33,8 @@
   public void testD8() throws Exception {
     testForD8()
         .addInnerClasses(getClass())
+        .addOptionsModification(
+            options -> options.getTestingOptions().enableRegisterAllocation8BitRefinement = true)
         .release()
         .setMinApi(parameters)
         .compile()
@@ -41,15 +43,15 @@
               MethodSubject testMethodSubject =
                   inspector.clazz(Main.class).uniqueInstanceInitializer();
               assertThat(testMethodSubject, isPresent());
-              // TODO(b/374266460): Leverage that most arguments are in 4 bit registers from the
-              //  beginning.
               assertEquals(
-                  9,
+                  2,
                   testMethodSubject
                       .streamInstructions()
                       .filter(InstructionSubject::isMove)
                       .count());
-            });
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithEmptyOutput();
   }
 
   static class Main {
@@ -73,5 +75,9 @@
       this.g = g;
       this.h = h;
     }
+
+    public static void main(String[] args) {
+      new Main(1, 2, 3, 4, 5, 6, 7, 8);
+    }
   }
 }