Extend if-simplication to trivial object comparisons
Change-Id: I336435a937ab47bade74b74a8dd456a5e9a72cc5
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index d765d76..6aa39c0 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -2482,141 +2482,101 @@
// Simplify if conditions when possible.
If theIf = block.exit().asIf();
- Value lhs = theIf.lhs();
- Value rhs = theIf.isZeroTest() ? null : theIf.rhs();
- if (lhs.isConstNumber() && (theIf.isZeroTest() || rhs.isConstNumber())) {
+ if (theIf.isZeroTest()) {
+ simplifyIfZeroTest(code, block, theIf);
+ continue;
+ }
+
+ Value lhs = theIf.lhs();
+ Value lhsRoot = lhs.getAliasedValue();
+ Value rhs = theIf.rhs();
+ Value rhsRoot = rhs.getAliasedValue();
+
+ if (lhsRoot == rhsRoot) {
+ // Comparing the same value.
+ simplifyIfWithKnownCondition(code, block, theIf, theIf.targetFromCondition(0));
+ } else if (lhsRoot.isDefinedByInstructionSatisfying(Instruction::isCreatingInstanceOrArray)
+ && rhsRoot.isDefinedByInstructionSatisfying(Instruction::isCreatingInstanceOrArray)) {
+ // Comparing two newly created objects.
+ assert theIf.getType() == Type.EQ || theIf.getType() == Type.NE;
+ simplifyIfWithKnownCondition(code, block, theIf, theIf.targetFromCondition(1));
+ } else if (lhs.isConstNumber() && rhs.isConstNumber()) {
// Zero test with a constant of comparison between between two constants.
- if (theIf.isZeroTest()) {
- ConstNumber cond = lhs.getConstInstruction().asConstNumber();
- BasicBlock target = theIf.targetFromCondition(cond);
- simplifyIfWithKnownCondition(code, block, theIf, target);
- } else {
- ConstNumber left = lhs.getConstInstruction().asConstNumber();
- ConstNumber right = rhs.getConstInstruction().asConstNumber();
- BasicBlock target = theIf.targetFromCondition(left, right);
- simplifyIfWithKnownCondition(code, block, theIf, target);
- }
- } else if (lhs.hasValueRange() && (theIf.isZeroTest() || rhs.hasValueRange())) {
+ ConstNumber left = lhs.getConstInstruction().asConstNumber();
+ ConstNumber right = rhs.getConstInstruction().asConstNumber();
+ BasicBlock target = theIf.targetFromCondition(left, right);
+ simplifyIfWithKnownCondition(code, block, theIf, target);
+ } else if (lhs.hasValueRange() && rhs.hasValueRange()) {
// Zero test with a value range, or comparison between between two values,
// each with a value ranges.
- if (theIf.isZeroTest()) {
- LongInterval interval = lhs.getValueRange();
- if (!interval.containsValue(0)) {
- // Interval doesn't contain zero at all.
- int sign = Long.signum(interval.getMin());
- simplifyIfWithKnownCondition(code, block, theIf, sign);
- } else {
- // Interval contains zero.
- switch (theIf.getType()) {
- case GE:
- case LT:
- // [a, b] >= 0 is always true if a >= 0.
- // [a, b] < 0 is always false if a >= 0.
- // In both cases a zero condition takes the right branch.
- if (interval.getMin() == 0) {
- simplifyIfWithKnownCondition(code, block, theIf, 0);
- }
- break;
- case LE:
- case GT:
- // [a, b] <= 0 is always true if b <= 0.
- // [a, b] > 0 is always false if b <= 0.
- if (interval.getMax() == 0) {
- simplifyIfWithKnownCondition(code, block, theIf, 0);
- }
- break;
- case EQ:
- case NE:
- // Only a single element interval [0, 0] can be dealt with here.
- // Such intervals should have been replaced by constants.
- assert !interval.isSingleValue();
- break;
- }
- }
+ LongInterval leftRange = lhs.getValueRange();
+ LongInterval rightRange = rhs.getValueRange();
+ // Two overlapping ranges. Check for single point overlap.
+ if (!leftRange.overlapsWith(rightRange)) {
+ // No overlap.
+ int cond = Long.signum(leftRange.getMin() - rightRange.getMin());
+ simplifyIfWithKnownCondition(code, block, theIf, cond);
} else {
- LongInterval leftRange = lhs.getValueRange();
- LongInterval rightRange = rhs.getValueRange();
- // Two overlapping ranges. Check for single point overlap.
- if (!leftRange.overlapsWith(rightRange)) {
- // No overlap.
- int cond = Long.signum(leftRange.getMin() - rightRange.getMin());
- simplifyIfWithKnownCondition(code, block, theIf, cond);
- } else {
- // The two intervals overlap. We can simplify if they overlap at the end points.
- switch (theIf.getType()) {
- case LT:
- case GE:
- // [a, b] < [c, d] is always false when a == d.
- // [a, b] >= [c, d] is always true when a == d.
- // In both cases 0 condition will choose the right branch.
- if (leftRange.getMin() == rightRange.getMax()) {
- simplifyIfWithKnownCondition(code, block, theIf, 0);
- }
- break;
- case GT:
- case LE:
- // [a, b] > [c, d] is always false when b == c.
- // [a, b] <= [c, d] is always true when b == c.
- // In both cases 0 condition will choose the right branch.
- if (leftRange.getMax() == rightRange.getMin()) {
- simplifyIfWithKnownCondition(code, block, theIf, 0);
- }
- break;
- case EQ:
- case NE:
- // Since there is overlap EQ and NE cannot be determined.
- break;
- }
+ // The two intervals overlap. We can simplify if they overlap at the end points.
+ switch (theIf.getType()) {
+ case LT:
+ case GE:
+ // [a, b] < [c, d] is always false when a == d.
+ // [a, b] >= [c, d] is always true when a == d.
+ // In both cases 0 condition will choose the right branch.
+ if (leftRange.getMin() == rightRange.getMax()) {
+ simplifyIfWithKnownCondition(code, block, theIf, 0);
+ }
+ break;
+ case GT:
+ case LE:
+ // [a, b] > [c, d] is always false when b == c.
+ // [a, b] <= [c, d] is always true when b == c.
+ // In both cases 0 condition will choose the right branch.
+ if (leftRange.getMax() == rightRange.getMin()) {
+ simplifyIfWithKnownCondition(code, block, theIf, 0);
+ }
+ break;
+ case EQ:
+ case NE:
+ // Since there is overlap EQ and NE cannot be determined.
+ break;
}
}
} else if (theIf.getType() == Type.EQ || theIf.getType() == Type.NE) {
- if (theIf.isZeroTest()) {
- if (!lhs.isConstNumber()) {
- TypeElement l = lhs.getType();
- if (l.isReferenceType() && lhs.isNeverNull()) {
- simplifyIfWithKnownCondition(code, block, theIf, 1);
- } else {
- if (!l.isPrimitiveType() && !l.isNullable()) {
- simplifyIfWithKnownCondition(code, block, theIf, 1);
- }
- }
+ ProgramMethod context = code.context();
+ AbstractValue abstractValue = lhs.getAbstractValue(appView, context);
+ if (abstractValue.isSingleConstClassValue()) {
+ AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
+ if (otherAbstractValue.isSingleConstClassValue()) {
+ SingleConstClassValue singleConstClassValue = abstractValue.asSingleConstClassValue();
+ SingleConstClassValue otherSingleConstClassValue =
+ otherAbstractValue.asSingleConstClassValue();
+ simplifyIfWithKnownCondition(
+ code,
+ block,
+ theIf,
+ BooleanUtils.intValue(
+ singleConstClassValue.getType() != otherSingleConstClassValue.getType()));
}
- } else {
- ProgramMethod context = code.context();
- AbstractValue abstractValue = lhs.getAbstractValue(appView, context);
- if (abstractValue.isSingleConstClassValue()) {
- AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
- if (otherAbstractValue.isSingleConstClassValue()) {
- SingleConstClassValue singleConstClassValue =
- abstractValue.asSingleConstClassValue();
- SingleConstClassValue otherSingleConstClassValue =
- otherAbstractValue.asSingleConstClassValue();
- simplifyIfWithKnownCondition(
- code,
- block,
- theIf,
- BooleanUtils.intValue(
- singleConstClassValue.getType() != otherSingleConstClassValue.getType()));
- }
- } else if (abstractValue.isSingleFieldValue()) {
- AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
- if (otherAbstractValue.isSingleFieldValue()) {
- SingleFieldValue singleFieldValue = abstractValue.asSingleFieldValue();
- SingleFieldValue otherSingleFieldValue = otherAbstractValue.asSingleFieldValue();
- if (singleFieldValue.getField() == otherSingleFieldValue.getField()) {
- simplifyIfWithKnownCondition(code, block, theIf, 0);
- } else {
- DexClass holder = appView.definitionForHolder(singleFieldValue.getField());
- DexEncodedField field = singleFieldValue.getField().lookupOnClass(holder);
- if (field != null && field.isEnum()) {
- DexClass otherHolder =
- appView.definitionForHolder(otherSingleFieldValue.getField());
- DexEncodedField otherField =
- otherSingleFieldValue.getField().lookupOnClass(otherHolder);
- if (otherField != null && otherField.isEnum()) {
- simplifyIfWithKnownCondition(code, block, theIf, 1);
- }
+ } else if (abstractValue.isSingleFieldValue()) {
+ AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
+ if (otherAbstractValue.isSingleFieldValue()) {
+ SingleFieldValue singleFieldValue = abstractValue.asSingleFieldValue();
+ SingleFieldValue otherSingleFieldValue = otherAbstractValue.asSingleFieldValue();
+ if (singleFieldValue.getField() == otherSingleFieldValue.getField()) {
+ simplifyIfWithKnownCondition(code, block, theIf, 0);
+ } else {
+ DexClass holder = appView.definitionForHolder(singleFieldValue.getField());
+ DexEncodedField field = singleFieldValue.getField().lookupOnClass(holder);
+ if (field != null && field.isEnum()) {
+ DexClass otherHolder =
+ appView.definitionForHolder(otherSingleFieldValue.getField());
+ DexEncodedField otherField =
+ otherSingleFieldValue.getField().lookupOnClass(otherHolder);
+ if (otherField != null && otherField.isEnum()) {
+ simplifyIfWithKnownCondition(code, block, theIf, 1);
}
}
}
@@ -2633,6 +2593,73 @@
return !affectedValues.isEmpty();
}
+ private void simplifyIfZeroTest(IRCode code, BasicBlock block, If theIf) {
+ Value lhs = theIf.lhs();
+ Value lhsRoot = lhs.getAliasedValue();
+ if (lhsRoot.isConstNumber()) {
+ ConstNumber cond = lhs.getConstInstruction().asConstNumber();
+ BasicBlock target = theIf.targetFromCondition(cond);
+ simplifyIfWithKnownCondition(code, block, theIf, target);
+ return;
+ }
+
+ if (theIf.isNullTest()) {
+ assert theIf.getType() == Type.EQ || theIf.getType() == Type.NE;
+
+ if (lhs.isAlwaysNull(appView)) {
+ simplifyIfWithKnownCondition(code, block, theIf, theIf.targetFromNullObject());
+ return;
+ }
+
+ if (lhs.isNeverNull()) {
+ simplifyIfWithKnownCondition(code, block, theIf, theIf.targetFromNonNullObject());
+ return;
+ }
+ }
+
+ if (lhs.hasValueRange()) {
+ LongInterval interval = lhs.getValueRange();
+ if (!interval.containsValue(0)) {
+ // Interval doesn't contain zero at all.
+ int sign = Long.signum(interval.getMin());
+ simplifyIfWithKnownCondition(code, block, theIf, sign);
+ return;
+ }
+
+ // Interval contains zero.
+ switch (theIf.getType()) {
+ case GE:
+ case LT:
+ // [a, b] >= 0 is always true if a >= 0.
+ // [a, b] < 0 is always false if a >= 0.
+ // In both cases a zero condition takes the right branch.
+ if (interval.getMin() == 0) {
+ simplifyIfWithKnownCondition(code, block, theIf, 0);
+ return;
+ }
+ break;
+
+ case LE:
+ case GT:
+ // [a, b] <= 0 is always true if b <= 0.
+ // [a, b] > 0 is always false if b <= 0.
+ // In both cases a zero condition takes the right branch.
+ if (interval.getMax() == 0) {
+ simplifyIfWithKnownCondition(code, block, theIf, 0);
+ return;
+ }
+ break;
+
+ case EQ:
+ case NE:
+ // Only a single element interval [0, 0] can be dealt with here.
+ // Such intervals should have been replaced by constants.
+ assert !interval.isSingleValue();
+ break;
+ }
+ }
+ }
+
private void simplifyIfWithKnownCondition(
IRCode code, BasicBlock block, If theIf, BasicBlock target) {
BasicBlock deadTarget =
diff --git a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
index 0b9c896..7157943 100644
--- a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
+++ b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
@@ -107,15 +107,14 @@
.withOptionConsumer(opts -> opts.enableClassInlining = false)
.withBuilderTransformation(
b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
- .withDexCheck(inspector -> checkLambdaCount(inspector, 102, "lambdadesugaring"))
+ .withDexCheck(inspector -> checkLambdaCount(inspector, 101, "lambdadesugaring"))
.run();
test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
.withMinApiLevel(ToolHelper.getMinApiLevelForDexVmNoHigherThan(AndroidApiLevel.K))
- .withOptionConsumer(opts -> opts.enableClassInlining = true)
.withBuilderTransformation(
b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
- .withDexCheck(inspector -> checkLambdaCount(inspector, 7, "lambdadesugaring"))
+ .withDexCheck(inspector -> checkLambdaCount(inspector, 6, "lambdadesugaring"))
.run();
}
@@ -147,15 +146,14 @@
.withOptionConsumer(opts -> opts.enableClassInlining = false)
.withBuilderTransformation(
b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
- .withDexCheck(inspector -> checkLambdaCount(inspector, 102, "lambdadesugaring"))
+ .withDexCheck(inspector -> checkLambdaCount(inspector, 101, "lambdadesugaring"))
.run();
test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
.withMinApiLevel(AndroidApiLevel.N)
- .withOptionConsumer(opts -> opts.enableClassInlining = true)
.withBuilderTransformation(
b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
- .withDexCheck(inspector -> checkLambdaCount(inspector, 7, "lambdadesugaring"))
+ .withDexCheck(inspector -> checkLambdaCount(inspector, 6, "lambdadesugaring"))
.run();
}
@@ -175,7 +173,6 @@
test("lambdadesugaringnplus", "lambdadesugaringnplus", "LambdasWithStaticAndDefaultMethods")
.withMinApiLevel(ToolHelper.getMinApiLevelForDexVmNoHigherThan(AndroidApiLevel.K))
.withInterfaceMethodDesugaring(OffOrAuto.Auto)
- .withOptionConsumer(opts -> opts.enableClassInlining = true)
.withBuilderTransformation(ToolHelper::allowTestProguardOptions)
.withBuilderTransformation(
b -> b.addProguardConfiguration(PROGUARD_OPTIONS_N_PLUS, Origin.unknown()))
@@ -199,7 +196,6 @@
test("lambdadesugaringnplus", "lambdadesugaringnplus", "LambdasWithStaticAndDefaultMethods")
.withMinApiLevel(AndroidApiLevel.N)
.withInterfaceMethodDesugaring(OffOrAuto.Auto)
- .withOptionConsumer(opts -> opts.enableClassInlining = true)
.withBuilderTransformation(ToolHelper::allowTestProguardOptions)
.withBuilderTransformation(
b -> b.addProguardConfiguration(PROGUARD_OPTIONS_N_PLUS, Origin.unknown()))
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/TrivialObjectEqualsTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/TrivialObjectEqualsTest.java
new file mode 100644
index 0000000..5e5f50c
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/TrivialObjectEqualsTest.java
@@ -0,0 +1,88 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.ifs;
+
+import static com.android.tools.r8.utils.codeinspector.CodeMatchers.invokesMethodWithName;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class TrivialObjectEqualsTest extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ public TrivialObjectEqualsTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void testD8() throws Exception {
+ assumeTrue(parameters.isDexRuntime());
+ testForD8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .setMinApi(parameters.getApiLevel())
+ .compile()
+ .inspect(
+ inspector ->
+ assertThat(
+ inspector.clazz(Main.class).uniqueMethodWithName("dead"),
+ not(invokesMethodWithName("dead"))))
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("Hello world!");
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .enableInliningAnnotations()
+ .setMinApi(parameters.getApiLevel())
+ .compile()
+ .inspect(
+ inspector ->
+ assertThat(inspector.clazz(Main.class).uniqueMethodWithName("dead"), isAbsent()))
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("Hello world!");
+ }
+
+ static class Main {
+
+ public static void main(String[] args) {
+ Object o1 = new Object();
+ Object o2 = new Object();
+ if (o1 == o1) {
+ System.out.print("Hello");
+ } else {
+ dead();
+ }
+ if (o1 == o2) {
+ dead();
+ } else {
+ System.out.println(" world!");
+ }
+ }
+
+ @NeverInline
+ static void dead() {
+ System.out.println("Unexpected!");
+ }
+ }
+}