Use fastutil collections in register allocator
Fixes: b/374702812
Change-Id: Ie1277570c4f0ac228571b9d3f3d154f3167ca935
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index e40f0d7..7216583 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -43,6 +43,7 @@
import com.android.tools.r8.ir.regalloc.RegisterPositions.RegisterType;
import com.android.tools.r8.utils.ArrayUtils;
import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.IntObjPredicate;
import com.android.tools.r8.utils.InternalOptions;
import com.android.tools.r8.utils.IterableUtils;
import com.android.tools.r8.utils.LinkedHashSetUtils;
@@ -63,7 +64,9 @@
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
+import it.unimi.dsi.fastutil.ints.IntRBTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
+import it.unimi.dsi.fastutil.ints.IntSortedSet;
import it.unimi.dsi.fastutil.objects.Reference2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Reference2IntMap;
import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
@@ -80,8 +83,6 @@
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
-import java.util.TreeSet;
-import java.util.function.BiPredicate;
import java.util.function.Predicate;
/**
@@ -204,7 +205,7 @@
// The current register allocation mode.
private ArgumentReuseMode mode;
// The set of registers that are free for allocation.
- private TreeSet<Integer> freeRegisters = new TreeSet<>();
+ private IntSortedSet freeRegisters = new IntRBTreeSet();
// The max register number used.
private int maxRegisterNumber = -1;
@@ -1309,7 +1310,7 @@
}
private boolean invariantsHold(ArgumentReuseMode mode) {
- TreeSet<Integer> computedFreeRegisters = new TreeSet<>();
+ IntSortedSet computedFreeRegisters = new IntRBTreeSet();
for (int register = 0; register <= maxRegisterNumber; ++register) {
computedFreeRegisters.add(register);
}
@@ -1459,7 +1460,7 @@
});
// Save the current register allocation state so we can restore it at the end.
- TreeSet<Integer> savedFreeRegisters = new TreeSet<>(freeRegisters);
+ IntSortedSet savedFreeRegisters = new IntRBTreeSet(freeRegisters);
int savedMaxRegisterNumber = maxRegisterNumber;
// Simulate adding all the active intervals to the inactive set by blocking their register if
@@ -1634,7 +1635,7 @@
return intervals.getSplitParent().getRegister();
}
- TreeSet<Integer> previousFreeRegisters = new TreeSet<>(freeRegisters);
+ IntSortedSet previousFreeRegisters = new IntRBTreeSet(freeRegisters);
int previousMaxRegisterNumber = maxRegisterNumber;
freeRegisters.removeAll(expiredHere);
if (excludedRegisters != null) {
@@ -1793,7 +1794,7 @@
// Is the array-get array register the same as the first register we are
// allocating for the result?
- private boolean isArrayGetArrayRegister(LiveIntervals intervals, int register) {
+ private boolean isArrayGetArrayRegister(int register, LiveIntervals intervals) {
assert needsArrayGetWideWorkaround(intervals);
Value array = intervals.getValue().definition.asArrayGet().array();
int arrayReg =
@@ -1833,7 +1834,7 @@
// Is one of the cmp-long argument registers the same as the register we are
// allocating for the result?
- private boolean isSingleResultOverlappingLongOperands(LiveIntervals intervals, int register) {
+ private boolean isSingleResultOverlappingLongOperands(int register, LiveIntervals intervals) {
assert needsSingleResultOverlappingLongOperandsWorkaround(intervals);
if (intervals.getValue().definition.isCmp()) {
Value left = intervals.getValue().definition.asCmp().leftValue();
@@ -1899,7 +1900,7 @@
}
private boolean isLongResultOverlappingLongOperands(
- LiveIntervals unhandledInterval, int register) {
+ int register, LiveIntervals unhandledInterval) {
assert needsLongResultOverlappingLongOperandsWorkaround(unhandledInterval);
Value left = unhandledInterval.getValue().definition.asBinop().leftValue();
Value right = unhandledInterval.getValue().definition.asBinop().rightValue();
@@ -2303,12 +2304,12 @@
}
// Check for overlapping long registers issue.
if (needsLongResultOverlappingLongOperandsWorkaround(unhandledInterval)
- && isLongResultOverlappingLongOperands(unhandledInterval, register)) {
+ && isLongResultOverlappingLongOperands(register, unhandledInterval)) {
return false;
}
// Check for aget-wide bug in recent Art VMs.
if (needsArrayGetWideWorkaround(unhandledInterval)
- && isArrayGetArrayRegister(unhandledInterval, register)) {
+ && isArrayGetArrayRegister(register, unhandledInterval)) {
return false;
}
assignFreeRegisterToUnhandledInterval(unhandledInterval, register);
@@ -2522,7 +2523,7 @@
private int handleWorkaround(
Predicate<LiveIntervals> workaroundNeeded,
- BiPredicate<LiveIntervals, Integer> workaroundNeededForCandidate,
+ IntObjPredicate<LiveIntervals> workaroundNeededForCandidate,
int candidate,
LiveIntervals unhandledInterval,
int registerConstraint,
@@ -2531,7 +2532,7 @@
RegisterType type) {
if (workaroundNeeded.test(unhandledInterval)) {
int lastCandidate = candidate;
- while (workaroundNeededForCandidate.test(unhandledInterval, candidate)) {
+ while (workaroundNeededForCandidate.test(candidate, unhandledInterval)) {
// Make the unusable register unavailable for allocation and try again.
freePositions.setBlockedTemporarily(candidate);
candidate =
@@ -3602,10 +3603,10 @@
private int getFreeConsecutiveRegisters(int numberOfRegisters, boolean prioritizeSmallRegisters) {
int oldMaxRegisterNumber = maxRegisterNumber;
- TreeSet<Integer> freeRegistersWithDesiredOrdering = freeRegisters;
+ IntSortedSet freeRegistersWithDesiredOrdering = freeRegisters;
if (prioritizeSmallRegisters) {
freeRegistersWithDesiredOrdering =
- new TreeSet<>(
+ new IntRBTreeSet(
(Integer x, Integer y) -> {
boolean xIsArgument = x < numberOfArgumentRegisters;
boolean yIsArgument = y < numberOfArgumentRegisters;
@@ -3623,7 +3624,7 @@
freeRegistersWithDesiredOrdering.addAll(freeRegisters);
}
- Iterator<Integer> freeRegistersIterator = freeRegistersWithDesiredOrdering.iterator();
+ IntIterator freeRegistersIterator = freeRegistersWithDesiredOrdering.iterator();
int first = getNextFreeRegister(freeRegistersIterator);
int current = first;
while (current - first + 1 != numberOfRegisters) {
@@ -3668,9 +3669,9 @@
return true;
}
- private int getNextFreeRegister(Iterator<Integer> freeRegistersIterator) {
+ private int getNextFreeRegister(IntIterator freeRegistersIterator) {
if (freeRegistersIterator.hasNext()) {
- return freeRegistersIterator.next();
+ return freeRegistersIterator.nextInt();
}
return ++maxRegisterNumber;
}