Extend split branch
- Support comparison against constants
- Support float comparison
Change-Id: I568ba24a0bbd66f95f1486b52355a51e973bd4e1
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 8f9378e..ebbecb0 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -30,7 +30,7 @@
import com.android.tools.r8.ir.conversion.passes.BinopRewriter;
import com.android.tools.r8.ir.conversion.passes.CommonSubexpressionElimination;
import com.android.tools.r8.ir.conversion.passes.ParentConstructorHoistingCodeRewriter;
-import com.android.tools.r8.ir.conversion.passes.SplitBranchOnKnownBoolean;
+import com.android.tools.r8.ir.conversion.passes.SplitBranch;
import com.android.tools.r8.ir.desugar.CfInstructionDesugaringCollection;
import com.android.tools.r8.ir.desugar.CovariantReturnTypeAnnotationTransformer;
import com.android.tools.r8.ir.optimize.AssertionErrorTwoArgsConstructorRewriter;
@@ -115,7 +115,7 @@
protected final InternalOptions options;
public final CodeRewriter codeRewriter;
public final CommonSubexpressionElimination commonSubexpressionElimination;
- private final SplitBranchOnKnownBoolean splitBranchOnKnownBoolean;
+ private final SplitBranch splitBranch;
public final AssertionErrorTwoArgsConstructorRewriter assertionErrorTwoArgsConstructorRewriter;
private final NaturalIntLoopRemover naturalIntLoopRemover = new NaturalIntLoopRemover();
public final MemberValuePropagation<?> memberValuePropagation;
@@ -167,7 +167,7 @@
this.options = appView.options();
this.codeRewriter = new CodeRewriter(appView);
this.commonSubexpressionElimination = new CommonSubexpressionElimination(appView);
- this.splitBranchOnKnownBoolean = new SplitBranchOnKnownBoolean(appView);
+ this.splitBranch = new SplitBranch(appView);
this.assertionErrorTwoArgsConstructorRewriter =
appView.options().desugarState.isOn()
? new AssertionErrorTwoArgsConstructorRewriter(appView)
@@ -777,7 +777,7 @@
timing.end();
}
timing.end();
- splitBranchOnKnownBoolean.run(code.context(), code, timing);
+ splitBranch.run(code.context(), code, timing);
if (options.enableRedundantConstNumberOptimization) {
timing.begin("Remove const numbers");
codeRewriter.redundantConstNumberRemoval(code);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
similarity index 68%
rename from src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java
rename to src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
index 134d2bc..9a3a024 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
@@ -9,6 +9,7 @@
import com.android.tools.r8.graph.ProgramMethod;
import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.ConstNumber;
import com.android.tools.r8.ir.code.Goto;
import com.android.tools.r8.ir.code.IRCode;
import com.android.tools.r8.ir.code.If;
@@ -23,17 +24,17 @@
import java.util.Map;
import java.util.Set;
-public class SplitBranchOnKnownBoolean extends CodeRewriterPass<AppInfo> {
+public class SplitBranch extends CodeRewriterPass<AppInfo> {
private static final boolean ALLOW_PARTIAL_REWRITE = true;
- public SplitBranchOnKnownBoolean(AppView<?> appView) {
+ public SplitBranch(AppView<?> appView) {
super(appView);
}
@Override
String getTimingId() {
- return "SplitBranchOnKnownBoolean";
+ return "SplitBranch";
}
@Override
@@ -89,17 +90,31 @@
private Map<Goto, BasicBlock> findGotosToRetarget(List<BasicBlock> candidates) {
Map<Goto, BasicBlock> newTargets = new LinkedHashMap<>();
for (BasicBlock block : candidates) {
- // We need to verify any instruction in between the if and the chain of phis is empty (we
- // could duplicate instruction, but the common case is empty).
+ // We need to verify any instruction in between the if and the chain of phis is empty or just
+ // a constant used in the If instruction (we could duplicate instruction, but the common case
+ // is empty).
// Then we can redirect any known value. This can lead to dead code.
If theIf = block.exit().asIf();
- Set<Phi> allowedPhis = getAllowedPhis(theIf.lhs().asPhi());
+ Set<Phi> allowedPhis = getAllowedPhis(nonConstNumberOperand(theIf).asPhi());
Set<Phi> foundPhis = Sets.newIdentityHashSet();
WorkList.newIdentityWorkList(block)
.process(
(current, workList) -> {
if (current.getInstructions().size() > 1) {
- return;
+ // We allow a single instruction, which is the constant used exclusively in the
+ // if. This is run before constant canonicalization.
+ if (theIf.isZeroTest()
+ || current.getInstructions().size() != 2
+ || !current.entry().isConstNumber()) {
+ return;
+ }
+ Value value = current.entry().outValue();
+ if (value.hasPhiUsers()
+ || value.uniqueUsers().size() > 1
+ || (value.uniqueUsers().size() == 1
+ && value.uniqueUsers().iterator().next() != theIf)) {
+ return;
+ }
}
if (current != block && !current.exit().isGoto()) {
return;
@@ -135,30 +150,61 @@
return newTargets;
}
+ private boolean isNumberAgainstConstNumberIf(If theIf) {
+ if (!(theIf.lhs().getType().isInt() || theIf.lhs().getType().isFloat())) {
+ return false;
+ }
+ if (theIf.isZeroTest()) {
+ return true;
+ }
+ assert theIf.lhs().getType() == theIf.rhs().getType();
+ return theIf.lhs().isConstNumber() || theIf.rhs().isConstNumber();
+ }
+
+ private Value nonConstNumberOperand(If theIf) {
+ return theIf.isZeroTest()
+ ? theIf.lhs()
+ : (theIf.lhs().isConstNumber() ? theIf.rhs() : theIf.lhs());
+ }
+
private List<BasicBlock> computeCandidates(IRCode code) {
List<BasicBlock> candidates = new ArrayList<>();
- for (BasicBlock block : ListUtils.filter(code.blocks, block -> block.entry().isIf())) {
+ for (BasicBlock block : ListUtils.filter(code.blocks, block -> block.exit().isIf())) {
If theIf = block.exit().asIf();
- if (theIf.isZeroTest()
- && theIf.lhs().getType().isInt()
- && theIf.lhs().isPhi()
- && theIf.lhs().hasSingleUniqueUser()
- && !theIf.lhs().hasPhiUsers()) {
+ if (!isNumberAgainstConstNumberIf(theIf)) {
+ continue;
+ }
+ Value nonConstNumberOperand = nonConstNumberOperand(theIf);
+ if (isNumberAgainstConstNumberIf(theIf)
+ && nonConstNumberOperand.isPhi()
+ && nonConstNumberOperand.hasSingleUniqueUser()
+ && !nonConstNumberOperand.hasPhiUsers()) {
candidates.add(block);
}
}
return candidates;
}
+ private BasicBlock targetFromCondition(If theIf, ConstNumber constForPhi) {
+ if (theIf.isZeroTest()) {
+ return theIf.targetFromCondition(constForPhi);
+ }
+ if (theIf.lhs().isConstNumber()) {
+ return theIf.targetFromCondition(
+ theIf.lhs().getConstInstruction().asConstNumber(), constForPhi);
+ }
+ assert theIf.rhs().isConstNumber();
+ return theIf.targetFromCondition(
+ constForPhi, theIf.rhs().getConstInstruction().asConstNumber());
+ }
+
private void recordNewTargetForGoto(
Value value, BasicBlock basicBlock, If theIf, Map<Goto, BasicBlock> newTargets) {
// The GoTo at the end of basicBlock should target the phiBlock, and should target instead
// the correct if destination.
assert basicBlock.exit().isGoto();
assert value.isConstant();
- assert value.getType().isInt();
- assert theIf.isZeroTest();
- BasicBlock newTarget = theIf.targetFromCondition(value.getConstInstruction().asConstNumber());
+ BasicBlock newTarget = targetFromCondition(theIf, value.getConstInstruction().asConstNumber());
Goto aGoto = basicBlock.exit().asGoto();
newTargets.put(aGoto, newTarget);
}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondCstTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondCstTest.java
new file mode 100644
index 0000000..7c5ed88
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondCstTest.java
@@ -0,0 +1,159 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.AlwaysInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DoubleDiamondCstTest extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ public DoubleDiamondCstTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void test() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .enableInliningAnnotations()
+ .enableAlwaysInliningAnnotations()
+ .setMinApi(parameters)
+ .compile()
+ .inspect(this::inspect)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines(
+ "1", "5", "5", "1", "1", "5", "5", "1", "5", "5", "1", "1", "1", "5");
+ }
+
+ private void inspect(CodeInspector inspector) {
+ for (FoundMethodSubject method : inspector.clazz(Main.class).allMethods()) {
+ if (!method.getOriginalName().equals("main")) {
+ long count = method.streamInstructions().filter(InstructionSubject::isIf).count();
+ assertEquals(method.getOriginalName().contains("Double") ? 2 : 1, count);
+ }
+ }
+ }
+
+ public static class Main {
+
+ public static void main(String[] args) {
+ System.out.println(indirectTest(2, 6));
+ System.out.println(indirectTest(3, 3));
+
+ System.out.println(indirectTestNegated(2, 6));
+ System.out.println(indirectTestNegated(3, 3));
+
+ System.out.println(indirectCmp(2, 6));
+ System.out.println(indirectCmp(7, 3));
+
+ System.out.println(indirectCmpNegated(2, 6));
+ System.out.println(indirectCmpNegated(7, 3));
+
+ System.out.println(indirectDoubleTest(2, 6, 6));
+ System.out.println(indirectDoubleTest(7, 7, 3));
+ System.out.println(indirectDoubleTest(1, 1, 1));
+
+ System.out.println(indirectDoubleTestNegated(2, 6, 6));
+ System.out.println(indirectDoubleTestNegated(2, 2, 6));
+ System.out.println(indirectDoubleTestNegated(7, 7, 7));
+ }
+
+ @AlwaysInline
+ public static int doubleTest(int i, int j, int k) {
+ if (i != j) {
+ return 1;
+ }
+ if (j == k) {
+ return 2;
+ }
+ return 3;
+ }
+
+ @NeverInline
+ public static int indirectDoubleTest(int i, int j, int k) {
+ if (doubleTest(i, j, k) == 2) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectDoubleTestNegated(int i, int j, int k) {
+ if (doubleTest(i, j, k) != 2) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @AlwaysInline
+ public static int test(int i, int j) {
+ return i == j ? 1 : 2;
+ }
+
+ @NeverInline
+ public static int indirectTest(int i, int j) {
+ if (test(i, j) == 2) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectTestNegated(int i, int j) {
+ if (test(i, j) != 2) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @AlwaysInline
+ public static int cmp(int i, int j) {
+ return i <= j ? 1 : 2;
+ }
+
+ @NeverInline
+ public static int indirectCmp(int i, int j) {
+ if (cmp(i, j) < 2) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectCmpNegated(int i, int j) {
+ if (cmp(i, j) > 1) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondFloatTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondFloatTest.java
new file mode 100644
index 0000000..b433e5b
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondFloatTest.java
@@ -0,0 +1,158 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.AlwaysInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DoubleDiamondFloatTest extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ public DoubleDiamondFloatTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void test() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .enableInliningAnnotations()
+ .enableAlwaysInliningAnnotations()
+ .setMinApi(parameters)
+ .compile()
+ .inspect(this::inspect)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines(
+ "5", "1", "1", "5", "5", "5", "1", "1", "1", "5", "5", "5", "1", "1", "1", "5");
+ }
+
+ private void inspect(CodeInspector inspector) {
+ for (FoundMethodSubject method : inspector.clazz(Main.class).allMethods()) {
+ if (!method.getOriginalName().equals("main")) {
+ long count = method.streamInstructions().filter(InstructionSubject::isIf).count();
+ assertEquals(method.getOriginalName().contains("Double") ? 2 : 1, count);
+ }
+ }
+ }
+
+ public static class Main {
+
+ public static void main(String[] args) {
+ System.out.println(indirectEquals(2.0f, 6.0f));
+ System.out.println(indirectEquals(3.0f, 3.0f));
+
+ System.out.println(indirectEqualsNegated(2.0f, 6.0f));
+ System.out.println(indirectEqualsNegated(3.0f, 3.0f));
+
+ System.out.println(indirectDoubleEquals(2.0f, 6.0f, 6.0f));
+ System.out.println(indirectDoubleEquals(7.0f, 7.0f, 3.0f));
+ System.out.println(indirectDoubleEquals(1.0f, 1.0f, 1.0f));
+
+ System.out.println(indirectDoubleEqualsNegated(2.0f, 6.0f, 6.0f));
+ System.out.println(indirectDoubleEqualsNegated(2.0f, 2.0f, 6.0f));
+ System.out.println(indirectDoubleEqualsNegated(7.0f, 7.0f, 7.0f));
+
+ System.out.println(indirectDoubleEqualsSplit(2.0f, 6.0f, 6.0f));
+ System.out.println(indirectDoubleEqualsSplit(7.0f, 7.0f, 3.0f));
+ System.out.println(indirectDoubleEqualsSplit(1.0f, 1.0f, 1.0f));
+
+ System.out.println(indirectDoubleEqualsSplitNegated(2.0f, 6.0f, 6.0f));
+ System.out.println(indirectDoubleEqualsSplitNegated(2.0f, 2.0f, 6.0f));
+ System.out.println(indirectDoubleEqualsSplitNegated(7.0f, 7.0f, 7.0f));
+ }
+
+ @AlwaysInline
+ public static boolean doubleEqualsSplit(float i, float j, float k) {
+ if (i != j) {
+ return false;
+ }
+ return j == k;
+ }
+
+ @NeverInline
+ public static int indirectDoubleEqualsSplit(float i, float j, float k) {
+ if (doubleEqualsSplit(i, j, k)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectDoubleEqualsSplitNegated(float i, float j, float k) {
+ if (!doubleEqualsSplit(i, j, k)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @AlwaysInline
+ public static boolean doubleEquals(float i, float j, float k) {
+ return i == j && j == k;
+ }
+
+ @NeverInline
+ public static int indirectDoubleEquals(float i, float j, float k) {
+ if (doubleEquals(i, j, k)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectDoubleEqualsNegated(float i, float j, float k) {
+ if (!doubleEquals(i, j, k)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @AlwaysInline
+ public static boolean equals(float i, float j) {
+ return i == j;
+ }
+
+ @NeverInline
+ public static int indirectEquals(float i, float j) {
+ if (equals(i, j)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectEqualsNegated(float i, float j) {
+ if (!equals(i, j)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+ }
+}