Split branches on known boolean

Change-Id: Ie4e4a19aca82702a9d5f62f5edc86a727450ad18
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 0acb465..17beac4 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -28,6 +28,7 @@
 import com.android.tools.r8.ir.code.InstructionIterator;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.passes.ParentConstructorHoistingCodeRewriter;
+import com.android.tools.r8.ir.conversion.passes.SplitBranchOnKnownBoolean;
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaringCollection;
 import com.android.tools.r8.ir.desugar.CovariantReturnTypeAnnotationTransformer;
 import com.android.tools.r8.ir.optimize.AssertionErrorTwoArgsConstructorRewriter;
@@ -111,6 +112,7 @@
   private final ClassInliner classInliner;
   protected final InternalOptions options;
   public final CodeRewriter codeRewriter;
+  private final SplitBranchOnKnownBoolean splitBranchOnKnownBoolean;
   public final AssertionErrorTwoArgsConstructorRewriter assertionErrorTwoArgsConstructorRewriter;
   private final NaturalIntLoopRemover naturalIntLoopRemover = new NaturalIntLoopRemover();
   public final MemberValuePropagation<?> memberValuePropagation;
@@ -160,6 +162,7 @@
     this.appView = appView;
     this.options = appView.options();
     this.codeRewriter = new CodeRewriter(appView);
+    this.splitBranchOnKnownBoolean = new SplitBranchOnKnownBoolean(appView);
     this.assertionErrorTwoArgsConstructorRewriter =
         appView.options().desugarState.isOn()
             ? new AssertionErrorTwoArgsConstructorRewriter(appView)
@@ -765,6 +768,7 @@
       timing.end();
     }
     timing.end();
