Split try-catch ranges in case of u2 overflow

Bug: b/297320921
Change-Id: I46c16ec116ba4ba70d309ebb34bc9ae4a75740bc
diff --git a/src/main/java/com/android/tools/r8/dex/JumboStringRewriter.java b/src/main/java/com/android/tools/r8/dex/JumboStringRewriter.java
index 785158d..8c11532 100644
--- a/src/main/java/com/android/tools/r8/dex/JumboStringRewriter.java
+++ b/src/main/java/com/android/tools/r8/dex/JumboStringRewriter.java
@@ -29,6 +29,7 @@
 import com.android.tools.r8.dex.code.DexInstruction;
 import com.android.tools.r8.dex.code.DexNop;
 import com.android.tools.r8.dex.code.DexSwitchPayload;
+import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.DexCode;
 import com.android.tools.r8.graph.DexCode.Try;
 import com.android.tools.r8.graph.DexCode.TryHandler;
@@ -41,6 +42,7 @@
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.lightir.ByteUtils;
 import com.google.common.collect.Lists;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
@@ -190,12 +192,56 @@
     for (int i = 0; i < code.tries.length; i++) {
       Try theTry = code.tries[i];
       TryTargets targets = tryTargets.get(theTry);
-      result[i] = new Try(targets.getStartOffset(), targets.getStartToEndDelta(), -1);
+      int startToEndDelta = targets.getStartToEndDelta();
+      if (startToEndDelta > ByteUtils.MAX_U2) {
+        return rewriteSplitTryOffsets(code);
+      }
+      result[i] = new Try(targets.getStartOffset(), startToEndDelta, -1);
       result[i].handlerIndex = theTry.handlerIndex;
     }
     return result;
   }
 
+  // Note: this algorithm should be aligned with DexBuilder.splitOverflowingRanges.
+  private Try[] rewriteSplitTryOffsets(DexCode code) {
+    // It is unlikely we have 10 overflows (unlikely we have any to begin with).
+    int tentativeCapacity = code.tries.length + 10;
+    List<Try> result = new ArrayList<>(tentativeCapacity);
+    for (Try theTry : code.tries) {
+      TryTargets targets = tryTargets.get(theTry);
+      int startToEndDelta = targets.getStartToEndDelta();
+      int start = targets.getStartOffset();
+      while (startToEndDelta > ByteUtils.MAX_U2) {
+        // Find instruction offset under limit.
+        int maxOffset = start + ByteUtils.MAX_U2;
+        int intermediateEnd = -1;
+        for (int i = code.instructions.length - 1; i >= 0; i--) {
+          DexInstruction instruction = code.instructions[i];
+          // Note that the instructions have been expanded, so getOffset is the rewritten offset.
+          if (instruction.getOffset() <= maxOffset) {
+            intermediateEnd = instruction.getOffset();
+            break;
+          }
+        }
+        if (intermediateEnd <= start) {
+          throw new Unreachable("Unexpected try-catch handler end point: " + intermediateEnd);
+        }
+        int intermediateDelta = intermediateEnd - start;
+        Try splitTry = new Try(start, intermediateDelta, -1);
+        splitTry.handlerIndex = theTry.handlerIndex;
+        result.add(splitTry);
+        start = intermediateEnd;
+        startToEndDelta -= intermediateDelta;
+      }
+      assert startToEndDelta > 0;
+      Try rewrittenTry = new Try(start, startToEndDelta, -1);
+      rewrittenTry.handlerIndex = theTry.handlerIndex;
+      result.add(rewrittenTry);
+    }
+    assert result.size() > code.tries.length;
+    return result.toArray(Try.EMPTY_ARRAY);
+  }
+
   private TryHandler[] rewriteHandlerOffsets() {
     DexCode code = method.getCode().asDexCode();
     TryHandler[] result = new TryHandler[code.handlers.length];
diff --git a/src/main/java/com/android/tools/r8/graph/DexCode.java b/src/main/java/com/android/tools/r8/graph/DexCode.java
index 1869581..1904080 100644
--- a/src/main/java/com/android/tools/r8/graph/DexCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DexCode.java
@@ -30,6 +30,7 @@
 import com.android.tools.r8.ir.conversion.MethodConversionOptions.MutableMethodConversionOptions;
 import com.android.tools.r8.ir.conversion.MethodConversionOptions.ThrowingMethodConversionOptions;
 import com.android.tools.r8.naming.ClassNameMapper;
+import com.android.tools.r8.lightir.ByteUtils;
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.utils.ArrayUtils;
 import com.android.tools.r8.utils.StringUtils;
@@ -782,6 +783,7 @@
       this.instructionCount = instructionCount;
       this.handlerOffset = handlerOffset;
       this.handlerIndex = NO_INDEX;
+      assert ByteUtils.isU2(instructionCount);
     }
 
     @Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/Value.java b/src/main/java/com/android/tools/r8/ir/code/Value.java
