Split try-catch ranges in case of u2 overflow
Bug: b/297320921
Change-Id: I46c16ec116ba4ba70d309ebb34bc9ae4a75740bc
diff --git a/src/main/java/com/android/tools/r8/dex/JumboStringRewriter.java b/src/main/java/com/android/tools/r8/dex/JumboStringRewriter.java
index 785158d..8c11532 100644
--- a/src/main/java/com/android/tools/r8/dex/JumboStringRewriter.java
+++ b/src/main/java/com/android/tools/r8/dex/JumboStringRewriter.java
@@ -29,6 +29,7 @@
import com.android.tools.r8.dex.code.DexInstruction;
import com.android.tools.r8.dex.code.DexNop;
import com.android.tools.r8.dex.code.DexSwitchPayload;
+import com.android.tools.r8.errors.Unreachable;
import com.android.tools.r8.graph.DexCode;
import com.android.tools.r8.graph.DexCode.Try;
import com.android.tools.r8.graph.DexCode.TryHandler;
@@ -41,6 +42,7 @@
import com.android.tools.r8.graph.DexEncodedMethod;
import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.lightir.ByteUtils;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
@@ -190,12 +192,56 @@
for (int i = 0; i < code.tries.length; i++) {
Try theTry = code.tries[i];
TryTargets targets = tryTargets.get(theTry);
- result[i] = new Try(targets.getStartOffset(), targets.getStartToEndDelta(), -1);
+ int startToEndDelta = targets.getStartToEndDelta();
+ if (startToEndDelta > ByteUtils.MAX_U2) {
+ return rewriteSplitTryOffsets(code);
+ }
+ result[i] = new Try(targets.getStartOffset(), startToEndDelta, -1);
result[i].handlerIndex = theTry.handlerIndex;
}
return result;
}
+ // Note: this algorithm should be aligned with DexBuilder.splitOverflowingRanges.
+ private Try[] rewriteSplitTryOffsets(DexCode code) {
+ // It is unlikely we have 10 overflows (unlikely we have any to begin with).
+ int tentativeCapacity = code.tries.length + 10;
+ List<Try> result = new ArrayList<>(tentativeCapacity);
+ for (Try theTry : code.tries) {
+ TryTargets targets = tryTargets.get(theTry);
+ int startToEndDelta = targets.getStartToEndDelta();
+ int start = targets.getStartOffset();
+ while (startToEndDelta > ByteUtils.MAX_U2) {
+ // Find instruction offset under limit.
+ int maxOffset = start + ByteUtils.MAX_U2;
+ int intermediateEnd = -1;
+ for (int i = code.instructions.length - 1; i >= 0; i--) {
+ DexInstruction instruction = code.instructions[i];
+ // Note that the instructions have been expanded, so getOffset is the rewritten offset.
+ if (instruction.getOffset() <= maxOffset) {
+ intermediateEnd = instruction.getOffset();
+ break;
+ }
+ }
+ if (intermediateEnd <= start) {
+ throw new Unreachable("Unexpected try-catch handler end point: " + intermediateEnd);
+ }
+ int intermediateDelta = intermediateEnd - start;
+ Try splitTry = new Try(start, intermediateDelta, -1);
+ splitTry.handlerIndex = theTry.handlerIndex;
+ result.add(splitTry);
+ start = intermediateEnd;
+ startToEndDelta -= intermediateDelta;
+ }
+ assert startToEndDelta > 0;
+ Try rewrittenTry = new Try(start, startToEndDelta, -1);
+ rewrittenTry.handlerIndex = theTry.handlerIndex;
+ result.add(rewrittenTry);
+ }
+ assert result.size() > code.tries.length;
+ return result.toArray(Try.EMPTY_ARRAY);
+ }
+
private TryHandler[] rewriteHandlerOffsets() {
DexCode code = method.getCode().asDexCode();
TryHandler[] result = new TryHandler[code.handlers.length];
diff --git a/src/main/java/com/android/tools/r8/graph/DexCode.java b/src/main/java/com/android/tools/r8/graph/DexCode.java
index 1869581..1904080 100644
--- a/src/main/java/com/android/tools/r8/graph/DexCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DexCode.java
@@ -30,6 +30,7 @@
import com.android.tools.r8.ir.conversion.MethodConversionOptions.MutableMethodConversionOptions;
import com.android.tools.r8.ir.conversion.MethodConversionOptions.ThrowingMethodConversionOptions;
import com.android.tools.r8.naming.ClassNameMapper;
+import com.android.tools.r8.lightir.ByteUtils;
import com.android.tools.r8.origin.Origin;
import com.android.tools.r8.utils.ArrayUtils;
import com.android.tools.r8.utils.StringUtils;
@@ -782,6 +783,7 @@
this.instructionCount = instructionCount;
this.handlerOffset = handlerOffset;
this.handlerIndex = NO_INDEX;
+ assert ByteUtils.isU2(instructionCount);
}
@Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/Value.java b/src/main/java/com/android/tools/r8/ir/code/Value.java
index b0eb2bd..0f0e749 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Value.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Value.java
@@ -708,7 +708,10 @@
public boolean needsRegister() {
assert needsRegister >= 0;
- assert !hasUsersInfo() || (needsRegister > 0) == internalComputeNeedsRegister();
+ // This has quadratic behavior so don't check for large user sets.
+ assert !hasUsersInfo()
+ || numberOfAllUsers() > 100
+ || (needsRegister > 0) == internalComputeNeedsRegister();
return needsRegister > 0;
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/DexBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/DexBuilder.java
index ec93ef2..9c648a7 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/DexBuilder.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/DexBuilder.java
@@ -63,6 +63,7 @@
import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.ir.optimize.CodeRewriter;
import com.android.tools.r8.ir.regalloc.RegisterAllocator;
+import com.android.tools.r8.lightir.ByteUtils;
import com.android.tools.r8.utils.InternalOptions;
import com.android.tools.r8.utils.InternalOutputMode;
import com.google.common.collect.BiMap;
@@ -328,7 +329,7 @@
}
// Construct try-catch info.
- TryInfo tryInfo = computeTryInfo();
+ TryInfo tryInfo = computeTryInfo(dexInstructions);
// Return the dex code.
DexCode code =
@@ -814,11 +815,13 @@
// Helpers for computing the try items and handlers.
- private TryInfo computeTryInfo() {
+ private TryInfo computeTryInfo(List<DexInstruction> dexInstructions) {
// Canonical map of handlers.
BiMap<CatchHandlers<BasicBlock>, Integer> canonicalHandlers = HashBiMap.create();
// Compute the list of try items and their handlers.
List<TryItem> tryItems = computeTryItems(canonicalHandlers);
+ // Split the try items if they overflow the range limit.
+ tryItems = splitOverflowingRanges(tryItems, dexInstructions);
// Compute handler sets before dex items which depend on the handler index.
Try[] tries = getDexTryItems(tryItems, canonicalHandlers);
TryHandler[] handlers = getDexTryHandlers(canonicalHandlers.inverse());
@@ -905,6 +908,65 @@
return coalescedTryItems;
}
+ private static int numberOfOverflowingRanges(List<TryItem> tryItems) {
+ int numberOfOverflows = 0;
+ for (TryItem tryItem : tryItems) {
+ int instructionCount = tryItem.end - tryItem.start;
+ while (instructionCount > ByteUtils.MAX_U2) {
+ ++numberOfOverflows;
+ instructionCount -= ByteUtils.MAX_U2;
+ }
+ }
+ return numberOfOverflows;
+ }
+
+ // Note: this algorithm should be aligned with JumboStringRewriter.rewriteSplitTryOffsets.
+ private List<TryItem> splitOverflowingRanges(
+ List<TryItem> tryItems, List<DexInstruction> dexInstructions) {
+ // The fast path is that there will not be any overflows.
+ int overflows = numberOfOverflowingRanges(tryItems);
+ if (overflows == 0) {
+ return tryItems;
+ }
+ // The overflow may not fall on an instruction header, so we add a single entry just in case.
+ // Multiple try items overflowing is unlikely so that just causes reallocating the backing.
+ int tentativeCapacity = tryItems.size() + overflows + 1;
+ ArrayList<TryItem> splitTryItems = new ArrayList<>(tentativeCapacity);
+ for (TryItem tryItem : tryItems) {
+ if (tryItem.end - tryItem.start <= ByteUtils.MAX_U2) {
+ splitTryItems.add(tryItem);
+ continue;
+ }
+ final CatchHandlers<BasicBlock> handlers = tryItem.handlers;
+ final int end = tryItem.end;
+ // The iteration is based on the start offset advancing on each split.
+ int start = tryItem.start;
+ while (end - start > ByteUtils.MAX_U2) {
+ // Find a new end that does not overflow the U2 limit on the delta.
+ // It must be on an instruction offset so scan backwards in the block to find one.
+ int maxOffset = start + ByteUtils.MAX_U2;
+ assert maxOffset < end;
+ int intermediateEnd = -1;
+ for (int i = dexInstructions.size() - 1; i >= 0; i--) {
+ DexInstruction instruction = dexInstructions.get(i);
+ if (instruction.getOffset() <= maxOffset) {
+ intermediateEnd = instruction.getOffset();
+ break;
+ }
+ }
+ if (intermediateEnd <= start) {
+ throw new Unreachable("Unexpected try-catch handler end point: " + intermediateEnd);
+ }
+ splitTryItems.add(new TryItem(handlers, start, intermediateEnd));
+ start = intermediateEnd;
+ }
+ assert start < end;
+ splitTryItems.add(new TryItem(handlers, start, end));
+ }
+ assert splitTryItems.size() >= tryItems.size() + overflows;
+ return splitTryItems;
+ }
+
private int trimEnd(BasicBlock block) {
// Trim the range end for non-throwing instructions when end has been computed.
List<com.android.tools.r8.ir.code.Instruction> instructions = block.getInstructions();
diff --git a/src/main/java/com/android/tools/r8/lightir/ByteUtils.java b/src/main/java/com/android/tools/r8/lightir/ByteUtils.java
index 5c51959..ad1169d 100644
--- a/src/main/java/com/android/tools/r8/lightir/ByteUtils.java
+++ b/src/main/java/com/android/tools/r8/lightir/ByteUtils.java
@@ -8,6 +8,9 @@
/** Simple utilities for byte encodings. */
public class ByteUtils {
+ public static final int MAX_U1 = 0xFF;
+ public static final int MAX_U2 = 0xFFFF;
+
public static boolean isU1(int value) {
return (0 <= value) && (value <= 0xFF);
}
diff --git a/src/test/java/com/android/tools/r8/dex/TryCatchRangeOverflowTest.java b/src/test/java/com/android/tools/r8/dex/TryCatchRangeOverflowTest.java
new file mode 100644
index 0000000..fb5f1b1
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/dex/TryCatchRangeOverflowTest.java
@@ -0,0 +1,195 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.dex;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import com.android.tools.r8.D8TestBuilder;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.graph.DebugLocalInfo;
+import com.android.tools.r8.graph.DexCode.Try;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
+import com.android.tools.r8.ir.code.Add;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.NumericType;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import java.util.Arrays;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+// Regression test for b/297320921
+@RunWith(Parameterized.class)
+public class TryCatchRangeOverflowTest extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withDefaultDexRuntime().withApiLevel(AndroidApiLevel.B).build();
+ }
+
+ public TryCatchRangeOverflowTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ // Each add/2addr instruction has size 1, so we add have as many instruction minus some padding
+ // to make room for the instructions before and after but still in the same block.
+ // Notice that this value may change if the generated code by the compiler changes. It must then
+ // be updated to the precise limit again so that the test for jumbo-string exactly hits the
+ // crossing point.
+ private final int PADDING = 33;
+ private final int UNSPLIT_LIMIT = 0xFFFF - PADDING;
+ private final int SPLIT_2_LIMIT = 0xFFFF * 2 - PADDING;
+
+ @Test
+ public void testWithinU2() throws Exception {
+ parameters.assumeDexRuntime();
+ int addCount = UNSPLIT_LIMIT;
+ compile(addCount)
+ .run(parameters.getRuntime(), TestClass.class)
+ .assertSuccessWithOutputLines("" + addCount)
+ .inspect(inspector -> checkTryCatchHandlers(1, inspector));
+ }
+
+ @Test
+ public void testJumboExceedsU2() throws Exception {
+ parameters.assumeDexRuntime();
+ int addCount = UNSPLIT_LIMIT;
+ compile(addCount)
+ .addOptionsModification(o -> o.testing.forceJumboStringProcessing = true)
+ .run(parameters.getRuntime(), TestClass.class)
+ .assertSuccessWithOutputLines("" + addCount)
+ .inspect(inspector -> checkTryCatchHandlers(2, inspector));
+ }
+
+ @Test
+ public void testExceedsU2() throws Exception {
+ parameters.assumeDexRuntime();
+ // Test with a few values above the limit.
+ for (int addCount : Arrays.asList(UNSPLIT_LIMIT + 1, UNSPLIT_LIMIT + 2, UNSPLIT_LIMIT + 100)) {
+ compile(addCount)
+ .run(parameters.getRuntime(), TestClass.class)
+ .assertSuccessWithOutputLines("" + addCount)
+ .inspect(inspector -> checkTryCatchHandlers(2, inspector));
+ }
+ }
+
+ @Test
+ public void testWithinU2x2() throws Exception {
+ parameters.assumeDexRuntime();
+ int addCount = SPLIT_2_LIMIT;
+ compile(addCount)
+ .run(parameters.getRuntime(), TestClass.class)
+ .assertSuccessWithOutputLines("" + addCount)
+ .inspect(inspector -> checkTryCatchHandlers(2, inspector));
+ }
+
+ @Test
+ public void testJumboExceedsU2x2() throws Exception {
+ parameters.assumeDexRuntime();
+ int addCount = SPLIT_2_LIMIT;
+ compile(addCount)
+ .addOptionsModification(o -> o.testing.forceJumboStringProcessing = true)
+ .run(parameters.getRuntime(), TestClass.class)
+ .assertSuccessWithOutputLines("" + addCount)
+ .inspect(inspector -> checkTryCatchHandlers(3, inspector));
+ }
+
+ @Test
+ public void testExceedsU2x2() throws Exception {
+ parameters.assumeDexRuntime();
+ int addCount = SPLIT_2_LIMIT + 1;
+ compile(addCount)
+ .run(parameters.getRuntime(), TestClass.class)
+ .assertSuccessWithOutputLines("" + addCount)
+ .inspect(inspector -> checkTryCatchHandlers(3, inspector));
+ }
+
+ private D8TestBuilder compile(int addCount) throws Exception {
+ return testForD8(Backend.DEX)
+ .addProgramClasses(TestClass.class)
+ .addOptionsModification(
+ o ->
+ o.testing.irModifier =
+ (code, appView) -> amendCodeWithAddInstructions(addCount, code))
+ .setMinApi(parameters);
+ }
+
+ private static void amendCodeWithAddInstructions(int addCount, IRCode code) {
+ if (!code.context().getReference().qualifiedName().endsWith("main")) {
+ return;
+ }
+ InstructionListIterator it = code.instructionListIterator();
+ while (it.hasNext()) {
+ Instruction instruction = it.next();
+ if (instruction.isAdd()) {
+ TypeElement outType = instruction.getOutType();
+ DebugLocalInfo localInfo = instruction.getLocalInfo();
+ // Create the last value which will replace the users of the original value in the
+ // continuations.
+ Value newLastValue = code.createValue(outType, localInfo);
+ instruction.outValue().replaceUsers(newLastValue);
+
+ Add add = instruction.asAdd();
+ NumericType numericType = add.getNumericType();
+ assert add.rightValue().isConstNumber();
+ for (int i = 1; i < addCount; i++) {
+ Value dest = i == addCount - 1 ? newLastValue : code.createValue(outType, localInfo);
+ Add newAdd = new Add(numericType, dest, add.outValue(), add.rightValue());
+ add.outValue().addDebugLocalEnd(newAdd);
+ newAdd.setPosition(add.getPosition());
+ it.add(newAdd);
+ add = newAdd;
+ }
+ return;
+ }
+ }
+ fail("Expected to find an Add instruction.");
+ }
+
+ private static void checkTryCatchHandlers(int tryCount, CodeInspector inspector)
+ throws NoSuchMethodException {
+
+ MethodSubject main = inspector.method(TestClass.class.getMethod("main", String[].class));
+ Try[] tries = main.getMethod().getCode().asDexCode().tries;
+ assertEquals(Arrays.toString(tries), tryCount, tries.length);
+ }
+
+ static class TestClass {
+
+ public static void main(String[] args) {
+ int i = 0;
+ try {
+ String str;
+ int len = args.length;
+ if (len == 0) {
+ str = "";
+ } else if (len == 1 /* Using a constant 1 here causes the add to be an add/2addr */) {
+ str = "Strings might become jumbos";
+ } else if (len % 2 == 0) {
+ str = "We need 4";
+ } else {
+ str = "to ensure overflow.";
+ }
+ i = str.length();
+ ++i; // repeated count number of times.
+ i += args[0].length();
+ } catch (Throwable e) {
+ System.out.println(i);
+ return;
+ }
+ System.out.println("unexpected i " + i);
+ }
+ }
+}