Extend split branch

- Support comparison against constants
- Support float comparison

Change-Id: I568ba24a0bbd66f95f1486b52355a51e973bd4e1
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 8f9378e..ebbecb0 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -30,7 +30,7 @@
 import com.android.tools.r8.ir.conversion.passes.BinopRewriter;
 import com.android.tools.r8.ir.conversion.passes.CommonSubexpressionElimination;
 import com.android.tools.r8.ir.conversion.passes.ParentConstructorHoistingCodeRewriter;
-import com.android.tools.r8.ir.conversion.passes.SplitBranchOnKnownBoolean;
+import com.android.tools.r8.ir.conversion.passes.SplitBranch;
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaringCollection;
 import com.android.tools.r8.ir.desugar.CovariantReturnTypeAnnotationTransformer;
 import com.android.tools.r8.ir.optimize.AssertionErrorTwoArgsConstructorRewriter;
@@ -115,7 +115,7 @@
   protected final InternalOptions options;
   public final CodeRewriter codeRewriter;
   public final CommonSubexpressionElimination commonSubexpressionElimination;
-  private final SplitBranchOnKnownBoolean splitBranchOnKnownBoolean;
+  private final SplitBranch splitBranch;
   public final AssertionErrorTwoArgsConstructorRewriter assertionErrorTwoArgsConstructorRewriter;
   private final NaturalIntLoopRemover naturalIntLoopRemover = new NaturalIntLoopRemover();
   public final MemberValuePropagation<?> memberValuePropagation;
@@ -167,7 +167,7 @@
     this.options = appView.options();
     this.codeRewriter = new CodeRewriter(appView);
     this.commonSubexpressionElimination = new CommonSubexpressionElimination(appView);
-    this.splitBranchOnKnownBoolean = new SplitBranchOnKnownBoolean(appView);
+    this.splitBranch = new SplitBranch(appView);
     this.assertionErrorTwoArgsConstructorRewriter =
         appView.options().desugarState.isOn()
             ? new AssertionErrorTwoArgsConstructorRewriter(appView)
@@ -777,7 +777,7 @@
       timing.end();
     }
     timing.end();
