Disable shift optimizations on shifts with negative right operands
It turns out most shift operations with negative right operands
tend to overflow into 0 immediately, and most of them are undefined.
I think it's safe not to optimize these cases
Bug: b/342067836
Change-Id: I7d28af3d22963727c1fc6062ec7800fb87134697
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
index 7a480d6..1aa30bc 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
@@ -136,8 +136,8 @@
} else if (binopDescriptor.isShift()) {
// x shift: a shift: b => x shift: (a + b) where a + b is a constant.
if (constBRight != null && constARight != null) {
- rewriteIntoConstThenBinop(
- iterator, ADD, binopDescriptor, constB, constA, input, false, code);
+ rewriteSuccessiveShift(
+ iterator, binop, binopDescriptor, constBRight, constARight, input, code);
return true;
}
} else if (binop.isSub() && constBRight != null) {
@@ -174,6 +174,30 @@
return false;
}
+ private void rewriteSuccessiveShift(
+ InstructionListIterator iterator,
+ Binop binop,
+ BinopDescriptor binopDescriptor,
+ ConstNumber constBRight,
+ ConstNumber constARight,
+ Value input,
+ IRCode code) {
+ int mask = input.outType().isWide() ? 63 : 31;
+ int intA = constARight.getIntValue() & mask;
+ int intB = constBRight.getIntValue() & mask;
+ if (intA + intB > mask) {
+ ConstNumber zero = code.createNumberConstant(0, binop.outValue().getType());
+ iterator.replaceCurrentInstruction(zero);
+ } else {
+ iterator.previous();
+ Value newConstantValue =
+ iterator.insertConstNumberInstruction(
+ code, appView.options(), intA + intB, TypeElement.getInt());
+ iterator.next();
+ replaceBinop(iterator, code, input, newConstantValue, binopDescriptor);
+ }
+ }
+
private boolean successiveLogicalSimplificationNoConstant(
InstructionListIterator iterator, Binop binop, BinopDescriptor binopDescriptor, IRCode code) {
if (!(binop.isAnd() || binop.isOr())) {
@@ -402,7 +426,7 @@
if (binop.leftValue() == binop.rightValue()) {
if (binop.isXor() || binop.isSub()) {
// a ^ a => 0, a - a => 0
- ConstNumber zero = new ConstNumber(code.createValue(binop.outValue().getType()), 0);
+ ConstNumber zero = code.createNumberConstant(0, binop.outValue().getType());
iterator.replaceCurrentInstruction(zero);
} else if (binop.isAnd() || binop.isOr()) {
// a & a => a, a | a => a.
diff --git a/src/test/java/com/android/tools/r8/ir/ShiftOppositeSignTest.java b/src/test/java/com/android/tools/r8/ir/ShiftOppositeSignTest.java
new file mode 100644
index 0000000..62cd987
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/ShiftOppositeSignTest.java
@@ -0,0 +1,138 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.StringUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class ShiftOppositeSignTest extends TestBase {
+
+ private static final String EXPECTED_RESULT =
+ StringUtils.lines(
+ "=== 1 ===",
+ "shl: 0 134217728 0 512",
+ "shr: 0 0 0 0",
+ "ushr: 0 0 0 0",
+ "shl: 0 64 268435456 0",
+ "shr: 0 0 0 0",
+ "ushr: 0 0 0 0",
+ "=== 123456 ===",
+ "shl: 0 0 0 63209472",
+ "shr: 0 0 0 241",
+ "ushr: 0 0 0 241",
+ "shl: 0 7901184 0 0",
+ "shr: 0 1929 0 0",
+ "ushr: 0 1929 0 0");
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withCfRuntimes().withDexRuntimes().withAllApiLevels().build();
+ }
+
+ private final TestParameters parameters;
+
+ public ShiftOppositeSignTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void testD8() throws Exception {
+ testForRuntime(parameters)
+ .addProgramClasses(Main.class)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(EXPECTED_RESULT);
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ testForR8(parameters.getBackend())
+ .addProgramClasses(Main.class)
+ .addKeepMainRule(Main.class)
+ .enableInliningAnnotations()
+ .setMinApi(parameters)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(EXPECTED_RESULT);
+ }
+
+ public static class Main {
+
+ public static void main(String[] args) {
+ test(1);
+ test(123456);
+ }
+
+ @NeverInline
+ private static void test(int i) {
+ System.out.println("=== " + i + " ===");
+
+ System.out.print("shl: ");
+ System.out.print(i << -2 << 7);
+ System.out.print(" ");
+ System.out.print(i << 2 << -7);
+ System.out.print(" ");
+ System.out.print(i << -2 << -7);
+ System.out.print(" ");
+ System.out.print(i << 2 << 7);
+ System.out.println();
+
+ System.out.print("shr: ");
+ System.out.print(i >> -2 >> 7);
+ System.out.print(" ");
+ System.out.print(i >> 2 >> -7);
+ System.out.print(" ");
+ System.out.print(i >> -2 >> -7);
+ System.out.print(" ");
+ System.out.print(i >> 2 >> 7);
+ System.out.println();
+
+ System.out.print("ushr: ");
+ System.out.print(i >>> -2 >>> 7);
+ System.out.print(" ");
+ System.out.print(i >>> 2 >>> -7);
+ System.out.print(" ");
+ System.out.print(i >>> -2 >>> -7);
+ System.out.print(" ");
+ System.out.print(i >>> 2 >>> 7);
+ System.out.println();
+
+ System.out.print("shl: ");
+ System.out.print(i << -5 << 31);
+ System.out.print(" ");
+ System.out.print(i << 5 << -31);
+ System.out.print(" ");
+ System.out.print(i << -5 << -31);
+ System.out.print(" ");
+ System.out.print(i << 5 << 31);
+ System.out.println();
+
+ System.out.print("shr: ");
+ System.out.print(i >> -5 >> 31);
+ System.out.print(" ");
+ System.out.print(i >> 5 >> -31);
+ System.out.print(" ");
+ System.out.print(i >> -5 >> -31);
+ System.out.print(" ");
+ System.out.print(i >> 5 >> 31);
+ System.out.println();
+
+ System.out.print("ushr: ");
+ System.out.print(i >>> -5 >>> 31);
+ System.out.print(" ");
+ System.out.print(i >>> 5 >>> -31);
+ System.out.print(" ");
+ System.out.print(i >>> -5 >>> -31);
+ System.out.print(" ");
+ System.out.print(i >>> 5 >>> 31);
+ System.out.println();
+ }
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/regress/B342067836Test.java b/src/test/java/com/android/tools/r8/regress/B342067836Test.java
index 703abc2..06a2b7f 100644
--- a/src/test/java/com/android/tools/r8/regress/B342067836Test.java
+++ b/src/test/java/com/android/tools/r8/regress/B342067836Test.java
@@ -28,7 +28,6 @@
}
private static final List<String> EXPECTED_OUTPUT = ImmutableList.of("50");
- private static final List<String> NOT_EXPECTED_OUTPUT = ImmutableList.of("864691135031803904");
@Test
public void testJvm() throws Exception {
@@ -56,7 +55,7 @@
.addKeepMainRule(TestClass.class)
.setMinApi(parameters)
.run(parameters.getRuntime(), TestClass.class)
- .assertSuccessWithOutputLines(NOT_EXPECTED_OUTPUT);
+ .assertSuccessWithOutputLines(EXPECTED_OUTPUT);
}
static class TestClass {