index b0eb2bd..0f0e749 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Value.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Value.java
@@ -708,7 +708,10 @@
 
   public boolean needsRegister() {
     assert needsRegister >= 0;
-    assert !hasUsersInfo() || (needsRegister > 0) == internalComputeNeedsRegister();
+    // This has quadratic behavior so don't check for large user sets.
+    assert !hasUsersInfo()
+        || numberOfAllUsers() > 100
+        || (needsRegister > 0) == internalComputeNeedsRegister();
     return needsRegister > 0;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/DexBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/DexBuilder.java
index ec93ef2..9c648a7 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/DexBuilder.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/DexBuilder.java
@@ -63,6 +63,7 @@
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.optimize.CodeRewriter;
 import com.android.tools.r8.ir.regalloc.RegisterAllocator;
+import com.android.tools.r8.lightir.ByteUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.InternalOutputMode;
 import com.google.common.collect.BiMap;
@@ -328,7 +329,7 @@
     }
 
     // Construct try-catch info.
-    TryInfo tryInfo = computeTryInfo();
+    TryInfo tryInfo = computeTryInfo(dexInstructions);
 
     // Return the dex code.
     DexCode code =
@@ -814,11 +815,13 @@
 
   // Helpers for computing the try items and handlers.
 
-  private TryInfo computeTryInfo() {
+  private TryInfo computeTryInfo(List<DexInstruction> dexInstructions) {
     // Canonical map of handlers.
     BiMap<CatchHandlers<BasicBlock>, Integer> canonicalHandlers = HashBiMap.create();
     // Compute the list of try items and their handlers.
     List<TryItem> tryItems = computeTryItems(canonicalHandlers);
+    // Split the try items if they overflow the range limit.
+    tryItems = splitOverflowingRanges(tryItems, dexInstructions);
     // Compute handler sets before dex items which depend on the handler index.
     Try[] tries = getDexTryItems(tryItems, canonicalHandlers);
     TryHandler[] handlers = getDexTryHandlers(canonicalHandlers.inverse());
@@ -905,6 +908,65 @@
     return coalescedTryItems;
   }
 
+  private static int numberOfOverflowingRanges(List<TryItem> tryItems) {
+    int numberOfOverflows = 0;
+    for (TryItem tryItem : tryItems) {
+      int instructionCount = tryItem.end - tryItem.start;
+      while (instructionCount > ByteUtils.MAX_U2) {
+        ++numberOfOverflows;
+        instructionCount -= ByteUtils.MAX_U2;
+      }
+    }
+    return numberOfOverflows;
+  }
+
+  // Note: this algorithm should be aligned with JumboStringRewriter.rewriteSplitTryOffsets.
+  private List<TryItem> splitOverflowingRanges(
+      List<TryItem> tryItems, List<DexInstruction> dexInstructions) {
+    // The fast path is that there will not be any overflows.
+    int overflows = numberOfOverflowingRanges(tryItems);
+    if (overflows == 0) {
+      return tryItems;
+    }
+    // The overflow may not fall on an instruction header, so we add a single entry just in case.
+    // Multiple try items overflowing is unlikely so that just causes reallocating the backing.
+    int tentativeCapacity = tryItems.size() + overflows + 1;
+    ArrayList<TryItem> splitTryItems = new ArrayList<>(tentativeCapacity);
+    for (TryItem tryItem : tryItems) {
+      if (tryItem.end - tryItem.start <= ByteUtils.MAX_U2) {
+        splitTryItems.add(tryItem);
+        continue;
+      }
+      final CatchHandlers<BasicBlock> handlers = tryItem.handlers;
+      final int end = tryItem.end;
+      // The iteration is based on the start offset advancing on each split.
+      int start = tryItem.start;
+      while (end - start > ByteUtils.MAX_U2) {
+        // Find a new end that does not overflow the U2 limit on the delta.
+        // It must be on an instruction offset so scan backwards in the block to find one.
+        int maxOffset = start + ByteUtils.MAX_U2;
+        assert maxOffset < end;
+        int intermediateEnd = -1;
+        for (int i = dexInstructions.size() - 1; i >= 0; i--) {
+          DexInstruction instruction = dexInstructions.get(i);
+          if (instruction.getOffset() <= maxOffset) {
+            intermediateEnd = instruction.getOffset();
+            break;
+          }
+        }
+        if (intermediateEnd <= start) {
+          throw new Unreachable("Unexpected try-catch handler end point: " + intermediateEnd);
+        }
+        splitTryItems.add(new TryItem(handlers, start, intermediateEnd));
+        start = intermediateEnd;
+      }
+      assert start < end;
+      splitTryItems.add(new TryItem(handlers, start, end));
+    }
+    assert splitTryItems.size() >= tryItems.size() + overflows;
+    return splitTryItems;
+  }
+
   private int trimEnd(BasicBlock block) {
     // Trim the range end for non-throwing instructions when end has been computed.
     List<com.android.tools.r8.ir.code.Instruction> instructions = block.getInstructions();
diff --git a/src/main/java/com/android/tools/r8/lightir/ByteUtils.java b/src/main/java/com/android/tools/r8/lightir/ByteUtils.java
index 5c51959..ad1169d 100644
--- a/src/main/java/com/android/tools/r8/lightir/ByteUtils.java
+++ b/src/main/java/com/android/tools/r8/lightir/ByteUtils.java
@@ -8,6 +8,9 @@
 /** Simple utilities for byte encodings. */
 public class ByteUtils {
 
+  public static final int MAX_U1 = 0xFF;
+  public static final int MAX_U2 = 0xFFFF;
+
   public static boolean isU1(int value) {
     return (0 <= value) && (value <= 0xFF);
   }
diff --git a/src/test/java/com/android/tools/r8/dex/TryCatchRangeOverflowTest.java b/src/test/java/com/android/tools/r8/dex/TryCatchRangeOverflowTest.java
new file mode 100644
index 0000000..fb5f1b1
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/dex/TryCatchRangeOverflowTest.java
@@ -0,0 +1,195 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.dex;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import com.android.tools.r8.D8TestBuilder;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.graph.DebugLocalInfo;
+import com.android.tools.r8.graph.DexCode.Try;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
+import com.android.tools.r8.ir.code.Add;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.NumericType;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import java.util.Arrays;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+// Regression test for b/297320921
+@RunWith(Parameterized.class)
+public class TryCatchRangeOverflowTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDefaultDexRuntime().withApiLevel(AndroidApiLevel.B).build();
+  }
+
+  public TryCatchRangeOverflowTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  // Each add/2addr instruction has size 1, so we add have as many instruction minus some padding
+  // to make room for the instructions before and after but still in the same block.
+  // Notice that this value may change if the generated code by the compiler changes. It must then
+  // be updated to the precise limit again so that the test for jumbo-string exactly hits the
+  // crossing point.
+  private final int PADDING = 33;
+  private final int UNSPLIT_LIMIT = 0xFFFF - PADDING;
+  private final int SPLIT_2_LIMIT = 0xFFFF * 2 - PADDING;
+
+  @Test
+  public void testWithinU2() throws Exception {
+    parameters.assumeDexRuntime();
+    int addCount = UNSPLIT_LIMIT;
+    compile(addCount)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("" + addCount)
+        .inspect(inspector -> checkTryCatchHandlers(1, inspector));
+  }
+
+  @Test
+  public void testJumboExceedsU2() throws Exception {
+    parameters.assumeDexRuntime();
+    int addCount = UNSPLIT_LIMIT;
+    compile(addCount)
+        .addOptionsModification(o -> o.testing.forceJumboStringProcessing = true)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("" + addCount)
+        .inspect(inspector -> checkTryCatchHandlers(2, inspector));
+  }
+
+  @Test
+  public void testExceedsU2() throws Exception {
+    parameters.assumeDexRuntime();
+    // Test with a few values above the limit.
+    for (int addCount : Arrays.asList(UNSPLIT_LIMIT + 1, UNSPLIT_LIMIT + 2, UNSPLIT_LIMIT + 100)) {
+      compile(addCount)
+          .run(parameters.getRuntime(), TestClass.class)
+          .assertSuccessWithOutputLines("" + addCount)
+          .inspect(inspector -> checkTryCatchHandlers(2, inspector));
+    }
+  }
+
+  @Test
+  public void testWithinU2x2() throws Exception {
+    parameters.assumeDexRuntime();
+    int addCount = SPLIT_2_LIMIT;
+    compile(addCount)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("" + addCount)
+        .inspect(inspector -> checkTryCatchHandlers(2, inspector));
+  }
+
+  @Test
+  public void testJumboExceedsU2x2() throws Exception {
+    parameters.assumeDexRuntime();
+    int addCount = SPLIT_2_LIMIT;
+    compile(addCount)
+        .addOptionsModification(o -> o.testing.forceJumboStringProcessing = true)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("" + addCount)
+        .inspect(inspector -> checkTryCatchHandlers(3, inspector));
+  }
+
+  @Test
+  public void testExceedsU2x2() throws Exception {
+    parameters.assumeDexRuntime();
+    int addCount = SPLIT_2_LIMIT + 1;
+    compile(addCount)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("" + addCount)
+        .inspect(inspector -> checkTryCatchHandlers(3, inspector));
+  }
+
+  private D8TestBuilder compile(int addCount) throws Exception {
+    return testForD8(Backend.DEX)
+        .addProgramClasses(TestClass.class)
+        .addOptionsModification(
+            o ->
+                o.testing.irModifier =
+                    (code, appView) -> amendCodeWithAddInstructions(addCount, code))
+        .setMinApi(parameters);
+  }
+
+  private static void amendCodeWithAddInstructions(int addCount, IRCode code) {
+    if (!code.context().getReference().qualifiedName().endsWith("main")) {
+      return;
+    }
+    InstructionListIterator it = code.instructionListIterator();
+    while (it.hasNext()) {
+      Instruction instruction = it.next();
+      if (instruction.isAdd()) {
+        TypeElement outType = instruction.getOutType();
+        DebugLocalInfo localInfo = instruction.getLocalInfo();
+        // Create the last value which will replace the users of the original value in the
+        // continuations.
+        Value newLastValue = code.createValue(outType, localInfo);
+        instruction.outValue().replaceUsers(newLastValue);
+
+        Add add = instruction.asAdd();
+        NumericType numericType = add.getNumericType();
+        assert add.rightValue().isConstNumber();
+        for (int i = 1; i < addCount; i++) {
+          Value dest = i == addCount - 1 ? newLastValue : code.createValue(outType, localInfo);
+          Add newAdd = new Add(numericType, dest, add.outValue(), add.rightValue());
+          add.outValue().addDebugLocalEnd(newAdd);
+          newAdd.setPosition(add.getPosition());
+          it.add(newAdd);
+          add = newAdd;
+        }
+        return;
+      }
+    }
+    fail("Expected to find an Add instruction.");
+  }
+
+  private static void checkTryCatchHandlers(int tryCount, CodeInspector inspector)
+      throws NoSuchMethodException {
+
+    MethodSubject main = inspector.method(TestClass.class.getMethod("main", String[].class));
+    Try[] tries = main.getMethod().getCode().asDexCode().tries;
+    assertEquals(Arrays.toString(tries), tryCount, tries.length);
+  }
+
+  static class TestClass {
+
+    public static void main(String[] args) {
+      int i = 0;
+      try {
+        String str;
+        int len = args.length;
+        if (len == 0) {
+          str = "";
+        } else if (len == 1 /* Using a constant 1 here causes the add to be an add/2addr */) {
+          str = "Strings might become jumbos";
+        } else if (len % 2 == 0) {
+          str = "We need 4";
+        } else {
+          str = "to ensure overflow.";
+        }
+        i = str.length();
+        ++i; // repeated count number of times.
+        i += args[0].length();
+      } catch (Throwable e) {
+        System.out.println(i);
+        return;
+      }
+      System.out.println("unexpected i " + i);
+    }
+  }
+}