-    splitBranchOnKnownBoolean.run(code.context(), code, timing);
+    splitBranch.run(code.context(), code, timing);
     if (options.enableRedundantConstNumberOptimization) {
       timing.begin("Remove const numbers");
       codeRewriter.redundantConstNumberRemoval(code);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
similarity index 68%
rename from src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java
rename to src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
index 134d2bc..9a3a024 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranchOnKnownBoolean.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
@@ -9,6 +9,7 @@
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.ConstNumber;
 import com.android.tools.r8.ir.code.Goto;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.If;
@@ -23,17 +24,17 @@
 import java.util.Map;
 import java.util.Set;
 
-public class SplitBranchOnKnownBoolean extends CodeRewriterPass<AppInfo> {
+public class SplitBranch extends CodeRewriterPass<AppInfo> {
 
   private static final boolean ALLOW_PARTIAL_REWRITE = true;
 
-  public SplitBranchOnKnownBoolean(AppView<?> appView) {
+  public SplitBranch(AppView<?> appView) {
     super(appView);
   }
 
   @Override
   String getTimingId() {
-    return "SplitBranchOnKnownBoolean";
+    return "SplitBranch";
   }
 
   @Override
@@ -89,17 +90,31 @@
   private Map<Goto, BasicBlock> findGotosToRetarget(List<BasicBlock> candidates) {
     Map<Goto, BasicBlock> newTargets = new LinkedHashMap<>();
     for (BasicBlock block : candidates) {
-      // We need to verify any instruction in between the if and the chain of phis is empty (we
-      // could duplicate instruction, but the common case is empty).
+      // We need to verify any instruction in between the if and the chain of phis is empty or just
+      // a constant used in the If instruction (we could duplicate instruction, but the common case
+      // is empty).
       // Then we can redirect any known value. This can lead to dead code.
       If theIf = block.exit().asIf();
-      Set<Phi> allowedPhis = getAllowedPhis(theIf.lhs().asPhi());
+      Set<Phi> allowedPhis = getAllowedPhis(nonConstNumberOperand(theIf).asPhi());
       Set<Phi> foundPhis = Sets.newIdentityHashSet();
       WorkList.newIdentityWorkList(block)
           .process(
               (current, workList) -> {
                 if (current.getInstructions().size() > 1) {
-                  return;
+                  // We allow a single instruction, which is the constant used exclusively in the
+                  // if. This is run before constant canonicalization.
+                  if (theIf.isZeroTest()
+                      || current.getInstructions().size() != 2
+                      || !current.entry().isConstNumber()) {
+                    return;
+                  }
+                  Value value = current.entry().outValue();
+                  if (value.hasPhiUsers()
+                      || value.uniqueUsers().size() > 1
+                      || (value.uniqueUsers().size() == 1
+                          && value.uniqueUsers().iterator().next() != theIf)) {
+                    return;
+                  }
                 }
                 if (current != block && !current.exit().isGoto()) {
                   return;
@@ -135,30 +150,61 @@
     return newTargets;
   }
 
+  private boolean isNumberAgainstConstNumberIf(If theIf) {
+    if (!(theIf.lhs().getType().isInt() || theIf.lhs().getType().isFloat())) {
+      return false;
+    }
+    if (theIf.isZeroTest()) {
+      return true;
+    }
+    assert theIf.lhs().getType() == theIf.rhs().getType();
+    return theIf.lhs().isConstNumber() || theIf.rhs().isConstNumber();
+  }
+
+  private Value nonConstNumberOperand(If theIf) {
+    return theIf.isZeroTest()
+        ? theIf.lhs()
+        : (theIf.lhs().isConstNumber() ? theIf.rhs() : theIf.lhs());
+  }
+
   private List<BasicBlock> computeCandidates(IRCode code) {
     List<BasicBlock> candidates = new ArrayList<>();
-    for (BasicBlock block : ListUtils.filter(code.blocks, block -> block.entry().isIf())) {
+    for (BasicBlock block : ListUtils.filter(code.blocks, block -> block.exit().isIf())) {
       If theIf = block.exit().asIf();
-      if (theIf.isZeroTest()
-          && theIf.lhs().getType().isInt()
-          && theIf.lhs().isPhi()
-          && theIf.lhs().hasSingleUniqueUser()
-          && !theIf.lhs().hasPhiUsers()) {
+      if (!isNumberAgainstConstNumberIf(theIf)) {
+        continue;
+      }
+      Value nonConstNumberOperand = nonConstNumberOperand(theIf);
+      if (isNumberAgainstConstNumberIf(theIf)
+          && nonConstNumberOperand.isPhi()
+          && nonConstNumberOperand.hasSingleUniqueUser()
+          && !nonConstNumberOperand.hasPhiUsers()) {
         candidates.add(block);
       }
     }
     return candidates;
   }
 
+  private BasicBlock targetFromCondition(If theIf, ConstNumber constForPhi) {
+    if (theIf.isZeroTest()) {
+      return theIf.targetFromCondition(constForPhi);
+    }
+    if (theIf.lhs().isConstNumber()) {
+      return theIf.targetFromCondition(
+          theIf.lhs().getConstInstruction().asConstNumber(), constForPhi);
+    }
+    assert theIf.rhs().isConstNumber();
+    return theIf.targetFromCondition(
+        constForPhi, theIf.rhs().getConstInstruction().asConstNumber());
+  }
+
   private void recordNewTargetForGoto(
       Value value, BasicBlock basicBlock, If theIf, Map<Goto, BasicBlock> newTargets) {
     // The GoTo at the end of basicBlock should target the phiBlock, and should target instead
     // the correct if destination.
     assert basicBlock.exit().isGoto();
     assert value.isConstant();
-    assert value.getType().isInt();
-    assert theIf.isZeroTest();
-    BasicBlock newTarget = theIf.targetFromCondition(value.getConstInstruction().asConstNumber());
+    BasicBlock newTarget = targetFromCondition(theIf, value.getConstInstruction().asConstNumber());
     Goto aGoto = basicBlock.exit().asGoto();
     newTargets.put(aGoto, newTarget);
   }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondCstTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondCstTest.java
new file mode 100644
index 0000000..7c5ed88
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondCstTest.java
@@ -0,0 +1,159 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.AlwaysInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DoubleDiamondCstTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public DoubleDiamondCstTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .enableInliningAnnotations()
+        .enableAlwaysInliningAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(
+            "1", "5", "5", "1", "1", "5", "5", "1", "5", "5", "1", "1", "1", "5");
+  }
+
+  private void inspect(CodeInspector inspector) {
+    for (FoundMethodSubject method : inspector.clazz(Main.class).allMethods()) {
+      if (!method.getOriginalName().equals("main")) {
+        long count = method.streamInstructions().filter(InstructionSubject::isIf).count();
+        assertEquals(method.getOriginalName().contains("Double") ? 2 : 1, count);
+      }
+    }
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      System.out.println(indirectTest(2, 6));
+      System.out.println(indirectTest(3, 3));
+
+      System.out.println(indirectTestNegated(2, 6));
+      System.out.println(indirectTestNegated(3, 3));
+
+      System.out.println(indirectCmp(2, 6));
+      System.out.println(indirectCmp(7, 3));
+
+      System.out.println(indirectCmpNegated(2, 6));
+      System.out.println(indirectCmpNegated(7, 3));
+
+      System.out.println(indirectDoubleTest(2, 6, 6));
+      System.out.println(indirectDoubleTest(7, 7, 3));
+      System.out.println(indirectDoubleTest(1, 1, 1));
+
+      System.out.println(indirectDoubleTestNegated(2, 6, 6));
+      System.out.println(indirectDoubleTestNegated(2, 2, 6));
+      System.out.println(indirectDoubleTestNegated(7, 7, 7));
+    }
+
+    @AlwaysInline
+    public static int doubleTest(int i, int j, int k) {
+      if (i != j) {
+        return 1;
+      }
+      if (j == k) {
+        return 2;
+      }
+      return 3;
+    }
+
+    @NeverInline
+    public static int indirectDoubleTest(int i, int j, int k) {
+      if (doubleTest(i, j, k) == 2) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectDoubleTestNegated(int i, int j, int k) {
+      if (doubleTest(i, j, k) != 2) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @AlwaysInline
+    public static int test(int i, int j) {
+      return i == j ? 1 : 2;
+    }
+
+    @NeverInline
+    public static int indirectTest(int i, int j) {
+      if (test(i, j) == 2) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectTestNegated(int i, int j) {
+      if (test(i, j) != 2) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @AlwaysInline
+    public static int cmp(int i, int j) {
+      return i <= j ? 1 : 2;
+    }
+
+    @NeverInline
+    public static int indirectCmp(int i, int j) {
+      if (cmp(i, j) < 2) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectCmpNegated(int i, int j) {
+      if (cmp(i, j) > 1) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondFloatTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondFloatTest.java
new file mode 100644
index 0000000..b433e5b
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/DoubleDiamondFloatTest.java
@@ -0,0 +1,158 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.AlwaysInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DoubleDiamondFloatTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public DoubleDiamondFloatTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .enableInliningAnnotations()
+        .enableAlwaysInliningAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(
+            "5", "1", "1", "5", "5", "5", "1", "1", "1", "5", "5", "5", "1", "1", "1", "5");
+  }
+
+  private void inspect(CodeInspector inspector) {
+    for (FoundMethodSubject method : inspector.clazz(Main.class).allMethods()) {
+      if (!method.getOriginalName().equals("main")) {
+        long count = method.streamInstructions().filter(InstructionSubject::isIf).count();
+        assertEquals(method.getOriginalName().contains("Double") ? 2 : 1, count);
+      }
+    }
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      System.out.println(indirectEquals(2.0f, 6.0f));
+      System.out.println(indirectEquals(3.0f, 3.0f));
+
+      System.out.println(indirectEqualsNegated(2.0f, 6.0f));
+      System.out.println(indirectEqualsNegated(3.0f, 3.0f));
+
+      System.out.println(indirectDoubleEquals(2.0f, 6.0f, 6.0f));
+      System.out.println(indirectDoubleEquals(7.0f, 7.0f, 3.0f));
+      System.out.println(indirectDoubleEquals(1.0f, 1.0f, 1.0f));
+
+      System.out.println(indirectDoubleEqualsNegated(2.0f, 6.0f, 6.0f));
+      System.out.println(indirectDoubleEqualsNegated(2.0f, 2.0f, 6.0f));
+      System.out.println(indirectDoubleEqualsNegated(7.0f, 7.0f, 7.0f));
+
+      System.out.println(indirectDoubleEqualsSplit(2.0f, 6.0f, 6.0f));
+      System.out.println(indirectDoubleEqualsSplit(7.0f, 7.0f, 3.0f));
+      System.out.println(indirectDoubleEqualsSplit(1.0f, 1.0f, 1.0f));
+
+      System.out.println(indirectDoubleEqualsSplitNegated(2.0f, 6.0f, 6.0f));
+      System.out.println(indirectDoubleEqualsSplitNegated(2.0f, 2.0f, 6.0f));
+      System.out.println(indirectDoubleEqualsSplitNegated(7.0f, 7.0f, 7.0f));
+    }
+
+    @AlwaysInline
+    public static boolean doubleEqualsSplit(float i, float j, float k) {
+      if (i != j) {
+        return false;
+      }
+      return j == k;
+    }
+
+    @NeverInline
+    public static int indirectDoubleEqualsSplit(float i, float j, float k) {
+      if (doubleEqualsSplit(i, j, k)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectDoubleEqualsSplitNegated(float i, float j, float k) {
+      if (!doubleEqualsSplit(i, j, k)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @AlwaysInline
+    public static boolean doubleEquals(float i, float j, float k) {
+      return i == j && j == k;
+    }
+
+    @NeverInline
+    public static int indirectDoubleEquals(float i, float j, float k) {
+      if (doubleEquals(i, j, k)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectDoubleEqualsNegated(float i, float j, float k) {
+      if (!doubleEquals(i, j, k)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @AlwaysInline
+    public static boolean equals(float i, float j) {
+      return i == j;
+    }
+
+    @NeverInline
+    public static int indirectEquals(float i, float j) {
+      if (equals(i, j)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+
+    @NeverInline
+    public static int indirectEqualsNegated(float i, float j) {
+      if (!equals(i, j)) {
+        return 1;
+      } else {
+        return 5;
+      }
+    }
+  }
+}