Reuse temporary registers in move scheduler

Bug: b/375147902
Change-Id: I51c6995ceccf7303e8a9b4c0a4d99a34daa9d9cc
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/RegisterMoveScheduler.java b/src/main/java/com/android/tools/r8/ir/regalloc/RegisterMoveScheduler.java
index 4c60ed6..242f2dc 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/RegisterMoveScheduler.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/RegisterMoveScheduler.java
@@ -19,7 +19,10 @@
 import it.unimi.dsi.fastutil.ints.Int2IntMap;
 import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
 import it.unimi.dsi.fastutil.ints.IntArraySet;
+import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
+import it.unimi.dsi.fastutil.ints.IntRBTreeSet;
 import it.unimi.dsi.fastutil.ints.IntSet;
+import it.unimi.dsi.fastutil.ints.IntSortedSet;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Deque;
@@ -34,25 +37,27 @@
   // Mapping to keep track of which values currently corresponds to each other.
   // This is initially an identity map but changes as we insert moves.
   private final Int2IntMap valueMap = new Int2IntOpenHashMap();
-  // Number of temp registers used to schedule the moves.
-  private int usedTempRegisters = 0;
   // Location at which to insert the scheduled moves.
   private final InstructionListIterator insertAt;
   // Debug position associated with insertion point.
   private final Position position;
   // The first available temporary register.
-  private final int tempRegister;
+  private final int firstTempRegister;
+  private int nextTempRegister;
+  // Free registers.
+  private final IntSortedSet freeRegisters = new IntRBTreeSet();
 
   public RegisterMoveScheduler(
-      InstructionListIterator insertAt, int tempRegister, Position position) {
+      InstructionListIterator insertAt, int firstTempRegister, Position position) {
     this.insertAt = insertAt;
-    this.tempRegister = tempRegister;
+    this.firstTempRegister = firstTempRegister;
+    this.nextTempRegister = firstTempRegister;
     this.position = position;
     this.valueMap.defaultReturnValue(NO_REGISTER);
   }
 
-  public RegisterMoveScheduler(InstructionListIterator insertAt, int tempRegister) {
-    this(insertAt, tempRegister, Position.none());
+  public RegisterMoveScheduler(InstructionListIterator insertAt, int firstTempRegister) {
+    this(insertAt, firstTempRegister, Position.none());
   }
 
   public void addMove(RegisterMove move) {
@@ -117,7 +122,7 @@
   }
 
   public int getUsedTempRegisters() {
-    return usedTempRegisters;
+    return nextTempRegister - firstTempRegister;
   }
 
   private List<RegisterMove> findMovesWithSrc(int src, TypeElement type) {
@@ -159,14 +164,15 @@
         }
       }
     } else {
+      int mappedSrc = valueMap.get(move.src);
       Value to = new FixedRegisterValue(move.type, move.dst);
-      Value from = new FixedRegisterValue(move.type, valueMap.get(move.src));
+      Value from = new FixedRegisterValue(move.type, mappedSrc);
       instruction = new Move(to, from);
+      returnTemporaryRegister(mappedSrc, move.isWide());
     }
     instruction.setPosition(position);
     insertAt.add(instruction);
     return move.dst;
-
   }
 
   private void createMoveDestToTemp(RegisterMove move) {
@@ -174,24 +180,70 @@
     // registers if we are unlucky with the overlap for values that use two registers.
     List<RegisterMove> movesWithSrc = findMovesWithSrc(move.dst, move.type);
     assert movesWithSrc.size() > 0;
+    assert verifyMovesHaveDifferentSources(movesWithSrc);
     for (RegisterMove moveWithSrc : movesWithSrc) {
-      // TODO(b/375147902): For now we always use a new temporary register whenever we have to
-      //  unblock a move. The move scheduler can have multiple unblocking temps live at the same
-      //  time and therefore we cannot have just one tempRegister (pair). However, we could check
-      //  here if the previously used tempRegisters is still needed by any of the moves in the move
-      //  set (taking the value map into account). If not, we can reuse the temp register instead
-      //  of generating a new one.
-      int register = tempRegister + usedTempRegisters;
+      // TODO(b/375147902): Maybe seed the move scheduler with a set of registers known to be free
+      //  at this point.
+      int register = takeFreeRegister(moveWithSrc.isWide());
       Value to = new FixedRegisterValue(moveWithSrc.type, register);
       Value from = new FixedRegisterValue(moveWithSrc.type, valueMap.get(moveWithSrc.src));
       Move instruction = new Move(to, from);
       instruction.setPosition(position);
       insertAt.add(instruction);
       valueMap.put(moveWithSrc.src, register);
-      usedTempRegisters += moveWithSrc.type.requiredRegisters();
     }
   }
 
