Split branches on known boolean
Change-Id: Ie4e4a19aca82702a9d5f62f5edc86a727450ad18
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 0acb465..17beac4 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -28,6 +28,7 @@
import com.android.tools.r8.ir.code.InstructionIterator;
import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.ir.conversion.passes.ParentConstructorHoistingCodeRewriter;
+import com.android.tools.r8.ir.conversion.passes.SplitBranchOnKnownBoolean;
import com.android.tools.r8.ir.desugar.CfInstructionDesugaringCollection;
import com.android.tools.r8.ir.desugar.CovariantReturnTypeAnnotationTransformer;
import com.android.tools.r8.ir.optimize.AssertionErrorTwoArgsConstructorRewriter;
@@ -111,6 +112,7 @@
private final ClassInliner classInliner;
protected final InternalOptions options;
public final CodeRewriter codeRewriter;
+ private final SplitBranchOnKnownBoolean splitBranchOnKnownBoolean;
public final AssertionErrorTwoArgsConstructorRewriter assertionErrorTwoArgsConstructorRewriter;
private final NaturalIntLoopRemover naturalIntLoopRemover = new NaturalIntLoopRemover();
public final MemberValuePropagation<?> memberValuePropagation;
@@ -160,6 +162,7 @@
this.appView = appView;
this.options = appView.options();
this.codeRewriter = new CodeRewriter(appView);
+ this.splitBranchOnKnownBoolean = new SplitBranchOnKnownBoolean(appView);
this.assertionErrorTwoArgsConstructorRewriter =
appView.options().desugarState.isOn()
? new AssertionErrorTwoArgsConstructorRewriter(appView)
@@ -765,6 +768,7 @@
timing.end();
}
timing.end();
+ splitBranchOnKnownBoolean.run(code.context(), code, timing);
if (options.enableRedundantConstNumberOptimization) {
timing.begin("Remove const numbers");
codeRewriter.redundantConstNumberRemoval(code);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java
new file mode 100644
index 0000000..942499a
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java
@@ -0,0 +1,178 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.conversion.passes;
+
+import com.android.tools.r8.graph.AppInfo;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.Goto;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.If;
+import com.android.tools.r8.ir.code.Phi;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.utils.ListUtils;
+import com.android.tools.r8.utils.WorkList;
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class SplitBranchOnKnownBoolean extends CodeRewriterPass<AppInfo> {
+
+ private static final boolean ALLOW_PARTIAL_REWRITE = true;
+
+ public SplitBranchOnKnownBoolean(AppView<?> appView) {
+ super(appView);
+ }
+
+ @Override
+ String getTimingId() {
+ return "SplitBranchOnKnownBoolean";
+ }
+
+ @Override
+ boolean shouldRewriteCode(ProgramMethod method, IRCode code) {
+ return true;
+ }
+
+ /**
+ * Simplify Boolean branches for example: <code>
+ * boolean b = i == j; if (b) { ... } else { ... }
+ * </code> ends up first creating a branch for the boolean b, then a second branch on b. D8/R8
+ * rewrites to: <code>
+ * if (i == j) { ... } else { ... }
+ * </code> More complex control flow are also supported to some extent, including cases where the
+ * input of the second branch comes from a set of dependent phis, and a subset of the inputs are
+ * known boolean values.
+ */
+ @Override
+ void rewriteCode(ProgramMethod method, IRCode code) {
+ List<BasicBlock> candidates = computeCandidates(code);
+ if (candidates.isEmpty()) {
+ return;
+ }
+ Map<Goto, BasicBlock> newTargets = findGotosToRetarget(candidates);
+ if (newTargets.isEmpty()) {
+ return;
+ }
+ retargetGotos(newTargets);
+ Set<Value> affectedValues = Sets.newIdentityHashSet();
+ affectedValues.addAll(code.removeUnreachableBlocks());
+ code.removeAllDeadAndTrivialPhis(affectedValues);
+ if (!affectedValues.isEmpty()) {
+ new TypeAnalysis(appView).narrowing(affectedValues);
+ }
+ if (ALLOW_PARTIAL_REWRITE) {
+ code.splitCriticalEdges();
+ }
+ assert code.isConsistentSSA(appView);
+ }
+
+ private void retargetGotos(Map<Goto, BasicBlock> newTargets) {
+ newTargets.forEach(
+ (goTo, newTarget) -> {
+ BasicBlock initialTarget = goTo.getTarget();
+ for (Phi phi : initialTarget.getPhis()) {
+ int index = initialTarget.getPredecessors().indexOf(goTo.getBlock());
+ phi.removeOperand(index);
+ }
+ goTo.setTarget(newTarget);
+ });
+ }
+
+ private Map<Goto, BasicBlock> findGotosToRetarget(List<BasicBlock> candidates) {
+ Map<Goto, BasicBlock> newTargets = new LinkedHashMap<>();
+ for (BasicBlock block : candidates) {
+ // We need to verify any instruction in between the if and the chain of phis is empty (we
+ // could duplicate instruction, but the common case is empty).
+ // Then we can redirect any known value. This can lead to dead code.
+ If theIf = block.exit().asIf();
+ Set<Phi> allowedPhis = getAllowedPhis(theIf.lhs().asPhi());
+ Set<Phi> foundPhis = Sets.newIdentityHashSet();
+ WorkList.newIdentityWorkList(block)
+ .process(
+ (current, workList) -> {
+ if (current.getInstructions().size() > 1) {
+ return;
+ }
+ if (current != block && !current.exit().isGoto()) {
+ return;
+ }
+ if (allowedPhis.containsAll(current.getPhis())) {
+ foundPhis.addAll(current.getPhis());
+ } else {
+ return;
+ }
+ workList.addIfNotSeen(current.getPredecessors());
+ });
+ if (!ALLOW_PARTIAL_REWRITE) {
+ for (Phi phi : foundPhis) {
+ for (Value value : phi.getOperands()) {
+ if (!value.isConstant() && !(value.isPhi() && foundPhis.contains(value.asPhi()))) {
+ return newTargets;
+ }
+ }
+ }
+ }
+ for (Phi phi : foundPhis) {
+ BasicBlock phiBlock = phi.getBlock();
+ for (int i = 0; i < phi.getOperands().size(); i++) {
+ Value value = phi.getOperand(i);
+ if (value.isConstant()) {
+ recordNewTargetForGoto(value, phiBlock.getPredecessors().get(i), theIf, newTargets);
+ }
+ }
+ }
+ }
+ return newTargets;
+ }
+
+ private List<BasicBlock> computeCandidates(IRCode code) {
+ List<BasicBlock> candidates = new ArrayList<>();
+ for (BasicBlock block : ListUtils.filter(code.blocks, block -> block.entry().isIf())) {
+ If theIf = block.exit().asIf();
+ if (theIf.isZeroTest()
+ && theIf.lhs().getType().isInt()
+ && theIf.lhs().isPhi()
+ && theIf.lhs().hasSingleUniqueUser()
+ && !theIf.lhs().hasPhiUsers()) {
+ candidates.add(block);
+ }
+ }
+ return candidates;
+ }
+
+ private void recordNewTargetForGoto(
+ Value value, BasicBlock basicBlock, If theIf, Map<Goto, BasicBlock> newTargets) {
+ // The GoTo at the end of basicBlock should target the phiBlock, and should target instead
+ // the correct if destination.
+ assert basicBlock.exit().isGoto();
+ assert value.isConstant();
+ assert value.getType().isInt();
+ assert theIf.isZeroTest();
+ BasicBlock newTarget = theIf.targetFromCondition(value.getConstInstruction().asConstNumber());
+ Goto aGoto = basicBlock.exit().asGoto();
+ newTargets.put(aGoto, newTarget);
+ }
+
+ private Set<Phi> getAllowedPhis(Phi initialPhi) {
+ WorkList<Phi> workList = WorkList.newIdentityWorkList(initialPhi);
+ while (workList.hasNext()) {
+ Phi phi = workList.next();
+ for (Value operand : phi.getOperands()) {
+ if (operand.isPhi()
+ && (operand.uniqueUsers().isEmpty() || phi == initialPhi)
+ && workList.getSeenSet().containsAll(operand.uniquePhiUsers())) {
+ workList.addIfNotSeen(operand.asPhi());
+ }
+ }
+ }
+ return workList.getSeenSet();
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 692eeee..d8e62cb 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -499,6 +499,7 @@
// TODO(sgjesse); Move this somewhere else, and reuse it for some of the other switch rewritings.
public abstract static class InstructionBuilder<T> {
+
protected int blockNumber;
protected final Position position;
@@ -515,6 +516,7 @@
}
public static class SwitchBuilder extends InstructionBuilder<SwitchBuilder> {
+
private Value value;
private final Int2ReferenceSortedMap<BasicBlock> keyToTarget = new Int2ReferenceAVLTreeMap<>();
private BasicBlock fallthrough;
@@ -530,7 +532,7 @@
public SwitchBuilder setValue(Value value) {
this.value = value;
- return this;
+ return this;
}
public SwitchBuilder addKeyAndTarget(int key, BasicBlock target) {
@@ -575,6 +577,7 @@
}
public static class IfBuilder extends InstructionBuilder<IfBuilder> {
+
private final IRCode code;
private Value left;
private int right;
@@ -593,12 +596,12 @@
public IfBuilder setLeft(Value left) {
this.left = left;
- return this;
+ return this;
}
public IfBuilder setRight(int right) {
this.right = right;
- return this;
+ return this;
}
public IfBuilder setTarget(BasicBlock target) {
@@ -2242,6 +2245,7 @@
}
private static class FilledArrayCandidate {
+
final NewArrayEmpty newArrayEmpty;
final int size;
final boolean encodeAsFilledNewArray;
@@ -2707,6 +2711,7 @@
}
static class ControlFlowSimplificationResult {
+
private boolean anyAffectedValues;
private boolean anySimplifications;
diff --git a/src/test/java/com/android/tools/r8/internal/opensourceapps/TiviTest.java b/src/test/java/com/android/tools/r8/internal/opensourceapps/TiviTest.java
index e220e0a..f0a5fd9 100644
--- a/src/test/java/com/android/tools/r8/internal/opensourceapps/TiviTest.java
+++ b/src/test/java/com/android/tools/r8/internal/opensourceapps/TiviTest.java
@@ -4,15 +4,13 @@
package com.android.tools.r8.internal.opensourceapps;
-import static org.junit.Assume.assumeTrue;
-
import com.android.tools.r8.LibraryDesugaringTestConfiguration;
import com.android.tools.r8.R8TestBuilder;
+import com.android.tools.r8.R8TestCompileResult;
import com.android.tools.r8.StringResource;
import com.android.tools.r8.TestBase;
import com.android.tools.r8.TestParameters;
import com.android.tools.r8.TestParametersCollection;
-import com.android.tools.r8.ToolHelper;
import com.android.tools.r8.utils.AndroidApiLevel;
import com.android.tools.r8.utils.ZipUtils;
import java.io.IOException;
@@ -40,17 +38,19 @@
@BeforeClass
public static void setup() throws IOException {
- assumeTrue(ToolHelper.isLocalDevelopment());
+ // assumeTrue(ToolHelper.isLocalDevelopment());
outDirectory = getStaticTemp().newFolder().toPath();
ZipUtils.unzip(Paths.get("third_party/opensource-apps/tivi/dump_app.zip"), outDirectory);
}
@Test
public void testR8() throws Exception {
- testForR8(Backend.DEX)
- .addProgramFiles(outDirectory.resolve("program.jar"))
- .apply(this::configure)
- .compile();
+ R8TestCompileResult compile =
+ testForR8(Backend.DEX)
+ .addProgramFiles(outDirectory.resolve("program.jar"))
+ .apply(this::configure)
+ .compile();
+ System.out.println(compile.app.applicationSize());
}
@Test
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondTest.java
new file mode 100644
index 0000000..cf928e6
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondTest.java
@@ -0,0 +1,188 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.AlwaysInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DoubleDiamondTest extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ public DoubleDiamondTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void test() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .enableInliningAnnotations()
+ .enableAlwaysInliningAnnotations()
+ .setMinApi(parameters)
+ .compile()
+ .inspect(this::inspect)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines(
+ "5", "1", "1", "5", "1", "5", "5", "1", "5", "5", "1", "1", "1", "5", "5", "5", "1",
+ "1", "1", "5");
+ }
+
+ private void inspect(CodeInspector inspector) {
+ for (FoundMethodSubject method : inspector.clazz(Main.class).allMethods()) {
+ if (!method.getOriginalName().equals("main")) {
+ long count = method.streamInstructions().filter(InstructionSubject::isIf).count();
+ assertEquals(method.getOriginalName().contains("Double") ? 2 : 1, count);
+ }
+ }
+ }
+
+ public static class Main {
+
+ public static void main(String[] args) {
+ System.out.println(indirectEquals(2, 6));
+ System.out.println(indirectEquals(3, 3));
+
+ System.out.println(indirectEqualsNegated(2, 6));
+ System.out.println(indirectEqualsNegated(3, 3));
+
+ System.out.println(indirectLessThan(2, 6));
+ System.out.println(indirectLessThan(7, 3));
+
+ System.out.println(indirectLessThanNegated(2, 6));
+ System.out.println(indirectLessThanNegated(7, 3));
+
+ System.out.println(indirectDoubleEquals(2, 6, 6));
+ System.out.println(indirectDoubleEquals(7, 7, 3));
+ System.out.println(indirectDoubleEquals(1, 1, 1));
+
+ System.out.println(indirectDoubleEqualsNegated(2, 6, 6));
+ System.out.println(indirectDoubleEqualsNegated(2, 2, 6));
+ System.out.println(indirectDoubleEqualsNegated(7, 7, 7));
+
+ System.out.println(indirectDoubleEqualsSplit(2, 6, 6));
+ System.out.println(indirectDoubleEqualsSplit(7, 7, 3));
+ System.out.println(indirectDoubleEqualsSplit(1, 1, 1));
+
+ System.out.println(indirectDoubleEqualsSplitNegated(2, 6, 6));
+ System.out.println(indirectDoubleEqualsSplitNegated(2, 2, 6));
+ System.out.println(indirectDoubleEqualsSplitNegated(7, 7, 7));
+ }
+
+ @AlwaysInline
+ public static boolean doubleEqualsSplit(int i, int j, int k) {
+ if (i != j) {
+ return false;
+ }
+ return j == k;
+ }
+
+ @NeverInline
+ public static int indirectDoubleEqualsSplit(int i, int j, int k) {
+ if (doubleEqualsSplit(i, j, k)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectDoubleEqualsSplitNegated(int i, int j, int k) {
+ if (!doubleEqualsSplit(i, j, k)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @AlwaysInline
+ public static boolean doubleEquals(int i, int j, int k) {
+ return i == j && j == k;
+ }
+
+ @NeverInline
+ public static int indirectDoubleEquals(int i, int j, int k) {
+ if (doubleEquals(i, j, k)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectDoubleEqualsNegated(int i, int j, int k) {
+ if (!doubleEquals(i, j, k)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @AlwaysInline
+ public static boolean equals(int i, int j) {
+ return i == j;
+ }
+
+ @NeverInline
+ public static int indirectEquals(int i, int j) {
+ if (equals(i, j)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectEqualsNegated(int i, int j) {
+ if (!equals(i, j)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @AlwaysInline
+ public static boolean lessThan(int i, int j) {
+ return i <= j;
+ }
+
+ @NeverInline
+ public static int indirectLessThan(int i, int j) {
+ if (lessThan(i, j)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+
+ @NeverInline
+ public static int indirectLessThanNegated(int i, int j) {
+ if (!lessThan(i, j)) {
+ return 1;
+ } else {
+ return 5;
+ }
+ }
+ }
+}