Optimize register allocator

Change-Id: Ibf334f9c48f60bf12820e81b00d94053e83dd4ab
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRToDexFinalizer.java b/src/main/java/com/android/tools/r8/ir/conversion/IRToDexFinalizer.java
index c971967..c7c4533 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRToDexFinalizer.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRToDexFinalizer.java
@@ -71,7 +71,8 @@
     // does not allow dead code (to make sure that we do not waste registers for unneeded values).
     assert deadCodeRemover.verifyNoDeadCode(code);
     timing.begin("Allocate registers");
-    LinearScanRegisterAllocator registerAllocator = new LinearScanRegisterAllocator(appView, code);
+    LinearScanRegisterAllocator registerAllocator =
+        new LinearScanRegisterAllocator(appView, code, timing);
     registerAllocator.allocateRegisters();
     timing.end();
     TrivialGotosCollapser trivialGotosCollapser = new TrivialGotosCollapser(appView);
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index 7b45b03..c66b0eb 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -50,6 +50,7 @@
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.SetUtils;
 import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.Timing;
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
@@ -253,6 +254,8 @@
   // because their values can be rematerialized.
   private int[] unusedRegisters = null;
 
+  private final Timing timing;
+
   // Whether or not the code has a move exception instruction. Used to pin the move exception
   // register.
   private boolean hasDedicatedMoveExceptionRegister() {
@@ -291,7 +294,7 @@
     return !isDedicatedMoveExceptionRegisterInFirstLocalRegister();
   }
 
-  public LinearScanRegisterAllocator(AppView<?> appView, IRCode code) {
+  public LinearScanRegisterAllocator(AppView<?> appView, IRCode code, Timing timing) {
     this.appView = appView;
     this.code = code;
     int argumentRegisters = 0;
@@ -303,6 +306,7 @@
       }
     }
     numberOfArgumentRegisters = argumentRegisters;
