Extend if-simplication to trivial object comparisons

Change-Id: I336435a937ab47bade74b74a8dd456a5e9a72cc5
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index d765d76..6aa39c0 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -2482,141 +2482,101 @@
 
         // Simplify if conditions when possible.
         If theIf = block.exit().asIf();
-        Value lhs = theIf.lhs();
-        Value rhs = theIf.isZeroTest() ? null : theIf.rhs();
 
-        if (lhs.isConstNumber() && (theIf.isZeroTest() || rhs.isConstNumber())) {
+        if (theIf.isZeroTest()) {
+          simplifyIfZeroTest(code, block, theIf);
+          continue;
+        }
+
+        Value lhs = theIf.lhs();
+        Value lhsRoot = lhs.getAliasedValue();
+        Value rhs = theIf.rhs();
+        Value rhsRoot = rhs.getAliasedValue();
+
+        if (lhsRoot == rhsRoot) {
+          // Comparing the same value.
+          simplifyIfWithKnownCondition(code, block, theIf, theIf.targetFromCondition(0));
+        } else if (lhsRoot.isDefinedByInstructionSatisfying(Instruction::isCreatingInstanceOrArray)
+            && rhsRoot.isDefinedByInstructionSatisfying(Instruction::isCreatingInstanceOrArray)) {
+          // Comparing two newly created objects.
+          assert theIf.getType() == Type.EQ || theIf.getType() == Type.NE;
+          simplifyIfWithKnownCondition(code, block, theIf, theIf.targetFromCondition(1));
+        } else if (lhs.isConstNumber() && rhs.isConstNumber()) {
           // Zero test with a constant of comparison between between two constants.
-          if (theIf.isZeroTest()) {
-            ConstNumber cond = lhs.getConstInstruction().asConstNumber();
-            BasicBlock target = theIf.targetFromCondition(cond);
-            simplifyIfWithKnownCondition(code, block, theIf, target);
-          } else {
-            ConstNumber left = lhs.getConstInstruction().asConstNumber();
-            ConstNumber right = rhs.getConstInstruction().asConstNumber();
-            BasicBlock target = theIf.targetFromCondition(left, right);
-            simplifyIfWithKnownCondition(code, block, theIf, target);
-          }
-        } else if (lhs.hasValueRange() && (theIf.isZeroTest() || rhs.hasValueRange())) {
+          ConstNumber left = lhs.getConstInstruction().asConstNumber();
+          ConstNumber right = rhs.getConstInstruction().asConstNumber();
+          BasicBlock target = theIf.targetFromCondition(left, right);
+          simplifyIfWithKnownCondition(code, block, theIf, target);
+        } else if (lhs.hasValueRange() && rhs.hasValueRange()) {
           // Zero test with a value range, or comparison between between two values,
           // each with a value ranges.
-          if (theIf.isZeroTest()) {
-            LongInterval interval = lhs.getValueRange();
-            if (!interval.containsValue(0)) {
-              // Interval doesn't contain zero at all.
-              int sign = Long.signum(interval.getMin());
-              simplifyIfWithKnownCondition(code, block, theIf, sign);
-            } else {
-              // Interval contains zero.
-              switch (theIf.getType()) {
-                case GE:
-                case LT:
-                  // [a, b] >= 0 is always true if a >= 0.
-                  // [a, b] < 0 is always false if a >= 0.
-                  // In both cases a zero condition takes the right branch.
-                  if (interval.getMin() == 0) {
-                    simplifyIfWithKnownCondition(code, block, theIf, 0);
-                  }
-                  break;
-                case LE:
-                case GT:
-                  // [a, b] <= 0 is always true if b <= 0.
-                  // [a, b] > 0 is always false if b <= 0.
-                  if (interval.getMax() == 0) {
-                    simplifyIfWithKnownCondition(code, block, theIf, 0);
-                  }
-                  break;
-                case EQ:
-                case NE:
-                  // Only a single element interval [0, 0] can be dealt with here.
-                  // Such intervals should have been replaced by constants.
-                  assert !interval.isSingleValue();
-                  break;
-              }
-            }
+          LongInterval leftRange = lhs.getValueRange();
+          LongInterval rightRange = rhs.getValueRange();
+          // Two overlapping ranges. Check for single point overlap.
+          if (!leftRange.overlapsWith(rightRange)) {
+            // No overlap.
+            int cond = Long.signum(leftRange.getMin() - rightRange.getMin());
+            simplifyIfWithKnownCondition(code, block, theIf, cond);
           } else {
-            LongInterval leftRange = lhs.getValueRange();
-            LongInterval rightRange = rhs.getValueRange();
-            // Two overlapping ranges. Check for single point overlap.
-            if (!leftRange.overlapsWith(rightRange)) {
-              // No overlap.
-              int cond = Long.signum(leftRange.getMin() - rightRange.getMin());
-              simplifyIfWithKnownCondition(code, block, theIf, cond);
-            } else {
-              // The two intervals overlap. We can simplify if they overlap at the end points.
-              switch (theIf.getType()) {
-                case LT:
-                case GE:
-                  // [a, b] < [c, d] is always false when a == d.
-                  // [a, b] >= [c, d] is always true when a == d.
-                  // In both cases 0 condition will choose the right branch.
-                  if (leftRange.getMin() == rightRange.getMax()) {
-                    simplifyIfWithKnownCondition(code, block, theIf, 0);
-                  }
-                  break;
-                case GT:
-                case LE:
-                  // [a, b] > [c, d] is always false when b == c.
-                  // [a, b] <= [c, d] is always true when b == c.
-                  // In both cases 0 condition will choose the right branch.
-                  if (leftRange.getMax() == rightRange.getMin()) {
-                    simplifyIfWithKnownCondition(code, block, theIf, 0);
-                  }
-                  break;
-                case EQ:
-                case NE:
-                  // Since there is overlap EQ and NE cannot be determined.
-                  break;
-              }
+            // The two intervals overlap. We can simplify if they overlap at the end points.
+            switch (theIf.getType()) {
+              case LT:
+              case GE:
+                // [a, b] < [c, d] is always false when a == d.
+                // [a, b] >= [c, d] is always true when a == d.
+                // In both cases 0 condition will choose the right branch.
+                if (leftRange.getMin() == rightRange.getMax()) {
+                  simplifyIfWithKnownCondition(code, block, theIf, 0);
+                }
+                break;
+              case GT:
+              case LE:
+                // [a, b] > [c, d] is always false when b == c.
+                // [a, b] <= [c, d] is always true when b == c.
+                // In both cases 0 condition will choose the right branch.
+                if (leftRange.getMax() == rightRange.getMin()) {
+                  simplifyIfWithKnownCondition(code, block, theIf, 0);
+                }
+                break;
+              case EQ:
+              case NE:
+                // Since there is overlap EQ and NE cannot be determined.
+                break;
             }
           }
         } else if (theIf.getType() == Type.EQ || theIf.getType() == Type.NE) {
-          if (theIf.isZeroTest()) {
-            if (!lhs.isConstNumber()) {
-              TypeElement l = lhs.getType();
-              if (l.isReferenceType() && lhs.isNeverNull()) {
-                simplifyIfWithKnownCondition(code, block, theIf, 1);
-              } else {
-                if (!l.isPrimitiveType() && !l.isNullable()) {
-                  simplifyIfWithKnownCondition(code, block, theIf, 1);
-                }
-              }
+          ProgramMethod context = code.context();
+          AbstractValue abstractValue = lhs.getAbstractValue(appView, context);
+          if (abstractValue.isSingleConstClassValue()) {
+            AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
+            if (otherAbstractValue.isSingleConstClassValue()) {
+              SingleConstClassValue singleConstClassValue = abstractValue.asSingleConstClassValue();
+              SingleConstClassValue otherSingleConstClassValue =
+                  otherAbstractValue.asSingleConstClassValue();
+              simplifyIfWithKnownCondition(
+                  code,
+                  block,
+                  theIf,
+                  BooleanUtils.intValue(
+                      singleConstClassValue.getType() != otherSingleConstClassValue.getType()));
             }
-          } else {
-            ProgramMethod context = code.context();
-            AbstractValue abstractValue = lhs.getAbstractValue(appView, context);
-            if (abstractValue.isSingleConstClassValue()) {
-              AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
-              if (otherAbstractValue.isSingleConstClassValue()) {
-                SingleConstClassValue singleConstClassValue =
-                    abstractValue.asSingleConstClassValue();
-                SingleConstClassValue otherSingleConstClassValue =
-                    otherAbstractValue.asSingleConstClassValue();
-                simplifyIfWithKnownCondition(
-                    code,
-                    block,
-                    theIf,
-                    BooleanUtils.intValue(
-                        singleConstClassValue.getType() != otherSingleConstClassValue.getType()));
-              }
-            } else if (abstractValue.isSingleFieldValue()) {
-              AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
-              if (otherAbstractValue.isSingleFieldValue()) {
-                SingleFieldValue singleFieldValue = abstractValue.asSingleFieldValue();
-                SingleFieldValue otherSingleFieldValue = otherAbstractValue.asSingleFieldValue();
-                if (singleFieldValue.getField() == otherSingleFieldValue.getField()) {
-                  simplifyIfWithKnownCondition(code, block, theIf, 0);
-                } else {
-                  DexClass holder = appView.definitionForHolder(singleFieldValue.getField());
-                  DexEncodedField field = singleFieldValue.getField().lookupOnClass(holder);
-                  if (field != null && field.isEnum()) {
-                    DexClass otherHolder =
-                        appView.definitionForHolder(otherSingleFieldValue.getField());
-                    DexEncodedField otherField =
-                        otherSingleFieldValue.getField().lookupOnClass(otherHolder);
-                    if (otherField != null && otherField.isEnum()) {
-                      simplifyIfWithKnownCondition(code, block, theIf, 1);
-                    }
+          } else if (abstractValue.isSingleFieldValue()) {
+            AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
+            if (otherAbstractValue.isSingleFieldValue()) {
+              SingleFieldValue singleFieldValue = abstractValue.asSingleFieldValue();
+              SingleFieldValue otherSingleFieldValue = otherAbstractValue.asSingleFieldValue();
+              if (singleFieldValue.getField() == otherSingleFieldValue.getField()) {
+                simplifyIfWithKnownCondition(code, block, theIf, 0);
+              } else {
+                DexClass holder = appView.definitionForHolder(singleFieldValue.getField());
+                DexEncodedField field = singleFieldValue.getField().lookupOnClass(holder);
+                if (field != null && field.isEnum()) {
+                  DexClass otherHolder =
+                      appView.definitionForHolder(otherSingleFieldValue.getField());
+                  DexEncodedField otherField =
+                      otherSingleFieldValue.getField().lookupOnClass(otherHolder);
+                  if (otherField != null && otherField.isEnum()) {
+                    simplifyIfWithKnownCondition(code, block, theIf, 1);
                   }
                 }
               }
@@ -2633,6 +2593,73 @@
     return !affectedValues.isEmpty();
   }
 
+  private void simplifyIfZeroTest(IRCode code, BasicBlock block, If theIf) {
+    Value lhs = theIf.lhs();
+    Value lhsRoot = lhs.getAliasedValue();
+    if (lhsRoot.isConstNumber()) {
+      ConstNumber cond = lhs.getConstInstruction().asConstNumber();
+      BasicBlock target = theIf.targetFromCondition(cond);
+      simplifyIfWithKnownCondition(code, block, theIf, target);
+      return;
+    }
+
+    if (theIf.isNullTest()) {
+      assert theIf.getType() == Type.EQ || theIf.getType() == Type.NE;
+
+      if (lhs.isAlwaysNull(appView)) {
+        simplifyIfWithKnownCondition(code, block, theIf, theIf.targetFromNullObject());
+        return;
+      }
+
+      if (lhs.isNeverNull()) {
+        simplifyIfWithKnownCondition(code, block, theIf, theIf.targetFromNonNullObject());
+        return;
+      }
+    }
+
+    if (lhs.hasValueRange()) {
+      LongInterval interval = lhs.getValueRange();
+      if (!interval.containsValue(0)) {
+        // Interval doesn't contain zero at all.
+        int sign = Long.signum(interval.getMin());
+        simplifyIfWithKnownCondition(code, block, theIf, sign);
+        return;
+      }
+
+      // Interval contains zero.
+      switch (theIf.getType()) {
+        case GE:
+        case LT:
+          // [a, b] >= 0 is always true if a >= 0.
+          // [a, b] < 0 is always false if a >= 0.
+          // In both cases a zero condition takes the right branch.
+          if (interval.getMin() == 0) {
+            simplifyIfWithKnownCondition(code, block, theIf, 0);
+            return;
+          }
+          break;
+
+        case LE:
+        case GT:
+          // [a, b] <= 0 is always true if b <= 0.
+          // [a, b] > 0 is always false if b <= 0.
+          // In both cases a zero condition takes the right branch.
+          if (interval.getMax() == 0) {
+            simplifyIfWithKnownCondition(code, block, theIf, 0);
+            return;
+          }
+          break;
+
+        case EQ:
+        case NE:
+          // Only a single element interval [0, 0] can be dealt with here.
+          // Such intervals should have been replaced by constants.
+          assert !interval.isSingleValue();
+          break;
+      }
+    }
+  }
+
   private void simplifyIfWithKnownCondition(
       IRCode code, BasicBlock block, If theIf, BasicBlock target) {
     BasicBlock deadTarget =
diff --git a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
index 0b9c896..7157943 100644
--- a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
+++ b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
@@ -107,15 +107,14 @@
         .withOptionConsumer(opts -> opts.enableClassInlining = false)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 102, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 101, "lambdadesugaring"))
         .run();
 
     test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
         .withMinApiLevel(ToolHelper.getMinApiLevelForDexVmNoHigherThan(AndroidApiLevel.K))
-        .withOptionConsumer(opts -> opts.enableClassInlining = true)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 7, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 6, "lambdadesugaring"))
         .run();
   }
 
@@ -147,15 +146,14 @@
         .withOptionConsumer(opts -> opts.enableClassInlining = false)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 102, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 101, "lambdadesugaring"))
         .run();
 
     test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
         .withMinApiLevel(AndroidApiLevel.N)
-        .withOptionConsumer(opts -> opts.enableClassInlining = true)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 7, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 6, "lambdadesugaring"))
         .run();
   }
 
@@ -175,7 +173,6 @@
     test("lambdadesugaringnplus", "lambdadesugaringnplus", "LambdasWithStaticAndDefaultMethods")
         .withMinApiLevel(ToolHelper.getMinApiLevelForDexVmNoHigherThan(AndroidApiLevel.K))
         .withInterfaceMethodDesugaring(OffOrAuto.Auto)
-        .withOptionConsumer(opts -> opts.enableClassInlining = true)
         .withBuilderTransformation(ToolHelper::allowTestProguardOptions)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS_N_PLUS, Origin.unknown()))