+  private int takeFreeRegister(boolean wide) {
+    for (int freeRegister : freeRegisters) {
+      if (wide && !freeRegisters.remove(freeRegister + 1)) {
+        continue;
+      }
+      freeRegisters.remove(freeRegister);
+      return freeRegister;
+    }
+    // We don't have a free register.
+    int register = allocateExtraRegister();
+    if (!wide) {
+      return register;
+    }
+    if (freeRegisters.remove(register - 1)) {
+      return register - 1;
+    }
+    allocateExtraRegister();
+    return register;
+  }
+
+  private void returnTemporaryRegister(int register, boolean wide) {
+    // TODO(b/375147902): If we seed the move scheduler with a set of free registers, then this
+    //  should also return non-temporary registers that are below firstTempRegister.
+    if (isTemporaryRegister(register)) {
+      freeRegisters.add(register);
+      if (wide) {
+        assert isTemporaryRegister(register + 1);
+        freeRegisters.add(register + 1);
+      }
+    } else if (wide && isTemporaryRegister(register + 1)) {
+      freeRegisters.add(register + 1);
+    }
+  }
+
+  private boolean isTemporaryRegister(int register) {
+    return register >= firstTempRegister;
+  }
+
+  private int allocateExtraRegister() {
+    return nextTempRegister++;
+  }
+
+  private boolean verifyMovesHaveDifferentSources(List<RegisterMove> movesWithSrc) {
+    IntSet seen = new IntOpenHashSet();
+    for (RegisterMove move : movesWithSrc) {
+      assert seen.add(move.src);
+    }
+    return true;
+  }
+
   private RegisterMove pickMoveToUnblock() {
     Iterator<RegisterMove> iterator = moveSet.iterator();
     RegisterMove move = null;
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java b/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
index fab0924..ae83eee 100644
--- a/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/RegisterMoveSchedulerTest.java
@@ -5,6 +5,9 @@
 
 import static org.junit.Assert.assertEquals;
 
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.errors.Unimplemented;
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.AppView;
@@ -39,8 +42,12 @@
 import java.util.function.Consumer;
 import java.util.function.UnaryOperator;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
 
-public class RegisterMoveSchedulerTest {
+@RunWith(Parameterized.class)
+public class RegisterMoveSchedulerTest extends TestBase {
 
   private static class CollectMovesIterator implements InstructionListIterator {
 
@@ -244,6 +251,15 @@
     }
   }
 
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withNoneRuntime().build();
+  }
+
+  public RegisterMoveSchedulerTest(TestParameters parameters) {
+    parameters.assertNoneRuntime();
+  }
+
   @Test
   public void testSingleParallelMove() {
     CollectMovesIterator moves = new CollectMovesIterator();
@@ -486,6 +502,26 @@
     assertEquals(43, moves.get(5).asMove().src().asFixedRegisterValue().getRegister());
   }
 
+  @Test
+  public void reuseTempRegister() {
+    CollectMovesIterator moves = new CollectMovesIterator();
+    int temp = 42;
+    RegisterMoveScheduler scheduler = new RegisterMoveScheduler(moves, temp);
+    scheduler.addMove(new RegisterMove(0, 1, TypeElement.getInt()));
+    scheduler.addMove(new RegisterMove(1, 0, TypeElement.getInt()));
+    scheduler.addMove(new RegisterMove(2, 3, TypeElement.getInt()));
+    scheduler.addMove(new RegisterMove(3, 2, TypeElement.getInt()));
+    scheduler.schedule();
+    // Verify that the temp register has been reused.
+    assertEquals("42 <- 1", toString(moves.get(0)));
+    assertEquals("1 <- 0", toString(moves.get(1)));
+    assertEquals("0 <- 42", toString(moves.get(2)));
+    assertEquals("42 <- 3", toString(moves.get(3)));
+    assertEquals("3 <- 2", toString(moves.get(4)));
+    assertEquals("2 <- 42", toString(moves.get(5)));
+    assertEquals(1, scheduler.getUsedTempRegisters());
+  }
+
   // Debugging aid.
   private void printMoves(List<Instruction> moves) {
     System.out.println("Generated moves:");
@@ -496,4 +532,10 @@
     }
     System.out.println("----------------");
   }
+
+  private String toString(Instruction move) {
+    return move.outValue().asFixedRegisterValue().getRegister()
+        + " <- "
+        + move.getFirstOperand().asFixedRegisterValue().getRegister();
+  }
 }