+    splitBranchOnKnownBoolean.run(code.context(), code, timing);
     if (options.enableRedundantConstNumberOptimization) {
       timing.begin("Remove const numbers");
       codeRewriter.redundantConstNumberRemoval(code);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java
new file mode 100644
index 0000000..942499a
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java
@@ -0,0 +1,178 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.conversion.passes;
+
+import com.android.tools.r8.graph.AppInfo;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.Goto;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.If;
+import com.android.tools.r8.ir.code.Phi;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.utils.ListUtils;
+import com.android.tools.r8.utils.WorkList;
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class SplitBranchOnKnownBoolean extends CodeRewriterPass<AppInfo> {
+
+  private static final boolean ALLOW_PARTIAL_REWRITE = true;
+
+  public SplitBranchOnKnownBoolean(AppView<?> appView) {
+    super(appView);
+  }
+
+  @Override
+  String getTimingId() {
+    return "SplitBranchOnKnownBoolean";
+  }
+
+  @Override
+  boolean shouldRewriteCode(ProgramMethod method, IRCode code) {
+    return true;
+  }
+
+  /**
+   * Simplify Boolean branches for example: <code>
+   * boolean b = i == j; if (b) { ... } else { ... }
+   * </code> ends up first creating a branch for the boolean b, then a second branch on b. D8/R8
+   * rewrites to: <code>
+   * if (i == j) { ... } else { ... }
+   * </code> More complex control flow are also supported to some extent, including cases where the
+   * input of the second branch comes from a set of dependent phis, and a subset of the inputs are
+   * known boolean values.
+   */
+  @Override
+  void rewriteCode(ProgramMethod method, IRCode code) {
+    List<BasicBlock> candidates = computeCandidates(code);
+    if (candidates.isEmpty()) {
+      return;
+    }
+    Map<Goto, BasicBlock> newTargets = findGotosToRetarget(candidates);
+    if (newTargets.isEmpty()) {
+      return;
+    }
+    retargetGotos(newTargets);
+    Set<Value> affectedValues = Sets.newIdentityHashSet();
+    affectedValues.addAll(code.removeUnreachableBlocks());
+    code.removeAllDeadAndTrivialPhis(affectedValues);
+    if (!affectedValues.isEmpty()) {
+      new TypeAnalysis(appView).narrowing(affectedValues);
+    }
+    if (ALLOW_PARTIAL_REWRITE) {
+      code.splitCriticalEdges();
+    }
+    assert code.isConsistentSSA(appView);
+  }
+
+  private void retargetGotos(Map<Goto, BasicBlock> newTargets) {
+    newTargets.forEach(
+        (goTo, newTarget) -> {
+          BasicBlock initialTarget = goTo.getTarget();
+          for (Phi phi : initialTarget.getPhis()) {
+            int index = initialTarget.getPredecessors().indexOf(goTo.getBlock());
+            phi.removeOperand(index);
+          }
+          goTo.setTarget(newTarget);
+        });
+  }
+
+  private Map<Goto, BasicBlock> findGotosToRetarget(List<BasicBlock> candidates) {
+    Map<Goto, BasicBlock> newTargets = new LinkedHashMap<>();
+    for (BasicBlock block : candidates) {
+      // We need to verify any instruction in between the if and the chain of phis is empty (we
+      // could duplicate instruction, but the common case is empty).
+      // Then we can redirect any known value. This can lead to dead code.
+      If theIf = block.exit().asIf();
+      Set<Phi> allowedPhis = getAllowedPhis(theIf.lhs().asPhi());
+      Set<Phi> foundPhis = Sets.newIdentityHashSet();
+      WorkList.newIdentityWorkList(block)
+          .process(
+              (current, workList) -> {
+                if (current.getInstructions().size() > 1) {
+                  return;
+                }
+                if (current != block && !current.exit().isGoto()) {
+                  return;
+                }
+                if (allowedPhis.containsAll(current.getPhis())) {
+                  foundPhis.addAll(current.getPhis());
+                } else {
+                  return;
+                }
+                workList.addIfNotSeen(current.getPredecessors());
+              });
+      if (!ALLOW_PARTIAL_REWRITE) {
+        for (Phi phi : foundPhis) {
+          for (Value value : phi.getOperands()) {
+            if (!value.isConstant() && !(value.isPhi() && foundPhis.contains(value.asPhi()))) {
+              return newTargets;
+            }
+          }
+        }
+      }
+      for (Phi phi : foundPhis) {
+        BasicBlock phiBlock = phi.getBlock();
+        for (int i = 0; i < phi.getOperands().size(); i++) {
+          Value value = phi.getOperand(i);
+          if (value.isConstant()) {
+            recordNewTargetForGoto(value, phiBlock.getPredecessors().get(i), theIf, newTargets);
+          }
+        }
+      }
+    }
+    return newTargets;
+  }
+
+  private List<BasicBlock> computeCandidates(IRCode code) {
+    List<BasicBlock> candidates = new ArrayList<>();
+    for (BasicBlock block : ListUtils.filter(code.blocks, block -> block.entry().isIf())) {
+      If theIf = block.exit().asIf();
+      if (theIf.isZeroTest()
+          && theIf.lhs().getType().isInt()
+          && theIf.lhs().isPhi()
+          && theIf.lhs().hasSingleUniqueUser()
+          && !theIf.lhs().hasPhiUsers()) {
+        candidates.add(block);
+      }
+    }
+    return candidates;
+  }
+
+  private void recordNewTargetForGoto(
+      Value value, BasicBlock basicBlock, If theIf, Map<Goto, BasicBlock> newTargets) {
+    // The GoTo at the end of basicBlock should target the phiBlock, and should target instead
+    // the correct if destination.
+    assert basicBlock.exit().isGoto();
+    assert value.isConstant();
+    assert value.getType().isInt();
+    assert theIf.isZeroTest();
+    BasicBlock newTarget = theIf.targetFromCondition(value.getConstInstruction().asConstNumber());
+    Goto aGoto = basicBlock.exit().asGoto();
+    newTargets.put(aGoto, newTarget);
+  }
+
+  private Set<Phi> getAllowedPhis(Phi initialPhi) {
+    WorkList<Phi> workList = WorkList.newIdentityWorkList(initialPhi);
+    while (workList.hasNext()) {
+      Phi phi = workList.next();
+      for (Value operand : phi.getOperands()) {
+        if (operand.isPhi()
+            && (operand.uniqueUsers().isEmpty() || phi == initialPhi)
+            && workList.getSeenSet().containsAll(operand.uniquePhiUsers())) {
+          workList.addIfNotSeen(operand.asPhi());
+        }
+      }
+    }
+    return workList.getSeenSet();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 692eeee..d8e62cb 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -499,6 +499,7 @@
 
   // TODO(sgjesse); Move this somewhere else, and reuse it for some of the other switch rewritings.
   public abstract static class InstructionBuilder<T> {
+
     protected int blockNumber;
     protected final Position position;
 
@@ -515,6 +516,7 @@
   }
 
   public static class SwitchBuilder extends InstructionBuilder<SwitchBuilder> {
+
     private Value value;
     private final Int2ReferenceSortedMap<BasicBlock> keyToTarget = new Int2ReferenceAVLTreeMap<>();
     private BasicBlock fallthrough;
@@ -530,7 +532,7 @@
 
     public SwitchBuilder setValue(Value value) {
       this.value = value;
-      return  this;
+      return this;
     }
 
     public SwitchBuilder addKeyAndTarget(int key, BasicBlock target) {
@@ -575,6 +577,7 @@
   }
 
   public static class IfBuilder extends InstructionBuilder<IfBuilder> {
+
     private final IRCode code;
     private Value left;
     private int right;
@@ -593,12 +596,12 @@
 
     public IfBuilder setLeft(Value left) {
       this.left = left;
-      return  this;
+      return this;
     }
 
     public IfBuilder setRight(int right) {
       this.right = right;
-      return  this;
+      return this;
     }
 
     public IfBuilder setTarget(BasicBlock target) {
@@ -2242,6 +2245,7 @@
   }
 
   private static class FilledArrayCandidate {
+
     final NewArrayEmpty newArrayEmpty;
     final int size;
     final boolean encodeAsFilledNewArray;
@@ -2707,6 +2711,7 @@
   }
 
   static class ControlFlowSimplificationResult {
+
     private boolean anyAffectedValues;
     private boolean anySimplifications;
 
diff --git a/src/test/java/com/android/tools/r8/internal/opensourceapps/TiviTest.java b/src/test/java/com/android/tools/r8/internal/opensourceapps/TiviTest.java
index e220e0a..f0a5fd9 100644
--- a/src/test/java/com/android/tools/r8/internal/opensourceapps/TiviTest.java
+++ b/src/test/java/com/android/tools/r8/internal/opensourceapps/TiviTest.java
@@ -4,15 +4,13 @@
 
 package com.android.tools.r8.internal.opensourceapps;
 
-import static org.junit.Assume.assumeTrue;
-
 import com.android.tools.r8.LibraryDesugaringTestConfiguration;
 import com.android.tools.r8.R8TestBuilder;
+import com.android.tools.r8.R8TestCompileResult;
 import com.android.tools.r8.StringResource;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
-import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.ZipUtils;
 import java.io.IOException;
@@ -40,17 +38,19 @@
 
   @BeforeClass
   public static void setup() throws IOException {
-    assumeTrue(ToolHelper.isLocalDevelopment());
+    // assumeTrue(ToolHelper.isLocalDevelopment());
     outDirectory = getStaticTemp().newFolder().toPath();
     ZipUtils.unzip(Paths.get("third_party/opensource-apps/tivi/dump_app.zip"), outDirectory);
   }
 
   @Test
   public void testR8() throws Exception {
-    testForR8(Backend.DEX)
-        .addProgramFiles(outDirectory.resolve("program.jar"))
-        .apply(this::configure)
-        .compile();
+    R8TestCompileResult compile =
+        testForR8(Backend.DEX)
+            .addProgramFiles(outDirectory.resolve("program.jar"))
+            .apply(this::configure)
+            .compile();
+    System.out.println(compile.app.applicationSize());
   }
 
   @Test
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondTest.java
new file mode 100644
index 0000000..cf928e6
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondTest.java
@@ -0,0 +1,188 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.AlwaysInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DoubleDiamondTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public DoubleDiamondTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .enableInliningAnnotations()
+        .enableAlwaysInliningAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(
+            "5", "1", "1", "5", "1", "5", "5", "1", "5", "5", "1", "1", "1", "5", "5", "5", "1",
+            "1", "1", "5");
+  }
+
+  private void inspect(CodeInspector inspector) {
+    for (FoundMethodSubject method : inspector.clazz(Main.class).allMethods()) {
+      if (!method.getOriginalName().equals("main")) {
+        long count = method.streamInstructions().filter(InstructionSubject::isIf).count();
+        assertEquals(method.getOriginalName().contains("Double") ? 2 : 1, count);
+      }
+    }
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      System.out.println(indirectEquals(2, 6));
+      System.out.println(indirectEquals(3, 3));
+
+      System.out.println(indirectEqualsNegated(2, 6));
+      System.out.println(indirectEqualsNegated(3, 3));
+
+      System.out.println(indirectLessThan(2, 6));
+      System.out.println(indirectLessThan(7, 3));
+
+      System.out.println(indirectLessThanNegated(2, 6));
+      System.out.println(indirectLessThanNegated(7, 3));
+
+      System.out.println(indirectDoubleEquals(2, 6, 6));
+      System.out.println(indirectDoubleEquals(7, 7, 3));
+      System.out.println(indirectDoubleEquals(1, 1, 1));
+
+      System.out.println(indirectDoubleEqualsNegated(2, 6, 6));
+      System.out.println(indirectDoubleEqualsNegated(2, 2, 6));
+      System.out.println(indirectDoubleEqualsNegated(7, 7, 7));
+
+      System.out.println(indirectDoubleEqualsSplit(2, 6, 6));
+      System.out.println(indirectDoubleEqualsSplit(7, 7, 3));
+      System.out.println(indirectDoubleEqualsSplit(1, 1, 1));
+
+      System.out.println(indirectDoubleEqualsSplitNegated(2, 6, 6));
+      System.out.println(indirectDoubleEqualsSplitNegated(2, 2, 6));
+      System.out.println(indirectDoubleEqualsSplitNegated(7, 7, 7));
+    }
+
+    @AlwaysInline
+    public static boolean doubleEqualsSplit(int i, int j, int k) {
+      if (i != j) {
+        return false;
+      }
+      return j == k;
+    }
+
+    @NeverInline
+    public static int indirectDoubleEqualsSplit(int i, int j, int k) {
+      if (doubleEqualsSplit(i, j, k)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectDoubleEqualsSplitNegated(int i, int j, int k) {
+      if (!doubleEqualsSplit(i, j, k)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @AlwaysInline
+    public static boolean doubleEquals(int i, int j, int k) {
+      return i == j && j == k;
+    }
+
+    @NeverInline
+    public static int indirectDoubleEquals(int i, int j, int k) {
+      if (doubleEquals(i, j, k)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectDoubleEqualsNegated(int i, int j, int k) {
+      if (!doubleEquals(i, j, k)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @AlwaysInline
+    public static boolean equals(int i, int j) {
+      return i == j;
+    }
+
+    @NeverInline
+    public static int indirectEquals(int i, int j) {
+      if (equals(i, j)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectEqualsNegated(int i, int j) {
+      if (!equals(i, j)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @AlwaysInline
+    public static boolean lessThan(int i, int j) {
+      return i <= j;
+    }
+
+    @NeverInline
+    public static int indirectLessThan(int i, int j) {
+      if (lessThan(i, j)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectLessThanNegated(int i, int j) {
+      if (!lessThan(i, j)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+  }
+}