+    this.timing = timing;
   }
 
   private boolean retry8BitAllocationWith4BitArgumentRegisters() {
@@ -353,11 +357,15 @@
     if (this.code.method().accessFlags.isBridge() && implementationIsBridge(this.code)) {
       transformBridgeMethod();
     }
+    timing.begin("Setup");
     computeNeedsRegister();
     constrainArgumentIntervals();
     insertRangeInvokeMoves();
     ImmutableList<BasicBlock> blocks = computeLivenessInformation();
+    timing.end();
+    timing.begin("Allocate");
     performAllocation();
+    timing.end();
     assert code.isConsistentGraph(appView);
     assert mode.is4Bit() || registersUsed() == 0 || unusedRegisters != null;
     // Even if the method is reachability sensitive, we do not compute debug information after
@@ -871,14 +879,18 @@
     assert numberOf4BitArgumentRegisters == 0 || mode.is8BitRefinement();
     ArgumentReuseMode result = mode;
     this.mode = mode;
+    timing.begin(mode.toString());
+    timing.begin("Prepare");
     if (retry) {
       clearRegisterAssignments();
       removeSpillAndPhiMoves();
     }
 
     pinArgumentRegisters();
+    timing.end();
 
     boolean succeeded = performLinearScan(mode);
+    timing.end();
     if (succeeded) {
       InsertMovesResult insertMovesResult = insertMoves();
 
@@ -1126,21 +1138,31 @@
   private boolean performLinearScan(ArgumentReuseMode mode) {
     unhandled.addAll(liveIntervals);
 
+    timing.begin("Prelude");
     processArgumentLiveIntervals();
     boolean hasInvokeRangeLiveIntervals = splitLiveIntervalsForInvokeRange();
     allocateRegistersForMoveExceptionIntervals(hasInvokeRangeLiveIntervals);
+    timing.end();
 
+    timing.begin("Argument linked");
     for (Value argumentValue = firstArgumentValue;
         argumentValue != null;
         argumentValue = argumentValue.getNextConsecutive()) {
       allocateRegistersForInvokeRangeSplits(argumentValue.getLiveIntervals());
     }
+    timing.end();
 
     // Go through each unhandled live interval and find a register for it.
+    timing.begin("Process all unhandled");
     while (!unhandled.isEmpty()) {
       assert invariantsHold(mode);
 
       LiveIntervals unhandledInterval = unhandled.poll();
+      if (unhandledInterval.isHandled()) {
+        assert unhandledInterval.hasRegister();
+        continue;
+      }
+
       setHintForDestRegOfCheckCast(unhandledInterval);
       setHintToPromote2AddrInstruction(unhandledInterval);
 
@@ -1148,21 +1170,30 @@
       // consecutive arguments now and add hints to the live intervals leading up to this
       // invoke/range. This looks forward and propagate hints backwards to avoid many moves in
       // connection with ranged invokes.
+      timing.begin("Linked");
       allocateRegistersForInvokeRangeSplits(unhandledInterval);
+      timing.end();
       if (unhandledInterval.hasRegister()) {
         // The value itself is in the chain that has now gotten registers allocated.
         continue;
       }
 
+      timing.begin("Advance state");
       advanceStateToLiveIntervals(unhandledInterval);
+      timing.end();
 
       // Perform the actual allocation.
+      timing.begin("Alloc single");
       if (!allocateSingleInterval(unhandledInterval)
           || maxRegisterNumber > mode.getMaxRegisterNumber()) {
+        timing.end();
+        timing.end();
         return false;
       }
+      timing.end();
       expiredHere.clear();
     }
+    timing.end();
     assert invariantsHold(mode);
     return true;
   }
@@ -1173,7 +1204,7 @@
         argumentValue = argumentValue.getNextConsecutive()) {
       LiveIntervals argumentInterval = argumentValue.getLiveIntervals();
       assert argumentInterval.hasRegister();
-      unhandled.remove(argumentInterval);
+      argumentInterval.setHandled();
       if (!mode.hasRegisterConstraint(argumentInterval)) {
         // All the argument intervals are active in the beginning and have preallocated registers.
         active.add(argumentInterval);
@@ -1235,7 +1266,7 @@
       if (intervals.getStart() < moveException.getNumber()) {
         intervals = intervals.splitBefore(moveException, mode);
       } else {
-        unhandled.remove(intervals);
+        intervals.setHandled();
       }
       moveExceptionIntervals.add(intervals);
       intervals.setRegister(getMoveExceptionRegister());
@@ -1454,11 +1485,15 @@
     if (!unhandledIntervals.isSplitParent()) {
       return;
     }
+    timing.begin("Extract splits");
     List<LiveIntervals> invokeRangeIntervals =
         ListUtils.filter(
             unhandledIntervals.getSplitChildren(),
             split -> split.isInvokeRangeIntervals() && !split.hasRegister());
+    timing.end();
+    timing.begin("Process splits");
     for (LiveIntervals split : invokeRangeIntervals) {
+      timing.begin("Extract list");
       Invoke invoke = split.getIsInvokeRangeIntervals();
       List<LiveIntervals> intervalsList =
           ListUtils.map(
@@ -1472,13 +1507,18 @@
                     || overlappingInvokeArgumentIntervals.getEnd() == invoke.getNumber() + 1;
                 return overlappingInvokeArgumentIntervals;
               });
+      timing.end();
+      timing.begin("Prelude");
 
       // Save the current register allocation state so we can restore it at the end.
+      timing.begin("Copy free registers");
       IntSortedSet savedFreeRegisters = new IntRBTreeSet(freeRegisters);
       int savedMaxRegisterNumber = maxRegisterNumber;
+      timing.end();
 
       // Simulate adding all the active intervals to the inactive set by blocking their register if
       // they overlap with any of the invoke/range intervals.
+      timing.begin("Overlaps active");
       for (LiveIntervals active : active) {
         // We could allow the use of all the currently active registers for the ranged invoke (by
         // adding the registers for all the active intervals to freeRegisters here). That could lead
@@ -1494,10 +1534,16 @@
           freeOccupiedRegistersForIntervals(active);
         }
       }
+      timing.end();
 
-      unhandled.removeAll(intervalsList);
+      timing.begin("Remove intervals from unhandled");
+      intervalsList.forEach(LiveIntervals::setHandled);
+      timing.end();
+      timing.end();
+      timing.begin("Allocate");
       allocateLinkedIntervals(intervalsList, invoke);
-
+      timing.end();
+      timing.begin("Postlude");
       // Restore the register allocation state.
       freeRegisters = savedFreeRegisters;
       // In case maxRegisterNumber has changed, update freeRegisters.
@@ -1506,12 +1552,15 @@
       }
       // Move all the argument intervals to the inactive set.
       inactive.addAll(intervalsList);
+      timing.end();
     }
