Ensure register allocation never restores pinned arguments

Fixes: b/374860913
Change-Id: I2cb1735b072bf88bd46212437f17c062d4ca3d21
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index 549930b..6cb2c2b 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.cf.FixedLocalValue;
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DebugLocalInfo;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -92,10 +93,31 @@
   public static final int MIN_CONSTANT_FREE_FOR_POSITIONS = 5;
   public static final int EXCEPTION_INTERVALS_OVERLAP_CUTOFF = 500;
 
-  private enum ArgumentReuseMode {
+  public enum ArgumentReuseMode {
     ALLOW_ARGUMENT_REUSE_U4BIT,
     ALLOW_ARGUMENT_REUSE_U8BIT,
-    ALLOW_ARGUMENT_REUSE_U16BIT
+    ALLOW_ARGUMENT_REUSE_U16BIT;
+
+    boolean hasRegisterConstraint(LiveIntervals intervals) {
+      return hasRegisterConstraint(intervals.getRegisterLimit());
+    }
+
+    boolean hasRegisterConstraint(LiveIntervalsUse use) {
+      return hasRegisterConstraint(use.getLimit());
+    }
+
+    private boolean hasRegisterConstraint(int constraint) {
+      switch (this) {
+        case ALLOW_ARGUMENT_REUSE_U4BIT:
+          return false;
+        case ALLOW_ARGUMENT_REUSE_U8BIT:
+          return constraint == Constants.U4BIT_MAX;
+        case ALLOW_ARGUMENT_REUSE_U16BIT:
+          return constraint != Constants.U16BIT_MAX;
+        default:
+          throw new Unreachable();
+      }
+    }
   }
 
   private static class LocalRange implements Comparable<LocalRange> {
@@ -741,7 +763,7 @@
     Value current = firstArgumentValue;
     while (current != null) {
       LiveIntervals intervals = current.getLiveIntervals();
-      assert intervals.getRegisterLimit() == Constants.U16BIT_MAX;
+      assert !mode.hasRegisterConstraint(intervals);
       boolean canUseArgumentRegister = true;
       boolean couldUseArgumentRegister = true;
       for (LiveIntervals child : intervals.getSplitChildren()) {
@@ -852,7 +874,7 @@
       LiveIntervals argumentInterval = argumentValue.getLiveIntervals();
       assert argumentInterval.getRegister() != NO_REGISTER;
       unhandled.remove(argumentInterval);
-      if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U4BIT) {
+      if (!mode.hasRegisterConstraint(argumentInterval)) {
         // All the argument intervals are active in the beginning and have preallocated registers.
         active.add(argumentInterval);
       } else {
@@ -1694,8 +1716,9 @@
       }
       // If the first use for these intervals is unconstrained, just spill this interval instead
       // of finding another candidate to spill via allocateBlockedRegister.
-      if (!unhandledInterval.hasUses() || !unhandledInterval.getUses().first().hasConstraint()) {
-        int nextConstrainedPosition = unhandledInterval.firstUseWithConstraint().getPosition();
+      assert unhandledInterval.hasUses();
+      if (!unhandledInterval.getUses().first().hasConstraint()) {
+        int nextConstrainedPosition = unhandledInterval.firstUseWithConstraint(mode).getPosition();
         int register = getSpillRegister(unhandledInterval, null);
         LiveIntervals split = unhandledInterval.splitBefore(nextConstrainedPosition);
         assignFreeRegisterToUnhandledInterval(unhandledInterval, register);
@@ -1707,25 +1730,32 @@
       // We will use the candidate register(s) for unhandledInterval, and therefore potentially
       // need to adjust maxRegisterNumber.
       int candidateEnd = candidate + unhandledInterval.requiredRegisters() - 1;
-      if (candidateEnd > maxRegisterNumber) {
-        increaseCapacity(candidateEnd);
-      }
-
       if (largestFreePosition >= unhandledInterval.getEnd()) {
         // Free for the entire interval. Allocate the register.
+        ensureCapacity(candidateEnd);
         assignFreeRegisterToUnhandledInterval(unhandledInterval, candidate);
+      } else if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U4BIT) {
+        // No splitting is allowed when we allow argument reuse. Bailout and start over with
+        // argument reuse disallowed.
+        return false;
       } else {
-        if (mode == ArgumentReuseMode.ALLOW_ARGUMENT_REUSE_U4BIT) {
-          // No splitting is allowed when we allow argument reuse. Bailout and start over with
-          // argument reuse disallowed.
-          return false;
-        }
         // The candidate is free for the beginning of an interval. We split the interval
         // and use the register for as long as we can.
+        int registerConstraintBeforeSplit = unhandledInterval.getRegisterLimit();
         LiveIntervals split = unhandledInterval.splitBefore(largestFreePosition);
         assert split != unhandledInterval;
-        assignFreeRegisterToUnhandledInterval(unhandledInterval, candidate);
         unhandled.add(split);
+
+        // After splitting the live intervals we may be able to find a more appropriate register
+        // than the current candidate register. This is especially true if this is an argument that
+        // is pinned in its incoming register, since if the live intervals is now unconstrained we
+        // avoid a redundant move to a low register.
+        if (unhandledInterval.getRegisterLimit() != registerConstraintBeforeSplit) {
+          return allocateSingleInterval(unhandledInterval, mode);
+        }
+
+        ensureCapacity(candidateEnd);
+        assignFreeRegisterToUnhandledInterval(unhandledInterval, candidate);
       }
     }
     return true;
@@ -2453,7 +2483,9 @@
             split != null;
             split = sortedChildren.poll()) {
           int position = split.getStart();
-          spillMoves.addSpillOrRestoreMove(toGapPosition(position), split, current);
+          if (!isPinnedArgumentRegister(split)) {
+            spillMoves.addSpillOrRestoreMove(toGapPosition(position), split, current);
+          }
           current = split;
         }
       }
@@ -2506,6 +2538,10 @@
           LiveIntervals parentInterval = value.getLiveIntervals();
           LiveIntervals fromIntervals = parentInterval.getSplitCovering(fromInstruction);
           LiveIntervals toIntervals = parentInterval.getSplitCovering(toInstruction);
+          if (isPinnedArgumentRegister(toIntervals)) {
+            // No need to add resolution moves to pinned argument registers.
+            continue;
+          }
           if (fromIntervals != toIntervals) {
             if (block.exit().isGoto() && !isCatch) {
               spillMoves.addOutResolutionMove(fromInstruction - 1, toIntervals, fromIntervals);
@@ -2531,6 +2567,14 @@
     }
   }
 
+  boolean isPinnedArgumentRegister(LiveIntervals intervals) {
+    if (intervals.isArgumentInterval()) {
+      assert intervals.getRegister() != NO_REGISTER;
+      return intervals.getRegister() < numberOfArgumentRegisters;
+    }
+    return false;
+  }
+
   private static void addLiveRange(
       Value value, BasicBlock block, int end, List<LiveIntervals> liveIntervals, IRCode code) {
     int firstInstructionInBlock = block.entry().getNumber();
@@ -3049,6 +3093,12 @@
     }
   }
 
+  private void ensureCapacity(int newMaxRegisterNumber) {
+    if (newMaxRegisterNumber > maxRegisterNumber) {
+      increaseCapacity(newMaxRegisterNumber);
+    }
+  }
+
   private void increaseCapacity(int newMaxRegisterNumber) {
     increaseCapacity(newMaxRegisterNumber, false);
   }
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
index 175ee30..87ca105 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
@@ -10,6 +10,8 @@
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.code.ValueType;
+import com.android.tools.r8.ir.regalloc.LinearScanRegisterAllocator.ArgumentReuseMode;
+import com.google.common.collect.Iterables;
 import it.unimi.dsi.fastutil.ints.IntArrayList;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -397,12 +399,11 @@
   }
 
   public LiveIntervalsUse firstUseWithConstraint() {
-    for (LiveIntervalsUse use : uses) {
-      if (use.hasConstraint()) {
-        return use;
-      }
-    }
-    return null;
+    return Iterables.find(uses, LiveIntervalsUse::hasConstraint, null);
+  }
+
+  public LiveIntervalsUse firstUseWithConstraint(ArgumentReuseMode mode) {
+    return Iterables.find(uses, use -> use.hasConstraint(mode), null);
   }
 
   public void forEachRegister(IntConsumer consumer) {
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervalsUse.java b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervalsUse.java
index db8e48d..f016b66 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervalsUse.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervalsUse.java
@@ -5,6 +5,8 @@
 
 import static com.android.tools.r8.dex.Constants.U16BIT_MAX;
 
+import com.android.tools.r8.ir.regalloc.LinearScanRegisterAllocator.ArgumentReuseMode;
+
 public class LiveIntervalsUse implements Comparable<LiveIntervalsUse> {
   private final int position;
   private final int limit;
@@ -47,4 +49,8 @@
   public boolean hasConstraint() {
     return limit < U16BIT_MAX;
   }
+
+  public boolean hasConstraint(ArgumentReuseMode mode) {
+    return mode.hasRegisterConstraint(this);
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/SpillMoveSet.java b/src/main/java/com/android/tools/r8/ir/regalloc/SpillMoveSet.java
index bace7c1..fe7c2fc 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/SpillMoveSet.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/SpillMoveSet.java
@@ -264,12 +264,12 @@
     // Spill and restore moves for the incoming edge.
     Set<SpillMove> inMoves =
         MapUtils.removeOrDefault(instructionToInMoves, number - 1, Collections.emptySet());
-    removeArgumentRestores(inMoves);
+    assert verifyNoArgumentRestores(inMoves);
 
     // Spill and restore moves for the outgoing edge.
     Set<SpillMove> outMoves =
         MapUtils.removeOrDefault(instructionToOutMoves, number - 1, Collections.emptySet());
-    removeArgumentRestores(outMoves);
+    assert verifyNoArgumentRestores(outMoves);
 
     // Get the phi moves for this instruction and schedule them with the out going spill moves.
     Set<SpillMove> phiMoves =
@@ -293,22 +293,19 @@
     assert !needsMovesBeforeInstruction(instruction);
   }
 
-  // Remove restore moves that restore arguments. Since argument register reuse is
-  // disallowed at this point we know that argument registers do not change value and
-  // therefore we don't have to perform spill moves. Performing spill moves will also
-  // make art reject the code because it loses type information for the argument.
-  //
-  // TODO(ager): We are dealing with some of these moves as rematerialization. However,
-  // we are still generating actual moves back to the original argument register.
-  // We should get rid of this method and avoid generating the moves in the first place.
-  private void removeArgumentRestores(Set<SpillMove> moves) {
+  // Since argument register reuse is disallowed at this point we know that argument registers do
+  // not change value and therefore we don't have to perform spill moves.
+  private boolean verifyNoArgumentRestores(Set<SpillMove> moves) {
     // The argument registers can be used for other values than the arguments in intervals where
     // the arguments are not live, so it is insufficient to check that the destination register
     // is in the argument register range.
-    moves.removeIf(
-        move ->
-            move.to.getRegister() < allocator.numberOfArgumentRegisters
-                && move.to.isArgumentInterval());
+    for (SpillMove move : moves) {
+      boolean isArgumentRestore =
+          move.to.getRegister() < allocator.numberOfArgumentRegisters
+              && move.to.isArgumentInterval();
+      assert !isArgumentRestore;
+    }
+    return true;
   }
 
   private void scheduleMoves(