User LinearFlowIterator in outliner to outline inline positions

Consider the following code:

public static void foo() {
  inline();
  bar();
}

public static void inlinee() {
  baz();
}

public static void bar() {
  ...
}

public static void baz() {
  ...
}

When we inline inlinee into foo we build we add a simple goto block for
the instruction to invoke baz(), which is perfectly fine to outline.

This CL changes the view of an outline from a single block to a linear
flow possibly spanning multiple blocks to allow inlined instructions to
be included in outlines.


Bug: 204749490
Change-Id: Ibfaa5931beee63e38485de163970bcf6d249e7d0
diff --git a/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
index b5aa980..cb20fd4 100644
--- a/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/LinearFlowInstructionListIterator.java
@@ -48,6 +48,10 @@
     return seenBlocks.contains(basicBlock);
   }
 
+  public Set<BasicBlock> getSeenBlocks() {
+    return seenBlocks;
+  }
+
   @Override
   public void replaceCurrentInstruction(Instruction newInstruction, Set<Value> affectedValues) {
     currentBlockIterator.replaceCurrentInstruction(newInstruction, affectedValues);
@@ -199,6 +203,7 @@
       if (!isLinearEdge(target, candidate)) {
         break;
       }
+      seenBlocks.add(target);
       target = candidate;
     }
     currentBlock = target;
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index b79ca64..3aefd80 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -1412,7 +1412,7 @@
 
     previous = printMethod(code, "IR after interface method rewriting (SSA)", previous);
 
-    // TODO(b/140766440): an ideal solution would be puttting CodeOptimization for this into
+    // TODO(b/140766440): an ideal solution would be putting CodeOptimization for this into
     //  the list for primary processing only.
     outliner.collectOutlineSites(code, timing);
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/OutlinerImpl.java b/src/main/java/com/android/tools/r8/ir/optimize/OutlinerImpl.java
index 0dc196e..68111d4 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/OutlinerImpl.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/OutlinerImpl.java
@@ -43,6 +43,7 @@
 import com.android.tools.r8.ir.code.Invoke.Type;
 import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.InvokeStatic;
+import com.android.tools.r8.ir.code.LinearFlowInstructionListIterator;
 import com.android.tools.r8.ir.code.Mul;
 import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.code.NumericType;