+    timing.end();
   }
 
   private void allocateLinkedIntervals(List<LiveIntervals> intervalsList, Invoke invoke) {
     LiveIntervals start = ListUtils.first(intervalsList);
 
+    timing.begin("Prelude");
     boolean consecutiveArguments =
         IterableUtils.allWithPrevious(
             intervalsList,
@@ -1521,6 +1570,7 @@
                         == previous.getSplitParent());
     boolean consecutivePinnedArguments =
         consecutiveArguments && Iterables.all(intervalsList, this::isPinnedArgumentRegister);
+    timing.end();
 
     int nextRegister;
     if (consecutivePinnedArguments) {
@@ -1529,6 +1579,7 @@
     } else {
       // Ensure that there is a free register for the out value (or two consecutive registers if
       // wide).
+      timing.begin("Not consecutive pinned args");
       int numberOfRegisters = getNumberOfRequiredRegisters(intervalsList);
       int numberOfOutRegisters = invoke.hasOutValue() ? invoke.outValue().requiredRegisters() : 0;
       if (numberOfOutRegisters > 0
@@ -1546,20 +1597,32 @@
 
       // Exclude the registers that overlap the start of one of the live ranges we are going to
       // assign registers to now.
+      timing.begin("Overlaps inactive");
       for (LiveIntervals inactiveIntervals : inactive) {
-        if (Iterables.any(intervalsList, inactiveIntervals::overlaps)) {
+        if (inactiveIntervals.isInvokeRangeIntervals()) {
+          // This is the live intervals for another invoke-range, these can never overlap.
+          assert !Iterables.any(intervalsList, inactiveIntervals::overlaps);
+          continue;
+        }
+        // All of the invoke-range live intervals usually start at the same instruction number.
+        if (inactiveIntervals.overlapsAnyInvokeRangeIntervals(intervalsList)) {
           excludeRegistersForInterval(inactiveIntervals);
         }
       }
+      timing.end();
 
+      timing.begin("Register range is free");
       if (consecutiveArguments
           && registerRangeIsFree(start.getSplitParent().getRegister(), numberOfRegisters)) {
         // For consecutive arguments we always to use the input argument registers, if they are
         // free.
+        timing.end();
         nextRegister = start.getSplitParent().getRegister();
       } else {
+        timing.end();
         // Exclude the pinned argument registers for which there exists a split that overlaps with
         // one of the inputs to the invoke-range instruction.
+        timing.begin("Exclude pinned args");
         for (Value argument = firstArgumentValue;
             argument != null;
             argument = argument.getNextConsecutive()) {
@@ -1569,9 +1632,11 @@
             excludeRegistersForInterval(argumentLiveIntervals);
           }
         }
+        timing.end();
         // Exclude move exception register if the first interval overlaps a move exception interval.
         // It is not necessary to check the remaining consecutive intervals, since we always use
         // register 0 (after remapping) for the argument register.
+        timing.begin("Exclude move exc");
         if (hasDedicatedMoveExceptionRegister()) {
           boolean canUseMoveExceptionRegisterForLinkedIntervals =
               isDedicatedMoveExceptionRegisterInFirstLocalRegister()
@@ -1580,17 +1645,22 @@
             freeRegisters.remove(getMoveExceptionRegister());
           }
         }
+        timing.end();
+
         // Select registers.
         nextRegister = getFreeConsecutiveRegisters(numberOfRegisters);
       }
+      timing.end();
     }
 
     // Assign registers.
+    timing.begin("Assign regs");
     for (LiveIntervals current : intervalsList) {
       current.setRegister(nextRegister);
       assert verifyRegisterAssignmentNotConflictingWithArgument(current);
       nextRegister += current.requiredRegisters();
     }
+    timing.end();
   }
 
   private int getNumberOfRequiredRegisters(List<LiveIntervals> intervalsList) {
@@ -1604,10 +1674,15 @@
   // Returns true if intervals has a split, which overlaps with any of the live intervals in the
   // given list.
   private boolean liveIntervalsOverlappingAnyOf(
-      LiveIntervals intervals, List<LiveIntervals> intervalsList) {
-    assert intervals == intervals.getSplitParent();
-    for (LiveIntervals split : intervals.getSplitChildren()) {
-      if (Iterables.any(intervalsList, split::overlaps)) {
+      LiveIntervals argumentLiveIntervals, List<LiveIntervals> intervalsList) {
+    assert argumentLiveIntervals == argumentLiveIntervals.getSplitParent();
+    for (LiveIntervals intervals : intervalsList) {
+      if (intervals.getValue() == argumentLiveIntervals.getValue()) {
+        return true;
+      }
+    }
+    for (LiveIntervals split : argumentLiveIntervals.getSplitChildren()) {
+      if (split.overlapsAnyInvokeRangeIntervals(intervalsList)) {
         return true;
       }
     }
@@ -1999,9 +2074,12 @@
     assert freePositionsAreConsistentWithFreeRegisters(freePositions, registerConstraint);
 
     // Attempt to use register hints.
+    timing.begin("Try hint");
     if (useRegisterHint(unhandledInterval, registerConstraint, freePositions, needsRegisterPair)) {
+      timing.end();
       return true;
     }
+    timing.end();
 
     // Get the register (pair) that is free the longest. That is the register with the largest
     // free position.
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
index 172b77c..f423296 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LiveIntervals.java
@@ -37,7 +37,7 @@
   private final List<LiveIntervals> splitChildren = new ArrayList<>();
   private final IntArrayList sortedSplitChildrenEnds = new IntArrayList();
   private boolean sortedChildren = false;
-  private List<LiveRange> ranges = new ArrayList<>();
+  private ArrayList<LiveRange> ranges = new ArrayList<>();
   private final TreeSet<LiveIntervalsUse> uses = new TreeSet<>();
   private int register = NO_REGISTER;
   private int hint = NO_REGISTER;
@@ -45,6 +45,7 @@
   private Invoke isInvokeRangeIntervals = null;
   private boolean usedInMonitorOperations = false;
   private boolean liveAtMoveExceptionEntry = false;
+  private boolean handled = false;
 
   // Only registers up to and including the registerLimit are allowed for this interval.
   private int registerLimit = U16BIT_MAX;
@@ -111,6 +112,15 @@
     return hint;
   }
 
+  public boolean isHandled() {
+    return handled;
+  }
+
+  // Equivalent to removing the live intervals from the unhandled set. This is O(1) instead of O(n).
+  public void setHandled() {
+    handled = true;
+  }
+
   public void setSpilled(boolean value) {
     // Check that we always spill arguments to their original register.
     assert getRegister() != NO_REGISTER;
@@ -377,6 +387,9 @@
   }
 
   public boolean overlapsPosition(int position) {
+    if (position < getStart() || position >= getEnd()) {
+      return false;
+    }
     for (LiveRange range : ranges) {
       if (range.start > position) {
         // Ranges are sorted. When a range starts after position there is no overlap.
@@ -393,6 +406,25 @@
     return nextOverlap(other) != -1;
   }
 
+  public boolean overlapsAnyInvokeRangeIntervals(List<LiveIntervals> intervalsList) {
+    boolean checked = false;
+    for (LiveIntervals intervals : intervalsList) {
+      boolean skip;
+      if (intervals.getValue().isDefinedByInstructionSatisfying(Instruction::isMove)) {
+        skip = false;
+      } else {
+        skip = checked;
+        if (!checked) {
+          checked = true;
+        }
+      }
+      if (!skip && overlaps(intervals)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   public boolean anySplitOverlaps(LiveIntervals other) {
     LiveIntervals parent = getSplitParent();
     if (parent.overlaps(other)) {
@@ -491,8 +523,8 @@
     LiveIntervals splitChild = new LiveIntervals(splitParent);
     splitParent.splitChildren.add(splitChild);
     splitParent.sortedChildren = splitParent.splitChildren.size() == 1;
-    List<LiveRange> beforeSplit = new ArrayList<>();
-    List<LiveRange> afterSplit = new ArrayList<>();
+    ArrayList<LiveRange> beforeSplit = new ArrayList<>();
+    ArrayList<LiveRange> afterSplit = new ArrayList<>();
     if (start == getEnd()) {
       beforeSplit = ranges;
       afterSplit.add(new LiveRange(start, start));
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ConstantRemovalTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ConstantRemovalTest.java
index adc7fb7..850cc4b 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/ConstantRemovalTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ConstantRemovalTest.java
@@ -29,6 +29,7 @@
 import com.android.tools.r8.ir.regalloc.LiveIntervals;
 import com.android.tools.r8.synthesis.SyntheticItems.GlobalSyntheticsStrategy;
 import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.Timing;
 import java.util.LinkedList;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -49,7 +50,7 @@
 
   private static class MockLinearScanRegisterAllocator extends LinearScanRegisterAllocator {
     MockLinearScanRegisterAllocator(AppView<?> appView, IRCode code) {
-      super(appView, code);
+      super(appView, code, Timing.empty());
     }
 
     @Override
diff --git a/src/test/java/com/android/tools/r8/ir/regalloc/Regress68656641.java b/src/test/java/com/android/tools/r8/ir/regalloc/Regress68656641.java
index fd91973..3394057 100644
--- a/src/test/java/com/android/tools/r8/ir/regalloc/Regress68656641.java
+++ b/src/test/java/com/android/tools/r8/ir/regalloc/Regress68656641.java
@@ -16,6 +16,7 @@
 import com.android.tools.r8.synthesis.SyntheticItems.GlobalSyntheticsStrategy;
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.codeinspector.MethodSubject;
 import com.google.common.collect.ImmutableList;
 import java.util.PriorityQueue;
@@ -25,7 +26,7 @@
 
   private static class MyRegisterAllocator extends LinearScanRegisterAllocator {
     public MyRegisterAllocator(AppView<?> appView, IRCode code) {
-      super(appView, code);
+      super(appView, code, Timing.empty());
     }
 
     public void addInactiveIntervals(LiveIntervals intervals) {