Extend Binop rewriter
- associative property for add, mul, and, or, xor
- shift combination
- sub combination
- mixed add/sub combination
Change-Id: I68b1fa1d7cab4f9787b368055acd68703354183b
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
index 93ad3ff..f642a95 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
@@ -4,9 +4,11 @@
package com.android.tools.r8.ir.conversion.passes;
+import com.android.tools.r8.errors.Unreachable;
import com.android.tools.r8.graph.AppInfo;
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
import com.android.tools.r8.ir.code.Add;
import com.android.tools.r8.ir.code.And;
import com.android.tools.r8.ir.code.Binop;
@@ -42,18 +44,17 @@
private Map<Class<?>, BinopDescriptor> createBinopDescriptors() {
ImmutableMap.Builder<Class<?>, BinopDescriptor> builder = ImmutableMap.builder();
- builder.put(Add.class, new BinopDescriptor(0, 0, null, null));
- builder.put(Sub.class, new BinopDescriptor(null, 0, null, null));
- builder.put(Mul.class, new BinopDescriptor(1, 1, 0, 0));
- // The following two can be improved if we handle ZeroDivide.
- builder.put(Div.class, new BinopDescriptor(null, 1, null, null));
- builder.put(Rem.class, new BinopDescriptor(null, null, null, null));
- builder.put(And.class, new BinopDescriptor(ALL_BITS_SET, ALL_BITS_SET, 0, 0));
- builder.put(Or.class, new BinopDescriptor(0, 0, ALL_BITS_SET, ALL_BITS_SET));
- builder.put(Xor.class, new BinopDescriptor(0, 0, null, null));
- builder.put(Shl.class, new BinopDescriptor(null, 0, 0, null));
- builder.put(Shr.class, new BinopDescriptor(null, 0, 0, null));
- builder.put(Ushr.class, new BinopDescriptor(null, 0, 0, null));
+ builder.put(Add.class, BinopDescriptor.ADD);
+ builder.put(Sub.class, BinopDescriptor.SUB);
+ builder.put(Mul.class, BinopDescriptor.MUL);
+ builder.put(Div.class, BinopDescriptor.DIV);
+ builder.put(Rem.class, BinopDescriptor.REM);
+ builder.put(And.class, BinopDescriptor.AND);
+ builder.put(Or.class, BinopDescriptor.OR);
+ builder.put(Xor.class, BinopDescriptor.XOR);
+ builder.put(Shl.class, BinopDescriptor.SHL);
+ builder.put(Shr.class, BinopDescriptor.SHR);
+ builder.put(Ushr.class, BinopDescriptor.USHR);
return builder.build();
}
@@ -64,24 +65,176 @@
* - i is right identity if for each x in K, x * i = x.
* - a is left absorbing if for each x in K, a * x = a.
* - a is right absorbing if for each x in K, x * a = a.
+ * In a space K, a binop * is associative if for each x,y,z in K, (x * y) * z = x * (y * z).
* </code>
*/
- private static class BinopDescriptor {
+ private enum BinopDescriptor {
+ ADD(0, 0, null, null, true) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return Add.create(numericType, dest, left, right);
+ }
+
+ @Override
+ int evaluate(int left, int right) {
+ return left + right;
+ }
+
+ @Override
+ long evaluate(long left, long right) {
+ return left + right;
+ }
+ },
+ SUB(null, 0, null, null, false) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return new Sub(numericType, dest, left, right);
+ }
+
+ @Override
+ int evaluate(int left, int right) {
+ return left - right;
+ }
+
+ @Override
+ long evaluate(long left, long right) {
+ return left - right;
+ }
+ },
+ MUL(1, 1, 0, 0, true) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return Mul.create(numericType, dest, left, right);
+ }
+
+ @Override
+ int evaluate(int left, int right) {
+ return left * right;
+ }
+
+ @Override
+ long evaluate(long left, long right) {
+ return left * right;
+ }
+ },
+ // The following two can be improved if we handle ZeroDivide.
+ DIV(null, 1, null, null, false),
+ REM(null, null, null, null, false),
+ AND(ALL_BITS_SET, ALL_BITS_SET, 0, 0, true) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return And.create(numericType, dest, left, right);
+ }
+
+ @Override
+ int evaluate(int left, int right) {
+ return left & right;
+ }
+
+ @Override
+ long evaluate(long left, long right) {
+ return left & right;
+ }
+ },
+ OR(0, 0, ALL_BITS_SET, ALL_BITS_SET, true) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return Or.create(numericType, dest, left, right);
+ }
+
+ @Override
+ int evaluate(int left, int right) {
+ return left | right;
+ }
+
+ @Override
+ long evaluate(long left, long right) {
+ return left | right;
+ }
+ },
+ XOR(0, 0, null, null, true) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return Xor.create(numericType, dest, left, right);
+ }
+
+ @Override
+ int evaluate(int left, int right) {
+ return left ^ right;
+ }
+
+ @Override
+ long evaluate(long left, long right) {
+ return left ^ right;
+ }
+ },
+ SHL(null, 0, 0, null, false) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return new Shl(numericType, dest, left, right);
+ }
+
+ @Override
+ boolean isShift() {
+ return true;
+ }
+ },
+ SHR(null, 0, 0, null, false) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return new Shr(numericType, dest, left, right);
+ }
+
+ @Override
+ boolean isShift() {
+ return true;
+ }
+ },
+ USHR(null, 0, 0, null, false) {
+ @Override
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ return new Ushr(numericType, dest, left, right);
+ }
+
+ @Override
+ boolean isShift() {
+ return true;
+ }
+ };
final Integer leftIdentity;
final Integer rightIdentity;
final Integer leftAbsorbing;
final Integer rightAbsorbing;
+ final boolean associativeAndCommutative;
- private BinopDescriptor(
+ BinopDescriptor(
Integer leftIdentity,
Integer rightIdentity,
Integer leftAbsorbing,
- Integer rightAbsorbing) {
+ Integer rightAbsorbing,
+ boolean associativeAndCommutative) {
this.leftIdentity = leftIdentity;
this.rightIdentity = rightIdentity;
this.leftAbsorbing = leftAbsorbing;
this.rightAbsorbing = rightAbsorbing;
+ this.associativeAndCommutative = associativeAndCommutative;
+ }
+
+ Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+ throw new Unreachable();
+ }
+
+ int evaluate(int left, int right) {
+ throw new Unreachable();
+ }
+
+ long evaluate(long left, long right) {
+ throw new Unreachable();
+ }
+
+ boolean isShift() {
+ return false;
}
}
@@ -106,30 +259,10 @@
|| binop.getNumericType() == NumericType.LONG) {
BinopDescriptor binopDescriptor = descriptors.get(binop.getClass());
assert binopDescriptor != null;
- ConstNumber constNumber = getConstNumber(binop.leftValue());
- if (constNumber != null) {
- if (simplify(
- binop,
- iterator,
- constNumber,
- binopDescriptor.leftIdentity,
- binop.rightValue(),
- binopDescriptor.leftAbsorbing,
- binop.leftValue())) {
- continue;
- }
+ if (identityAbsorbingSimplification(iterator, binop, binopDescriptor)) {
+ continue;
}
- constNumber = getConstNumber(binop.rightValue());
- if (constNumber != null) {
- simplify(
- binop,
- iterator,
- constNumber,
- binopDescriptor.rightIdentity,
- binop.leftValue(),
- binopDescriptor.rightAbsorbing,
- binop.rightValue());
- }
+ successiveSimplification(iterator, binop, binopDescriptor, code);
}
}
}
@@ -137,6 +270,138 @@
assert code.isConsistentSSA(appView);
}
+ private void successiveSimplification(
+ InstructionListIterator iterator, Binop binop, BinopDescriptor binopDescriptor, IRCode code) {
+ if (binop.outValue().hasDebugUsers()) {
+ return;
+ }
+ ConstNumber constLeft = getConstNumber(binop.leftValue());
+ ConstNumber constRight = getConstNumber(binop.rightValue());
+ if ((constLeft != null && constRight != null) || (constLeft == null && constRight == null)) {
+ return;
+ }
+ Value otherValue = constLeft == null ? binop.leftValue() : binop.rightValue();
+ if (otherValue.isPhi() || !otherValue.getDefinition().isBinop()) {
+ return;
+ }
+ Binop prevBinop = otherValue.getDefinition().asBinop();
+ ConstNumber prevConstLeft = getConstNumber(prevBinop.leftValue());
+ ConstNumber prevConstRight = getConstNumber(prevBinop.rightValue());
+ if ((prevConstLeft != null && prevConstRight != null)
+ || (prevConstLeft == null && prevConstRight == null)) {
+ return;
+ }
+ ConstNumber constB = constLeft == null ? constRight : constLeft;
+ ConstNumber constA = prevConstLeft == null ? prevConstRight : prevConstLeft;
+ Value input = prevConstLeft == null ? prevBinop.leftValue() : prevBinop.rightValue();
+ // We have two successive binops so that a,b constants, x the input and a * x * b.
+ if (prevBinop.getClass() == binop.getClass()) {
+ if (binopDescriptor.associativeAndCommutative) {
+ // a * x * b => x * (a * b) where (a * b) is a constant.
+ assert binop.isCommutative();
+ Value newConst = addNewConstNumber(code, iterator, constB, constA, binopDescriptor);
+ iterator.replaceCurrentInstruction(
+ instantiateBinop(code, input, newConst, binopDescriptor));
+ } else if (binopDescriptor.isShift()) {
+ // x shift: a shift: b => x shift: (a + b) where a + b is a constant.
+ if (constRight != null && prevConstRight != null) {
+ Value newConst = addNewConstNumber(code, iterator, constB, constA, BinopDescriptor.ADD);
+ iterator.replaceCurrentInstruction(
+ instantiateBinop(code, input, newConst, binopDescriptor));
+ }
+ } else if (binop.isSub()) {
+ // a - x - b => (a - b) - x where (a - b) is a constant.
+ // x - a - b => x - (a + b) where (a + b) is a constant.
+ if (prevConstRight == null) {
+ Value newConst = addNewConstNumber(code, iterator, constA, constB, BinopDescriptor.SUB);
+ iterator.replaceCurrentInstruction(
+ instantiateBinop(code, newConst, input, BinopDescriptor.SUB));
+ } else {
+ Value newConst = addNewConstNumber(code, iterator, constB, constA, BinopDescriptor.ADD);
+ iterator.replaceCurrentInstruction(
+ instantiateBinop(code, input, newConst, BinopDescriptor.SUB));
+ }
+ }
+ } else {
+ if (binop.isSub() && prevBinop.isAdd()) {
+ // x + a - b => x + (a - b) where (a - b) is a constant.
+ // a + x - b => x + (a - b) where (a - b) is a constant.
+ Value newConst = addNewConstNumber(code, iterator, constA, constB, BinopDescriptor.SUB);
+ iterator.replaceCurrentInstruction(
+ instantiateBinop(code, newConst, input, BinopDescriptor.ADD));
+ } else if (binop.isAdd() && prevBinop.isSub()) {
+ // x - a + b => x - (a - b) where (a - b) is a constant.
+ // a - x + b => (a + b) - x where (a + b) is a constant.
+ if (prevConstLeft == null) {
+ Value newConst = addNewConstNumber(code, iterator, constA, constB, BinopDescriptor.SUB);
+ iterator.replaceCurrentInstruction(
+ instantiateBinop(code, input, newConst, BinopDescriptor.SUB));
+ } else {
+ Value newConst = addNewConstNumber(code, iterator, constB, constA, BinopDescriptor.ADD);
+ iterator.replaceCurrentInstruction(
+ instantiateBinop(code, newConst, input, BinopDescriptor.SUB));
+ }
+ }
+ }
+ }
+
+ private Instruction instantiateBinop(
+ IRCode code, Value left, Value right, BinopDescriptor descriptor) {
+ TypeElement representative = left.getType().isInt() ? right.getType() : left.getType();
+ Value newValue = code.createValue(representative);
+ NumericType numericType = representative.isInt() ? NumericType.INT : NumericType.LONG;
+ return descriptor.instantiate(numericType, newValue, left, right);
+ }
+
+ private Value addNewConstNumber(
+ IRCode code,
+ InstructionListIterator iterator,
+ ConstNumber left,
+ ConstNumber right,
+ BinopDescriptor descriptor) {
+ TypeElement representative =
+ left.outValue().getType().isInt() ? right.outValue().getType() : left.outValue().getType();
+ long result =
+ representative.isInt()
+ ? descriptor.evaluate(left.getIntValue(), right.getIntValue())
+ : descriptor.evaluate(left.getLongValue(), right.getLongValue());
+ iterator.previous();
+ Value value =
+ iterator.insertConstNumberInstruction(
+ code, appView.options(), result, left.outValue().getType());
+ iterator.next();
+ return value;
+ }
+
+ private boolean identityAbsorbingSimplification(
+ InstructionListIterator iterator, Binop binop, BinopDescriptor binopDescriptor) {
+ ConstNumber constNumber = getConstNumber(binop.leftValue());
+ if (constNumber != null) {
+ if (simplify(
+ binop,
+ iterator,
+ constNumber,
+ binopDescriptor.leftIdentity,
+ binop.rightValue(),
+ binopDescriptor.leftAbsorbing,
+ binop.leftValue())) {
+ return true;
+ }
+ }
+ constNumber = getConstNumber(binop.rightValue());
+ if (constNumber != null) {
+ return simplify(
+ binop,
+ iterator,
+ constNumber,
+ binopDescriptor.rightIdentity,
+ binop.leftValue(),
+ binopDescriptor.rightAbsorbing,
+ binop.rightValue());
+ }
+ return false;
+ }
+
private ConstNumber getConstNumber(Value val) {
ConstNumber constNumber = getConstNumberIfConstant(val);
if (constNumber != null) {
diff --git a/src/test/java/com/android/tools/r8/ir/AssociativeIntTest.java b/src/test/java/com/android/tools/r8/ir/AssociativeIntTest.java
new file mode 100644
index 0000000..8a6f9c0
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/AssociativeIntTest.java
@@ -0,0 +1,424 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class AssociativeIntTest extends TestBase {
+
+ private static final String EXPECTED_RESULT =
+ StringUtils.lines(
+ "Associative",
+ "7",
+ "47",
+ "-2147483644",
+ "-2147483643",
+ "12",
+ "252",
+ "-6",
+ "0",
+ "2",
+ "2",
+ "2",
+ "0",
+ "3",
+ "43",
+ "2147483647",
+ "-2147483645",
+ "3",
+ "43",
+ "2147483646",
+ "-2147483647",
+ "Shift",
+ "64",
+ "1344",
+ "-32",
+ "0",
+ "0",
+ "1",
+ "67108863",
+ "-67108864",
+ "0",
+ "1",
+ "67108863",
+ "67108864",
+ "Sub",
+ "-1",
+ "-41",
+ "-2147483646",
+ "-2147483647",
+ "-3",
+ "37",
+ "2147483642",
+ "2147483643",
+ "Mixed",
+ "3",
+ "43",
+ "-2147483648",
+ "-2147483647",
+ "3",
+ "-37",
+ "-2147483642",
+ "-2147483643",
+ "3",
+ "43",
+ "-2147483648",
+ "-2147483647",
+ "1",
+ "41",
+ "2147483646",
+ "2147483647",
+ "Double Associative",
+ "12",
+ "52",
+ "84",
+ "1764",
+ "2",
+ "2",
+ "7",
+ "47",
+ "4",
+ "44",
+ "Double Shift",
+ "128",
+ "2688",
+ "0",
+ "0",
+ "0",
+ "0",
+ "Double Sub",
+ "-1",
+ "-41",
+ "-10",
+ "30",
+ "Double Mixed",
+ "-4",
+ "36",
+ "7",
+ "-33",
+ "-4",
+ "36",
+ "5",
+ "45");
+ private final TestParameters parameters;
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withCfRuntimes().withDexRuntimes().withAllApiLevels().build();
+ }
+
+ public AssociativeIntTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void testD8() throws Exception {
+ testForRuntime(parameters)
+ .addProgramClasses(Main.class)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(EXPECTED_RESULT);
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ testForR8(parameters.getBackend())
+ .addProgramClasses(Main.class)
+ .addKeepMainRule(Main.class)
+ .enableInliningAnnotations()
+ .setMinApi(parameters)
+ .compile()
+ .inspect(this::inspect)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(EXPECTED_RESULT);
+ }
+
+ private void inspect(CodeInspector inspector) {
+ ClassSubject clazz = inspector.clazz(Main.class);
+ for (FoundMethodSubject method :
+ clazz.allMethods(m -> m.getParameters().size() > 0 && m.getParameter(0).is("int"))) {
+ assertEquals(
+ 1,
+ method
+ .streamInstructions()
+ .filter(i -> i.isIntArithmeticBinop() || i.isIntLogicalBinop())
+ .count());
+ }
+ }
+
+ public static class Main {
+
+ public static void main(String[] args) {
+ simple();
+ doubleOps();
+ }
+
+ @NeverInline
+ private static void simple() {
+ // Associative + * & | ^.
+ System.out.println("Associative");
+ add(2);
+ add(42);
+ add(Integer.MAX_VALUE);
+ add(Integer.MIN_VALUE);
+ mul(2);
+ mul(42);
+ mul(Integer.MAX_VALUE);
+ mul(Integer.MIN_VALUE);
+ and(2);
+ and(42);
+ and(Integer.MAX_VALUE);
+ and(Integer.MIN_VALUE);
+ or(2);
+ or(42);
+ or(Integer.MAX_VALUE);
+ or(Integer.MIN_VALUE);
+ xor(2);
+ xor(42);
+ xor(Integer.MAX_VALUE);
+ xor(Integer.MIN_VALUE);
+
+ // Shift composition.
+ System.out.println("Shift");
+ shl(2);
+ shl(42);
+ shl(Integer.MAX_VALUE);
+ shl(Integer.MIN_VALUE);
+ shr(2);
+ shr(42);
+ shr(Integer.MAX_VALUE);
+ shr(Integer.MIN_VALUE);
+ ushr(2);
+ ushr(42);
+ ushr(Integer.MAX_VALUE);
+ ushr(Integer.MIN_VALUE);
+
+ // Special for -.
+ System.out.println("Sub");
+ sub(2);
+ sub(42);
+ sub(Integer.MAX_VALUE);
+ sub(Integer.MIN_VALUE);
+ sub2(2);
+ sub2(42);
+ sub2(Integer.MAX_VALUE);
+ sub2(Integer.MIN_VALUE);
+
+ // Mixed for + and -.
+ System.out.println("Mixed");
+ addSub(2);
+ addSub(42);
+ addSub(Integer.MAX_VALUE);
+ addSub(Integer.MIN_VALUE);
+ subAdd(2);
+ subAdd(42);
+ subAdd(Integer.MAX_VALUE);
+ subAdd(Integer.MIN_VALUE);
+ addSub2(2);
+ addSub2(42);
+ addSub2(Integer.MAX_VALUE);
+ addSub2(Integer.MIN_VALUE);
+ subAdd2(2);
+ subAdd2(42);
+ subAdd2(Integer.MAX_VALUE);
+ subAdd2(Integer.MIN_VALUE);
+ }
+
+ @NeverInline
+ private static void doubleOps() {
+ // Associative + * & | ^.
+ System.out.println("Double Associative");
+ addDouble(2);
+ addDouble(42);
+ mulDouble(2);
+ mulDouble(42);
+ andDouble(2);
+ andDouble(42);
+ orDouble(2);
+ orDouble(42);
+ xorDouble(2);
+ xorDouble(42);
+
+ // Shift composition.
+ System.out.println("Double Shift");
+ shlDouble(2);
+ shlDouble(42);
+ shrDouble(2);
+ shrDouble(42);
+ ushrDouble(2);
+ ushrDouble(42);
+
+ // Special for -.
+ System.out.println("Double Sub");
+ subDouble(2);
+ subDouble(42);
+ sub2Double(2);
+ sub2Double(42);
+
+ // Mixed for + and -.
+ System.out.println("Double Mixed");
+ addSubDouble(2);
+ addSubDouble(42);
+ subAddDouble(2);
+ subAddDouble(42);
+ addSub2Double(2);
+ addSub2Double(42);
+ subAdd2Double(2);
+ subAdd2Double(42);
+ }
+
+ @NeverInline
+ public static void add(int x) {
+ System.out.println(3 + x + 2);
+ }
+
+ @NeverInline
+ public static void mul(int x) {
+ System.out.println(3 * x * 2);
+ }
+
+ @NeverInline
+ public static void and(int x) {
+ System.out.println(3 & x & 2);
+ }
+
+ @NeverInline
+ public static void or(int x) {
+ System.out.println(3 | x | 2);
+ }
+
+ @NeverInline
+ public static void xor(int x) {
+ System.out.println(3 ^ x ^ 2);
+ }
+
+ @NeverInline
+ public static void shl(int x) {
+ System.out.println(x << 2 << 3);
+ }
+
+ @NeverInline
+ public static void shr(int x) {
+ System.out.println(x >> 2 >> 3);
+ }
+
+ @NeverInline
+ public static void ushr(int x) {
+ System.out.println(x >>> 2 >>> 3);
+ }
+
+ @NeverInline
+ public static void sub(int x) {
+ System.out.println(3 - x - 2);
+ }
+
+ @NeverInline
+ public static void sub2(int x) {
+ System.out.println(x - 3 - 2);
+ }
+
+ @NeverInline
+ public static void addSub(int x) {
+ System.out.println(3 + x - 2);
+ }
+
+ @NeverInline
+ public static void addSub2(int x) {
+ System.out.println(x + 3 - 2);
+ }
+
+ @NeverInline
+ public static void subAdd(int x) {
+ System.out.println(3 - x + 2);
+ }
+
+ @NeverInline
+ public static void subAdd2(int x) {
+ System.out.println(x - 3 + 2);
+ }
+
+ @NeverInline
+ public static void addDouble(int x) {
+ System.out.println(3 + x + 2 + 5);
+ }
+
+ @NeverInline
+ public static void mulDouble(int x) {
+ System.out.println(3 * x * 2 * 7);
+ }
+
+ @NeverInline
+ public static void andDouble(int x) {
+ System.out.println(3 & x & 2 & 7);
+ }
+
+ @NeverInline
+ public static void orDouble(int x) {
+ System.out.println(3 | x | 2 | 7);
+ }
+
+ @NeverInline
+ public static void xorDouble(int x) {
+ System.out.println(3 ^ x ^ 2 ^ 7);
+ }
+
+ @NeverInline
+ public static void shlDouble(int x) {
+ System.out.println(x << 2 << 3 << 1);
+ }
+
+ @NeverInline
+ public static void shrDouble(int x) {
+ System.out.println(x >> 2 >> 3 >> 1);
+ }
+
+ @NeverInline
+ public static void ushrDouble(int x) {
+ System.out.println(x >>> 2 >>> 3 >>> 1);
+ }
+
+ @NeverInline
+ public static void subDouble(int x) {
+ System.out.println(3 - x - 2);
+ }
+
+ @NeverInline
+ public static void sub2Double(int x) {
+ System.out.println(x - 3 - 2 - 7);
+ }
+
+ @NeverInline
+ public static void addSubDouble(int x) {
+ System.out.println(3 + x - 2 - 7);
+ }
+
+ @NeverInline
+ public static void addSub2Double(int x) {
+ System.out.println(x + 3 - 2 - 7);
+ }
+
+ @NeverInline
+ public static void subAddDouble(int x) {
+ System.out.println(3 - x + 2 + 4);
+ }
+
+ @NeverInline
+ public static void subAdd2Double(int x) {
+ System.out.println(x - 3 + 2 + 4);
+ }
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/AssociativeLongTest.java b/src/test/java/com/android/tools/r8/ir/AssociativeLongTest.java
new file mode 100644
index 0000000..22ea7e8
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/AssociativeLongTest.java
@@ -0,0 +1,424 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class AssociativeLongTest extends TestBase {
+
+ private static final String EXPECTED_RESULT =
+ StringUtils.lines(
+ "Associative",
+ "7",
+ "47",
+ "-9223372036854775804",
+ "-9223372036854775803",
+ "12",
+ "252",
+ "-6",
+ "0",
+ "2",
+ "2",
+ "2",
+ "0",
+ "3",
+ "43",
+ "9223372036854775807",
+ "-9223372036854775805",
+ "3",
+ "43",
+ "9223372036854775806",
+ "-9223372036854775807",
+ "Shift",
+ "64",
+ "1344",
+ "-32",
+ "0",
+ "0",
+ "1",
+ "288230376151711743",
+ "-288230376151711744",
+ "0",
+ "1",
+ "288230376151711743",
+ "288230376151711744",
+ "Sub",
+ "-1",
+ "-41",
+ "-9223372036854775806",
+ "-9223372036854775807",
+ "-3",
+ "37",
+ "9223372036854775802",
+ "9223372036854775803",
+ "Mixed",
+ "3",
+ "43",
+ "-9223372036854775808",
+ "-9223372036854775807",
+ "3",
+ "-37",
+ "-9223372036854775802",
+ "-9223372036854775803",
+ "3",
+ "43",
+ "-9223372036854775808",
+ "-9223372036854775807",
+ "1",
+ "41",
+ "9223372036854775806",
+ "9223372036854775807",
+ "Double Associative",
+ "12",
+ "52",
+ "84",
+ "1764",
+ "2",
+ "2",
+ "7",
+ "47",
+ "4",
+ "44",
+ "Double Shift",
+ "128",
+ "2688",
+ "0",
+ "0",
+ "0",
+ "0",
+ "Double Sub",
+ "-1",
+ "-41",
+ "-10",
+ "30",
+ "Double Mixed",
+ "-4",
+ "36",
+ "7",
+ "-33",
+ "-4",
+ "36",
+ "5",
+ "45");
+ private final TestParameters parameters;
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withCfRuntimes().withDexRuntimes().withAllApiLevels().build();
+ }
+
+ public AssociativeLongTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void testD8() throws Exception {
+ testForRuntime(parameters)
+ .addProgramClasses(Main.class)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(EXPECTED_RESULT);
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ testForR8(parameters.getBackend())
+ .addProgramClasses(Main.class)
+ .addKeepMainRule(Main.class)
+ .enableInliningAnnotations()
+ .setMinApi(parameters)
+ .compile()
+ .inspect(this::inspect)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(EXPECTED_RESULT);
+ }
+
+ private void inspect(CodeInspector inspector) {
+ ClassSubject clazz = inspector.clazz(Main.class);
+ for (FoundMethodSubject method :
+ clazz.allMethods(m -> m.getParameters().size() > 0 && m.getParameter(0).is("long"))) {
+ assertEquals(
+ 1,
+ method
+ .streamInstructions()
+ .filter(i -> i.isIntArithmeticBinop() || i.isLongLogicalBinop())
+ .count());
+ }
+ }
+
+ public static class Main {
+
+ public static void main(String[] args) {
+ simple();
+ doubleOps();
+ }
+
+ @NeverInline
+ private static void simple() {
+ // Associative + * & | ^.
+ System.out.println("Associative");
+ add(2);
+ add(42);
+ add(Long.MAX_VALUE);
+ add(Long.MIN_VALUE);
+ mul(2);
+ mul(42);
+ mul(Long.MAX_VALUE);
+ mul(Long.MIN_VALUE);
+ and(2);
+ and(42);
+ and(Long.MAX_VALUE);
+ and(Long.MIN_VALUE);
+ or(2);
+ or(42);
+ or(Long.MAX_VALUE);
+ or(Long.MIN_VALUE);
+ xor(2);
+ xor(42);
+ xor(Long.MAX_VALUE);
+ xor(Long.MIN_VALUE);
+
+ // Shift composition.
+ System.out.println("Shift");
+ shl(2);
+ shl(42);
+ shl(Long.MAX_VALUE);
+ shl(Long.MIN_VALUE);
+ shr(2);
+ shr(42);
+ shr(Long.MAX_VALUE);
+ shr(Long.MIN_VALUE);
+ ushr(2);
+ ushr(42);
+ ushr(Long.MAX_VALUE);
+ ushr(Long.MIN_VALUE);
+
+ // Special for -.
+ System.out.println("Sub");
+ sub(2);
+ sub(42);
+ sub(Long.MAX_VALUE);
+ sub(Long.MIN_VALUE);
+ sub2(2);
+ sub2(42);
+ sub2(Long.MAX_VALUE);
+ sub2(Long.MIN_VALUE);
+
+ // Mixed for + and -.
+ System.out.println("Mixed");
+ addSub(2);
+ addSub(42);
+ addSub(Long.MAX_VALUE);
+ addSub(Long.MIN_VALUE);
+ subAdd(2);
+ subAdd(42);
+ subAdd(Long.MAX_VALUE);
+ subAdd(Long.MIN_VALUE);
+ addSub2(2);
+ addSub2(42);
+ addSub2(Long.MAX_VALUE);
+ addSub2(Long.MIN_VALUE);
+ subAdd2(2);
+ subAdd2(42);
+ subAdd2(Long.MAX_VALUE);
+ subAdd2(Long.MIN_VALUE);
+ }
+
+ @NeverInline
+ private static void doubleOps() {
+ // Associative + * & | ^.
+ System.out.println("Double Associative");
+ addDouble(2);
+ addDouble(42);
+ mulDouble(2);
+ mulDouble(42);
+ andDouble(2);
+ andDouble(42);
+ orDouble(2);
+ orDouble(42);
+ xorDouble(2);
+ xorDouble(42);
+
+ // Shift composition.
+ System.out.println("Double Shift");
+ shlDouble(2);
+ shlDouble(42);
+ shrDouble(2);
+ shrDouble(42);
+ ushrDouble(2);
+ ushrDouble(42);
+
+ // Special for -.
+ System.out.println("Double Sub");
+ subDouble(2);
+ subDouble(42);
+ sub2Double(2);
+ sub2Double(42);
+
+ // Mixed for + and -.
+ System.out.println("Double Mixed");
+ addSubDouble(2);
+ addSubDouble(42);
+ subAddDouble(2);
+ subAddDouble(42);
+ addSub2Double(2);
+ addSub2Double(42);
+ subAdd2Double(2);
+ subAdd2Double(42);
+ }
+
+ @NeverInline
+ public static void add(long x) {
+ System.out.println(3L + x + 2L);
+ }
+
+ @NeverInline
+ public static void mul(long x) {
+ System.out.println(3L * x * 2L);
+ }
+
+ @NeverInline
+ public static void and(long x) {
+ System.out.println(3L & x & 2L);
+ }
+
+ @NeverInline
+ public static void or(long x) {
+ System.out.println(3L | x | 2L);
+ }
+
+ @NeverInline
+ public static void xor(long x) {
+ System.out.println(3L ^ x ^ 2L);
+ }
+
+ @NeverInline
+ public static void shl(long x) {
+ System.out.println(x << 2L << 3L);
+ }
+
+ @NeverInline
+ public static void shr(long x) {
+ System.out.println(x >> 2L >> 3L);
+ }
+
+ @NeverInline
+ public static void ushr(long x) {
+ System.out.println(x >>> 2L >>> 3L);
+ }
+
+ @NeverInline
+ public static void sub(long x) {
+ System.out.println(3L - x - 2L);
+ }
+
+ @NeverInline
+ public static void sub2(long x) {
+ System.out.println(x - 3L - 2L);
+ }
+
+ @NeverInline
+ public static void addSub(long x) {
+ System.out.println(3L + x - 2L);
+ }
+
+ @NeverInline
+ public static void addSub2(long x) {
+ System.out.println(x + 3L - 2L);
+ }
+
+ @NeverInline
+ public static void subAdd(long x) {
+ System.out.println(3L - x + 2L);
+ }
+
+ @NeverInline
+ public static void subAdd2(long x) {
+ System.out.println(x - 3L + 2L);
+ }
+
+ @NeverInline
+ public static void addDouble(long x) {
+ System.out.println(3L + x + 2L + 5);
+ }
+
+ @NeverInline
+ public static void mulDouble(long x) {
+ System.out.println(3L * x * 2L * 7L);
+ }
+
+ @NeverInline
+ public static void andDouble(long x) {
+ System.out.println(3L & x & 2L & 7L);
+ }
+
+ @NeverInline
+ public static void orDouble(long x) {
+ System.out.println(3L | x | 2L | 7L);
+ }
+
+ @NeverInline
+ public static void xorDouble(long x) {
+ System.out.println(3L ^ x ^ 2L ^ 7L);
+ }
+
+ @NeverInline
+ public static void shlDouble(long x) {
+ System.out.println(x << 2L << 3L << 1L);
+ }
+
+ @NeverInline
+ public static void shrDouble(long x) {
+ System.out.println(x >> 2L >> 3L >> 1L);
+ }
+
+ @NeverInline
+ public static void ushrDouble(long x) {
+ System.out.println(x >>> 2L >>> 3L >>> 1L);
+ }
+
+ @NeverInline
+ public static void subDouble(long x) {
+ System.out.println(3L - x - 2L);
+ }
+
+ @NeverInline
+ public static void sub2Double(long x) {
+ System.out.println(x - 3L - 2L - 7L);
+ }
+
+ @NeverInline
+ public static void addSubDouble(long x) {
+ System.out.println(3L + x - 2L - 7L);
+ }
+
+ @NeverInline
+ public static void addSub2Double(long x) {
+ System.out.println(x + 3L - 2L - 7L);
+ }
+
+ @NeverInline
+ public static void subAddDouble(long x) {
+ System.out.println(3L - x + 2L + 4L);
+ }
+
+ @NeverInline
+ public static void subAdd2Double(long x) {
+ System.out.println(x - 3L + 2L + 4L);
+ }
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/IdentityAbsorbingTest.java b/src/test/java/com/android/tools/r8/ir/IdentityAbsorbingTest.java
index 390d6ca..3f9a8b2 100644
--- a/src/test/java/com/android/tools/r8/ir/IdentityAbsorbingTest.java
+++ b/src/test/java/com/android/tools/r8/ir/IdentityAbsorbingTest.java
@@ -853,7 +853,11 @@
assertTrue(
m.streamInstructions()
.noneMatch(
- i -> i.isIntOrLongLogicalBinop() || i.isIntOrLongArithmeticBinop())));
+ i ->
+ i.isIntLogicalBinop()
+ || i.isLongLogicalBinop()
+ || i.isIntArithmeticBinop()
+ || i.isLongArithmeticBinop())));
}
static class Main {
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
index 45d6202..79425a3 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
@@ -339,17 +339,27 @@
}
@Override
- public boolean isIntOrLongArithmeticBinop() {
+ public boolean isIntArithmeticBinop() {
return instruction instanceof CfArithmeticBinop
- && (((CfArithmeticBinop) instruction).getType() == NumericType.INT
- || ((CfArithmeticBinop) instruction).getType() == NumericType.LONG);
+ && ((CfArithmeticBinop) instruction).getType() == NumericType.INT;
}
@Override
- public boolean isIntOrLongLogicalBinop() {
+ public boolean isIntLogicalBinop() {
return instruction instanceof CfLogicalBinop
- && (((CfLogicalBinop) instruction).getType() == NumericType.INT
- || ((CfLogicalBinop) instruction).getType() == NumericType.LONG);
+ && ((CfLogicalBinop) instruction).getType() == NumericType.INT;
+ }
+
+ @Override
+ public boolean isLongArithmeticBinop() {
+ return instruction instanceof CfArithmeticBinop
+ && ((CfArithmeticBinop) instruction).getType() == NumericType.LONG;
+ }
+
+ @Override
+ public boolean isLongLogicalBinop() {
+ return instruction instanceof CfLogicalBinop
+ && ((CfLogicalBinop) instruction).getType() == NumericType.LONG;
}
@Override
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
index 38421a2..bd68422 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
@@ -122,6 +122,8 @@
import com.android.tools.r8.dex.code.DexReturn;
import com.android.tools.r8.dex.code.DexReturnObject;
import com.android.tools.r8.dex.code.DexReturnVoid;
+import com.android.tools.r8.dex.code.DexRsubInt;
+import com.android.tools.r8.dex.code.DexRsubIntLit8;
import com.android.tools.r8.dex.code.DexSget;
import com.android.tools.r8.dex.code.DexSgetBoolean;
import com.android.tools.r8.dex.code.DexSgetByte;
@@ -501,69 +503,77 @@
return instruction instanceof DexSparseSwitch;
}
- public boolean isIntOrLongArithmeticBinop() {
+ public boolean isIntArithmeticBinop() {
return instruction instanceof DexMulInt
|| instruction instanceof DexMulIntLit8
|| instruction instanceof DexMulIntLit16
|| instruction instanceof DexMulInt2Addr
- || instruction instanceof DexMulLong
- || instruction instanceof DexMulLong2Addr
|| instruction instanceof DexAddInt
|| instruction instanceof DexAddIntLit8
|| instruction instanceof DexAddIntLit16
|| instruction instanceof DexAddInt2Addr
- || instruction instanceof DexAddLong
- || instruction instanceof DexAddLong2Addr
+ || instruction instanceof DexRsubInt
+ || instruction instanceof DexRsubIntLit8
|| instruction instanceof DexSubInt
|| instruction instanceof DexSubInt2Addr
- || instruction instanceof DexSubLong
- || instruction instanceof DexSubLong2Addr
|| instruction instanceof DexDivInt
|| instruction instanceof DexDivIntLit8
|| instruction instanceof DexDivIntLit16
|| instruction instanceof DexDivInt2Addr
- || instruction instanceof DexDivLong
- || instruction instanceof DexDivLong2Addr
|| instruction instanceof DexRemInt
|| instruction instanceof DexRemIntLit8
|| instruction instanceof DexRemIntLit16
- || instruction instanceof DexRemInt2Addr
+ || instruction instanceof DexRemInt2Addr;
+ }
+
+ public boolean isLongArithmeticBinop() {
+ return instruction instanceof DexMulLong
+ || instruction instanceof DexMulLong2Addr
+ || instruction instanceof DexAddLong
+ || instruction instanceof DexAddLong2Addr
+ || instruction instanceof DexSubLong
+ || instruction instanceof DexSubLong2Addr
+ || instruction instanceof DexDivLong
+ || instruction instanceof DexDivLong2Addr
|| instruction instanceof DexRemLong
|| instruction instanceof DexRemLong2Addr;
}
- public boolean isIntOrLongLogicalBinop() {
+ public boolean isIntLogicalBinop() {
return instruction instanceof DexAndInt
|| instruction instanceof DexAndIntLit8
|| instruction instanceof DexAndIntLit16
|| instruction instanceof DexAndInt2Addr
- || instruction instanceof DexAndLong
- || instruction instanceof DexAndLong2Addr
|| instruction instanceof DexOrInt
|| instruction instanceof DexOrIntLit8
|| instruction instanceof DexOrIntLit16
|| instruction instanceof DexOrInt2Addr
- || instruction instanceof DexOrLong
- || instruction instanceof DexOrLong2Addr
|| instruction instanceof DexXorInt
|| instruction instanceof DexXorIntLit8
|| instruction instanceof DexXorIntLit16
|| instruction instanceof DexXorInt2Addr
- || instruction instanceof DexXorLong
- || instruction instanceof DexXorLong2Addr
|| instruction instanceof DexShrInt
|| instruction instanceof DexShrIntLit8
|| instruction instanceof DexShrInt2Addr
- || instruction instanceof DexShrLong
- || instruction instanceof DexShrLong2Addr
|| instruction instanceof DexShlInt
|| instruction instanceof DexShlIntLit8
|| instruction instanceof DexShlInt2Addr
- || instruction instanceof DexShlLong
- || instruction instanceof DexShlLong2Addr
|| instruction instanceof DexUshrInt
|| instruction instanceof DexUshrIntLit8
- || instruction instanceof DexUshrInt2Addr
+ || instruction instanceof DexUshrInt2Addr;
+ }
+
+ public boolean isLongLogicalBinop() {
+ return instruction instanceof DexAndLong
+ || instruction instanceof DexAndLong2Addr
+ || instruction instanceof DexOrLong
+ || instruction instanceof DexOrLong2Addr
+ || instruction instanceof DexXorLong
+ || instruction instanceof DexXorLong2Addr
+ || instruction instanceof DexShrLong
+ || instruction instanceof DexShrLong2Addr
+ || instruction instanceof DexShlLong
+ || instruction instanceof DexShlLong2Addr
|| instruction instanceof DexUshrLong
|| instruction instanceof DexUshrLong2Addr;
}
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
index 6d1bd94..b77029f 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
@@ -140,9 +140,13 @@
boolean isSparseSwitch();
- boolean isIntOrLongArithmeticBinop();
+ boolean isIntArithmeticBinop();
- boolean isIntOrLongLogicalBinop();
+ boolean isIntLogicalBinop();
+
+ boolean isLongArithmeticBinop();
+
+ boolean isLongLogicalBinop();
boolean isMultiplication();