@@ -76,7 +77,9 @@
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -86,6 +89,7 @@
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Objects;
+import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
@@ -764,9 +768,8 @@
   abstract private class OutlineSpotter {
 
     final ProgramMethod method;
-    final BasicBlock block;
-    // instructionArrayCache is block.getInstructions() copied to an ArrayList.
-    private List<Instruction> instructionArrayCache = null;
+    final IRCode irCode;
+    final List<Instruction> currentCandidateInstructions;
 
     int start;
     int index;
@@ -780,32 +783,17 @@
     int returnValueUniqueUsersLeft;
     int pendingNewInstanceIndex = -1;
 
-    OutlineSpotter(ProgramMethod method, BasicBlock block) {
+    OutlineSpotter(
+        ProgramMethod method, IRCode irCode, List<Instruction> currentCandidateInstructions) {
       this.method = method;
-      this.block = block;
+      this.irCode = irCode;
+      this.currentCandidateInstructions = currentCandidateInstructions;
       reset(0);
     }
 
-    protected List<Instruction> getInstructionArray() {
-      if (instructionArrayCache == null) {
-        instructionArrayCache = new ArrayList<>(block.getInstructions());
-      }
-      return instructionArrayCache;
-    }
-
-    // Call this before modifying block.getInstructions().
-    protected void invalidateInstructionArray() {
-      instructionArrayCache = null;
-    }
-
     protected void process() {
-      List<Instruction> instructions;
-      for (;;) {
-        instructions = getInstructionArray(); // ProcessInstruction may have invalidated it.
-        if (index >= instructions.size()) {
-          break;
-        }
-        processInstruction(instructions.get(index));
+      while (index < currentCandidateInstructions.size()) {
+        processInstruction(currentCandidateInstructions.get(index));
       }
     }
 
@@ -944,10 +932,9 @@
         assert index > 0;
         int offset = 0;
         Instruction previous;
-        List<Instruction> instructions = getInstructionArray();
         do {
           offset++;
-          previous = instructions.get(index - offset);
+          previous = currentCandidateInstructions.get(index - offset);
         } while (previous.isConstInstruction());
         if (!previous.isNewInstance()
             || invoke != previous.asNewInstance().getUniqueConstructorInvoke(dexItemFactory)) {
@@ -1137,8 +1124,7 @@
     protected abstract void handle(int start, int end, Outline outline);
 
     private void candidate(int start, int index) {
-      List<Instruction> instructions = getInstructionArray();
-      assert !instructions.get(start).isConstInstruction();
+      assert !currentCandidateInstructions.get(start).isConstInstruction();
 
       if (pendingNewInstanceIndex != -1) {
         if (pendingNewInstanceIndex == start) {
@@ -1151,7 +1137,7 @@
 
       // Back out of any const instructions ending this candidate.
       int end = index;
-      while (instructions.get(end - 1).isConstInstruction()) {
+      while (currentCandidateInstructions.get(end - 1).isConstInstruction()) {
         end--;
       }
 
@@ -1162,7 +1148,8 @@
       }
 
       Outline outline =
-          new Outline(instructions, argumentTypes, argumentsMap, returnType, start, end);
+          new Outline(
+              currentCandidateInstructions, argumentTypes, argumentsMap, returnType, start, end);
       handle(start, end, outline);
 
       // Start a new candidate search from the next instruction after this outline.
@@ -1192,8 +1179,11 @@
     private final List<Outline> outlinesForMethod;
 
     OutlineMethodIdentifier(
-        ProgramMethod method, BasicBlock block, List<Outline> outlinesForMethod) {
-      super(method, block);
+        ProgramMethod method,
+        IRCode irCode,
+        List<Instruction> currentCandidateInstructions,
+        List<Outline> outlinesForMethod) {
+      super(method, irCode, currentCandidateInstructions);
       this.outlinesForMethod = outlinesForMethod;
     }
 
@@ -1205,8 +1195,9 @@
 
   private class OutlineSiteIdentifier extends OutlineSpotter {
 
-    OutlineSiteIdentifier(ProgramMethod method, BasicBlock block) {
-      super(method, block);
+    OutlineSiteIdentifier(
+        ProgramMethod method, IRCode irCode, List<Instruction> currentCandidateInstructions) {
+      super(method, irCode, currentCandidateInstructions);
     }
 
     @Override
@@ -1221,19 +1212,19 @@
   private class OutlineRewriter extends OutlineSpotter {
 
     private final IRCode code;
-    private final ListIterator<BasicBlock> blocksIterator;
-    private final List<Integer> toRemove;
+    private final Set<Instruction> toRemove;
+    private final Set<Instruction> invokesToOutlineMethods;
     int argumentsMapIndex;
 
     OutlineRewriter(
         IRCode code,
-        ListIterator<BasicBlock> blocksIterator,
-        BasicBlock block,
-        List<Integer> toRemove) {
-      super(code.context(), block);
+        List<Instruction> currentCandidateInstructions,
+        Set<Instruction> toRemove,
+        Set<Instruction> invokesToOutlineMethods) {
+      super(code.context(), code, currentCandidateInstructions);
       this.code = code;
-      this.blocksIterator = blocksIterator;
       this.toRemove = toRemove;
+      this.invokesToOutlineMethods = invokesToOutlineMethods;
     }
 
     @Override
@@ -1251,12 +1242,12 @@
                 // We set the line number to 0 here and rely on the LineNumberOptimizer to
                 // set a new disjoint line.
                 .setLine(0);
-
+        Instruction lastInstruction = null;
+        Position position = Position.none();
         { // Scope for 'instructions'.
-          List<Instruction> instructions = getInstructionArray();
           int outlinePositionIndex = 0;
           for (int i = start; i < end; i++) {
-            Instruction current = instructions.get(i);
+            Instruction current = currentCandidateInstructions.get(i);
             if (current.isConstInstruction()) {
               // Leave any const instructions.
               continue;
@@ -1281,28 +1272,25 @@
             // The invoke of the outline method will be placed at the last instruction index,
             // so don't mark that for removal.
             if (i < end - 1) {
-              toRemove.add(i);
+              toRemove.add(current);
             }
+            lastInstruction = current;
           }
         }
+        assert lastInstruction != null;
         assert outlineMethod.proto.shorty.toString().length() - 1 == in.size();
         if (returnValue != null && !returnValue.isUsed()) {
           returnValue = null;
         }
-        Position newPosition = positionBuilder.build();
         Invoke outlineInvoke = new InvokeStatic(outlineMethod, returnValue, in);
-        outlineInvoke.setBlock(block);
-        outlineInvoke.setPosition(newPosition);
-        InstructionListIterator endIterator = block.listIterator(code, end - 1);
-        Instruction instructionBeforeEnd = endIterator.next();
-        invalidateInstructionArray(); // Because we're about to modify the original linked list.
-        instructionBeforeEnd.clearBlock();
+        outlineInvoke.setBlock(lastInstruction.getBlock());
+        outlineInvoke.setPosition(positionBuilder.build());
+        InstructionListIterator endIterator =
+            lastInstruction.getBlock().listIterator(code, lastInstruction);
+        Instruction instructionBeforeEnd = endIterator.previous();
+        assert instructionBeforeEnd == lastInstruction;
         endIterator.set(outlineInvoke); // Replaces instructionBeforeEnd.
-        if (block.hasCatchHandlers()) {
-          // If the inserted invoke is inserted in a block with handlers, split the block after
-          // the inserted invoke.
-          endIterator.split(code, blocksIterator);
-        }
+        invokesToOutlineMethods.add(outlineInvoke);
       }
     }
 
@@ -1426,20 +1414,71 @@
 
     timing.begin("Collect outlines");
     List<Outline> outlinesForMethod = new ArrayList<>();
-    for (BasicBlock block : code.blocks) {
-      new OutlineMethodIdentifier(context, block, outlinesForMethod).process();
-    }
+    getInstructions(
+        appView,
+        code,
+        instructions ->
+            new OutlineMethodIdentifier(context, code, instructions, outlinesForMethod).process());
     outlineCollection.set(appView, context, outlinesForMethod);
     timing.end();
   }
 
+  public static void getInstructions(
+      AppView<?> appView, IRCode code, Consumer<List<Instruction>> consumer) {
+    int maxNumberOfInstructionsToBeConsidered =
+        appView.options().outline.maxNumberOfInstructionsToBeConsidered;
+    int minSize = appView.options().outline.minSize;
+    Set<BasicBlock> seenBlocks = Sets.newIdentityHashSet();
+    for (BasicBlock block : code.blocks) {
+      if (seenBlocks.add(block)) {
+        ImmutableList.Builder<Instruction> builder = ImmutableList.builder();
+        LinearFlowInstructionListIterator instructionIterator =
+            new LinearFlowInstructionListIterator(code, block);
+        // Maintaining the last seen block ensure that we always consider all instructions in a
+        // block before adding it to the seen set.
+        BasicBlock lastSeenBlock = block;
+        int counter = 0;
+        boolean sawLinearFlowWithCatchHandlers = false;
+        while (instructionIterator.hasNext()) {
+          Instruction instruction = instructionIterator.next();
+          // Disregard linear flow when there are catch handlers
+          if (instruction.getBlock() != block
+              && (block.hasCatchHandlers() || instruction.getBlock().hasCatchHandlers())) {
+            lastSeenBlock = instruction.getBlock();
+            sawLinearFlowWithCatchHandlers = true;
+            break;
+          }
+          builder.add(instruction);
+          counter++;
+          if (counter > maxNumberOfInstructionsToBeConsidered
+              && instruction.getBlock() != lastSeenBlock) {
+            // Ensure we only break on whole blocks.
+            break;
+          }
+          lastSeenBlock = instruction.getBlock();
+        }
+        seenBlocks.addAll(instructionIterator.getSeenBlocks());
+        if (sawLinearFlowWithCatchHandlers) {
+          assert lastSeenBlock != block;
+          // Remove the last seen block since we just visited the first instruction in that block
+          // and terminated without adding it.
+          seenBlocks.remove(lastSeenBlock);
+        }
+        if (counter >= minSize) {
+          consumer.accept(builder.build());
+        }
+      }
+    }
+  }
+
   public void identifyOutlineSites(IRCode code) {
     ProgramMethod context = code.context();
     assert !context.getDefinition().getCode().isOutlineCode();
     assert !ClassToFeatureSplitMap.isInFeature(context.getHolder(), appView);
-    for (BasicBlock block : code.blocks) {
-      new OutlineSiteIdentifier(context, block).process();
-    }
+    getInstructions(
+        appView,
+        code,
+        instructions -> new OutlineSiteIdentifier(context, code, instructions).process());
   }
 
   public ProgramMethodSet selectMethodsForOutlining() {
@@ -1519,13 +1558,33 @@
   }
 
   public void applyOutliningCandidate(IRCode code) {
-    assert !code.method().getCode().isOutlineCode();
-    ListIterator<BasicBlock> blocksIterator = code.listIterator();
-    while (blocksIterator.hasNext()) {
-      BasicBlock block = blocksIterator.next();
-      List<Integer> toRemove = new ArrayList<>();
-      new OutlineRewriter(code, blocksIterator, block, toRemove).process();
-      block.removeInstructions(toRemove);
+    assert !code.context().getDefinition().getCode().isOutlineCode();
+    Set<Instruction> toRemove = Sets.newIdentityHashSet();
+    Set<Instruction> invokesToOutlineMethods = Sets.newIdentityHashSet();
+    getInstructions(
+        appView,
+        code,
+        instructions ->
+            new OutlineRewriter(code, instructions, toRemove, invokesToOutlineMethods).process());
+    if (!toRemove.isEmpty()) {
+      assert !invokesToOutlineMethods.isEmpty();
+      // Scan over the entire code to remove outline instructions.
+      ListIterator<BasicBlock> blocksIterator = code.listIterator();
+      while (blocksIterator.hasNext()) {
+        BasicBlock block = blocksIterator.next();
+        InstructionListIterator instructionListIterator = block.listIterator(code);
+        instructionListIterator.forEachRemaining(
+            instruction -> {
+              if (toRemove.contains(instruction)) {
+                instructionListIterator.removeInstructionIgnoreOutValue();
+              } else if (invokesToOutlineMethods.contains(instruction)
+                  && block.hasCatchHandlers()) {
+                // If the inserted invoke is inserted in a block with handlers, split the block
+                // after the inserted invoke.
+                instructionListIterator.split(code, blocksIterator);
+              }
+            });
+      }
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 4f8efb2..031e055 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -1206,6 +1206,7 @@
     public int minSize = 3;
     public int maxSize = 99;
     public int threshold = 20;
+    public int maxNumberOfInstructionsToBeConsidered = 100;
   }
 
   public static class KotlinOptimizationOptions {
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/outliner/OutlineWithInlineMappingInformationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/outliner/OutlineWithInlineMappingInformationTest.java
new file mode 100644
index 0000000..6f343e9
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/outliner/OutlineWithInlineMappingInformationTest.java
@@ -0,0 +1,158 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.optimize.outliner;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.naming.retrace.StackTrace;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.codeinspector.HorizontallyMergedClassesInspector;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+
+@RunWith(Parameterized.class)
+public class OutlineWithInlineMappingInformationTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameter(1)
+  public boolean throwInFirstOutline;
+
+  @Parameter(2)
+  public boolean throwOnFirstCall;
+
+  @Parameterized.Parameters(name = "{0}, throwInFirstOutline: {1}, throwOnFirstCall: {2}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        getTestParameters().withAllRuntimesAndApiLevels().build(),
+        BooleanUtils.values(),
+        BooleanUtils.values());
+  }
+
+  StackTrace expectedStackTrace;
+
+  @Before
+  public void setup() throws Exception {
+    expectedStackTrace =
+        testForRuntime(parameters)
+            .addProgramClasses(TestClass.class, TestClass2.class, Greeter.class)
+            .run(
+                parameters.getRuntime(),
+                TestClass.class,
+                throwInFirstOutline ? "0" : "1",
+                throwOnFirstCall ? "0" : "1")
+            .assertFailureWithErrorThatThrows(ArrayStoreException.class)
+            .getStackTrace();
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClasses(TestClass.class, TestClass2.class, Greeter.class)
+        .addKeepMainRule(TestClass.class)
+        .addOptionsModification(
+            options -> {
+              options.outline.threshold = 2;
+              options.outline.minSize = 2;
+            })
+        .addKeepAttributeLineNumberTable()
+        .addKeepAttributeSourceFile()
+        .enableNoHorizontalClassMergingAnnotations()
+        .noHorizontalClassMergingOfSynthetics()
+        .addHorizontallyMergedClassesInspector(
+            HorizontallyMergedClassesInspector::assertNoClassesMerged)
+        .enableInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .enableExperimentalMapFileVersion()
+        .compile()
+        .run(
+            parameters.getRuntime(),
+            TestClass.class,
+            throwInFirstOutline ? "0" : "1",
+            throwOnFirstCall ? "0" : "1")
+        .assertFailureWithErrorThatThrows(ArrayStoreException.class)
+        .inspectStackTrace(
+            (stackTrace, inspector) -> {
+              // Two outlines are created, one for
+              //   Greeter.throwExceptionFirst();
+              //   Greeter.throwExceptionSecond();
+              // and one for
+              //   new ArrayStoreException("Foo")
+              assertEquals(5, inspector.allClasses().size());
+              if (throwInFirstOutline ^ throwOnFirstCall) {
+                // TODO(b/204643407): Should always be equal.
+                assertNotEquals(expectedStackTrace, stackTrace);
+              } else {
+                assertEquals(expectedStackTrace, stackTrace);
+              }
+            });
+  }
+
+  @NoHorizontalClassMerging
+  static class TestClass {
+
+    public static boolean shouldThrowInGreeter;
+    public static boolean throwOnFirst;
+
+    public static void main(String... args) {
+      shouldThrowInGreeter = args[0].equals("0");
+      throwOnFirst = args[1].equals("0");
+      greet();
+      shouldThrowInGreeter = true;
+      TestClass2.greet();
+    }
+
+    @NeverInline
+    static void greet() {
+      Greeter.throwExceptionFirst();
+      inlinee();
+    }
+
+    static void inlinee() {
+      Greeter.throwExceptionSecond();
+    }
+  }
+
+  @NoHorizontalClassMerging
+  static class TestClass2 {
+
+    @NeverInline
+    static void greet() {
+      // Keep on same line
+      inlinee(); Greeter.throwExceptionSecond();
+    }
+
+    static void inlinee() {
+      Greeter.throwExceptionFirst();
+    }
+  }
+
+  @NoHorizontalClassMerging
+  public static class Greeter {
+
+    @NeverInline
+    public static void throwExceptionFirst() {
+      if (TestClass.shouldThrowInGreeter && TestClass.throwOnFirst) {
+        throw new ArrayStoreException("Foo");
+      }
+    }
+
+    @NeverInline
+    public static void throwExceptionSecond() {
+      if (TestClass.shouldThrowInGreeter && !TestClass.throwOnFirst) {
+        throw new ArrayStoreException("Foo");
+      }
+    }
+  }
+}