Fix invalid rewriting in outliner

When splitting blocks after inserting an invoke to an outline the block
after the split was not fully processed leaving some of the outlined
instructions.

Bug: b/353279141
Change-Id: Ib74460c93b1bf46ce2d1cad602b804613b9b1205
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/outliner/OutlinerImpl.java b/src/main/java/com/android/tools/r8/ir/optimize/outliner/OutlinerImpl.java
index 55a301c..3221105 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/outliner/OutlinerImpl.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/outliner/OutlinerImpl.java
@@ -35,6 +35,7 @@
 import com.android.tools.r8.ir.code.Assume;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.BasicBlock.ThrowingInfo;
+import com.android.tools.r8.ir.code.BasicBlockIterator;
 import com.android.tools.r8.ir.code.Binop;
 import com.android.tools.r8.ir.code.CatchHandlers;
 import com.android.tools.r8.ir.code.Div;
@@ -90,7 +91,6 @@
 import java.util.HashMap;
 import java.util.IdentityHashMap;
 import java.util.List;
-import java.util.ListIterator;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Objects;
@@ -1659,24 +1659,26 @@
     if (!toRemove.isEmpty()) {
       assert !invokesToOutlineMethods.isEmpty();
       // Scan over the entire code to remove outline instructions.
-      ListIterator<BasicBlock> blocksIterator = code.listIterator();
+      BasicBlockIterator blocksIterator = code.listIterator();
       while (blocksIterator.hasNext()) {
         BasicBlock block = blocksIterator.next();
         InstructionListIterator instructionListIterator = block.listIterator(code);
         instructionListIterator.forEachRemaining(
             instruction -> {
-              if (toRemove.contains(instruction)) {
+              if (toRemove.remove(instruction)) {
                 instructionListIterator.removeInstructionIgnoreOutValue();
               } else if (invokesToOutlineMethods.contains(instruction)
                   && block.hasCatchHandlers()) {
                 // If the inserted invoke is inserted in a block with handlers, split the block
                 // after the inserted invoke.
-                instructionListIterator.split(code, blocksIterator);
+                instructionListIterator.splitCopyCatchHandlers(
+                    code, blocksIterator, appView.options(), ignored -> block);
               }
             });
       }
       code.removeRedundantBlocks();
     }
+    assert toRemove.isEmpty();
     code.removeRedundantBlocks();
     assert code.isConsistentSSA(appView);
   }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/outliner/MultipleOutlinesInMethodWithExceptionHandlerTest.java b/src/test/java/com/android/tools/r8/ir/optimize/outliner/MultipleOutlinesInMethodWithExceptionHandlerTest.java
index 0ea2792..8dc6316 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/outliner/MultipleOutlinesInMethodWithExceptionHandlerTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/outliner/MultipleOutlinesInMethodWithExceptionHandlerTest.java
@@ -4,12 +4,32 @@
 
 package com.android.tools.r8.ir.optimize.outliner;
 
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.dex.code.DexAddInt;
+import com.android.tools.r8.dex.code.DexAddInt2Addr;
+import com.android.tools.r8.dex.code.DexMulInt;
+import com.android.tools.r8.dex.code.DexMulInt2Addr;
+import com.android.tools.r8.dex.code.DexReturn;
+import com.android.tools.r8.synthesis.SyntheticItemsTestUtils;
 import com.android.tools.r8.utils.StringUtils;
-import com.android.tools.r8.utils.codeinspector.AssertUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.CodeMatchers;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -29,24 +49,71 @@
 
   private static final String EXPECTED_OUTPUT = StringUtils.lines("18");
 
+  private void validateOutlining(CodeInspector inspector) {
+    // Validate that an outline of mul, mul, add has been created and called twice in m.
+    ClassSubject outlineClass =
+        inspector.clazz(SyntheticItemsTestUtils.syntheticOutlineClass(TestClass.class, 0));
+    assertThat(outlineClass, isPresent());
+    MethodSubject outline0Method =
+        outlineClass.method(
+            "int",
+            SyntheticItemsTestUtils.syntheticMethodName(),
+            ImmutableList.of("int", "int", "int", "int"));
+    assertThat(outline0Method, isPresent());
+    // Only check the content if instructions fo DEX.
+    if (parameters.isDexRuntime()) {
+      Map<Class<?>, Class<?>> map =
+          ImmutableMap.of(
+              DexMulInt2Addr.class, DexMulInt.class, DexAddInt2Addr.class, DexAddInt.class);
+      List<Class<?>> instructionClasses =
+          outline0Method
+              .streamInstructions()
+              .map(instruction -> instruction.asDexInstruction().getInstruction().getClass())
+              .map(instructionClass -> map.getOrDefault(instructionClass, instructionClass))
+              .collect(Collectors.toList());
+      assertEquals(
+          ImmutableList.of(DexMulInt.class, DexMulInt.class, DexAddInt.class, DexReturn.class),
+          instructionClasses);
+    }
+    ClassSubject classSubject = inspector.clazz(TestClass.class);
+    MethodSubject methodSubject = classSubject.uniqueMethodWithOriginalName("m");
+    List<InstructionSubject> outlineInvokes =
+        methodSubject
+            .streamInstructions()
+            .filter(CodeMatchers.isInvokeWithTarget(outline0Method))
+            .collect(Collectors.toList());
+    assertEquals(2, outlineInvokes.size());
+    // Check that both outlines invoked are covered by the catch handler.
+    methodSubject
+        .iterateTryCatches()
+        .forEachRemaining(
+            tryCatchSubject -> {
+              outlineInvokes.removeIf(
+                  instructionSubject ->
+                      tryCatchSubject
+                          .getRange()
+                          .includes(instructionSubject.getOffset(methodSubject)));
+            });
+    assertTrue(outlineInvokes.isEmpty());
+  }
+
   @Test
   public void testR8() throws Exception {
-    AssertUtils.assertFailsCompilation(
-        () ->
-            testForR8(parameters.getBackend())
-                .addInnerClasses(getClass())
-                .addKeepMainRule(TestClass.class)
-                .setMinApi(parameters)
-                .enableInliningAnnotations()
-                .addOptionsModification(
-                    options -> {
-                      // To trigger outlining.
-                      options.outline.threshold = 2;
-                      options.outline.maxSize = 3;
-                    })
-                .compile()
-                .run(parameters.getRuntime(), TestClass.class)
-                .assertSuccessWithOutput(EXPECTED_OUTPUT));
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(TestClass.class)
+        .setMinApi(parameters)
+        .enableInliningAnnotations()
+        .addOptionsModification(
+            options -> {
+              // To trigger outlining.
+              options.outline.threshold = 2;
+              options.outline.maxSize = 3;
+            })
+        .compile()
+        .inspect(this::validateOutlining)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
   static class TestClass {