@@ -199,7 +196,6 @@
     test("lambdadesugaringnplus", "lambdadesugaringnplus", "LambdasWithStaticAndDefaultMethods")
         .withMinApiLevel(AndroidApiLevel.N)
         .withInterfaceMethodDesugaring(OffOrAuto.Auto)
-        .withOptionConsumer(opts -> opts.enableClassInlining = true)
         .withBuilderTransformation(ToolHelper::allowTestProguardOptions)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS_N_PLUS, Origin.unknown()))
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/TrivialObjectEqualsTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/TrivialObjectEqualsTest.java
new file mode 100644
index 0000000..5e5f50c
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/TrivialObjectEqualsTest.java
@@ -0,0 +1,88 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static com.android.tools.r8.utils.codeinspector.CodeMatchers.invokesMethodWithName;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class TrivialObjectEqualsTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public TrivialObjectEqualsTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    assumeTrue(parameters.isDexRuntime());
+    testForD8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(
+            inspector ->
+                assertThat(
+                    inspector.clazz(Main.class).uniqueMethodWithName("dead"),
+                    not(invokesMethodWithName("dead"))))
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("Hello world!");
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .enableInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(
+            inspector ->
+                assertThat(inspector.clazz(Main.class).uniqueMethodWithName("dead"), isAbsent()))
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("Hello world!");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      Object o1 = new Object();
+      Object o2 = new Object();
+      if (o1 == o1) {
+        System.out.print("Hello");
+      } else {
+        dead();
+      }
+      if (o1 == o2) {
+        dead();
+      } else {
+        System.out.println(" world!");
+      }
+    }
+
+    @NeverInline
+    static void dead() {
+      System.out.println("Unexpected!");
+    }
+  }
+}