Fix the loop unroller
Bug: b/225857834
Change-Id: If34b82e67262c19d438d35d4e9e333610181e503
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/NaturalIntLoopRemover.java b/src/main/java/com/android/tools/r8/ir/optimize/NaturalIntLoopRemover.java
index c82a122..0974cdd 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/NaturalIntLoopRemover.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/NaturalIntLoopRemover.java
@@ -87,6 +87,9 @@
if (!analyzeLoopExit(loopBody, comparison, builder)) {
return false;
}
+ if (!analyzePhiUses(loopBody, comparison, builder)) {
+ return false;
+ }
NaturalIntLoopWithKnowIterations loop = builder.build();
@@ -98,6 +101,42 @@
}
/**
+ * The loop unroller removes phis corresponding to the loop backjump. There are three scenarios:
+ * (1) The loop has a single exit point analyzed, phis used outside the loop are replaced by the
+ * value at the end of the loop body.
+ * (2) The phis are unused outside the loop, and they are simply removed.
+ * (3) The loop has multiple exits and the phis are used outside the loop, this would require
+ * dealing with complex merge point and postponing phis after the loop, we bail out.
+ */
+ private boolean analyzePhiUses(
+ Set<BasicBlock> loopBody, If comparison, NaturalIntLoopWithKnowIterations.Builder builder) {
+ // Check for single exit scenario.
+ Set<BasicBlock> successors = Sets.newIdentityHashSet();
+ for (BasicBlock basicBlock : loopBody) {
+ successors.addAll(basicBlock.getSuccessors());
+ }
+ successors.removeAll(loopBody);
+ if (successors.size() == 1) {
+ assert successors.iterator().next() == builder.getLoopExit();
+ return true;
+ }
+ // Check phis are unused outside the loop.
+ for (Phi phi : comparison.getBlock().getPhis()) {
+ for (Instruction use : phi.uniqueUsers()) {
+ if (!loopBody.contains(use.getBlock())) {
+ return false;
+ }
+ }
+ for (Phi phiUse : phi.uniquePhiUsers()) {
+ if (!loopBody.contains(phiUse.getBlock())) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ /**
* Verifies the loop is well formed: the comparison on the int iterator should jump to a loop exit
* on one side and to the loop body on the other side.
*/
@@ -305,6 +344,10 @@
this.loopBodyEntry = loopBodyEntry;
}
+ public BasicBlock getLoopExit() {
+ return loopExit;
+ }
+
public void setLoopBody(Set<BasicBlock> loopBody) {
this.loopBody = loopBody;
}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/loops/LoopWith1IterationsEscape.java b/src/test/java/com/android/tools/r8/ir/optimize/loops/LoopWith1IterationsEscape.java
new file mode 100644
index 0000000..2f54f61
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/loops/LoopWith1IterationsEscape.java
@@ -0,0 +1,76 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.loops;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class LoopWith1IterationsEscape extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ public LoopWith1IterationsEscape(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void testLoopRemoved() throws Exception {
+ testForR8(parameters.getBackend())
+ .setMinApi(parameters.getApiLevel())
+ .addProgramClasses(Main.class)
+ .addKeepMainRule(Main.class)
+ .addOptionsModification(options -> options.testing.enableExperimentalLoopUnrolling = true)
+ .enableInliningAnnotations()
+ .noMinification()
+ .compile()
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("end 0", "iteration", "end 1");
+ }
+
+ public static class Main {
+
+ public static void main(String[] args) {
+ loopExit();
+ loopNoExit();
+ }
+
+ @NeverInline
+ public static void loopNoExit() {
+ Object[] objects = new Object[1];
+ int i;
+ for (i = 0; i < objects.length; i++) {
+ if (System.currentTimeMillis() < 0) {
+ break;
+ }
+ System.out.println("iteration");
+ }
+ System.out.println("end " + i);
+ }
+
+ @NeverInline
+ public static void loopExit() {
+ Object[] objects = new Object[1];
+ int i;
+ for (i = 0; i < objects.length; i++) {
+ if (System.currentTimeMillis() > 0) {
+ break;
+ }
+ System.out.println("iteration");
+ }
+ System.out.println("end " + i);
+ }
+ }
+}