Extend Binop rewriter

- associative property for add, mul, and, or, xor
- shift combination
- sub combination
- mixed add/sub combination

Change-Id: I68b1fa1d7cab4f9787b368055acd68703354183b
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
index 93ad3ff..f642a95 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
@@ -4,9 +4,11 @@
 
 package com.android.tools.r8.ir.conversion.passes;
 
+import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.code.Add;
 import com.android.tools.r8.ir.code.And;
 import com.android.tools.r8.ir.code.Binop;
@@ -42,18 +44,17 @@
 
   private Map<Class<?>, BinopDescriptor> createBinopDescriptors() {
     ImmutableMap.Builder<Class<?>, BinopDescriptor> builder = ImmutableMap.builder();
-    builder.put(Add.class, new BinopDescriptor(0, 0, null, null));
-    builder.put(Sub.class, new BinopDescriptor(null, 0, null, null));
-    builder.put(Mul.class, new BinopDescriptor(1, 1, 0, 0));
-    // The following two can be improved if we handle ZeroDivide.
-    builder.put(Div.class, new BinopDescriptor(null, 1, null, null));
-    builder.put(Rem.class, new BinopDescriptor(null, null, null, null));
-    builder.put(And.class, new BinopDescriptor(ALL_BITS_SET, ALL_BITS_SET, 0, 0));
-    builder.put(Or.class, new BinopDescriptor(0, 0, ALL_BITS_SET, ALL_BITS_SET));
-    builder.put(Xor.class, new BinopDescriptor(0, 0, null, null));
-    builder.put(Shl.class, new BinopDescriptor(null, 0, 0, null));
-    builder.put(Shr.class, new BinopDescriptor(null, 0, 0, null));
-    builder.put(Ushr.class, new BinopDescriptor(null, 0, 0, null));
+    builder.put(Add.class, BinopDescriptor.ADD);
+    builder.put(Sub.class, BinopDescriptor.SUB);
+    builder.put(Mul.class, BinopDescriptor.MUL);
+    builder.put(Div.class, BinopDescriptor.DIV);
+    builder.put(Rem.class, BinopDescriptor.REM);
+    builder.put(And.class, BinopDescriptor.AND);
+    builder.put(Or.class, BinopDescriptor.OR);
+    builder.put(Xor.class, BinopDescriptor.XOR);
+    builder.put(Shl.class, BinopDescriptor.SHL);
+    builder.put(Shr.class, BinopDescriptor.SHR);
+    builder.put(Ushr.class, BinopDescriptor.USHR);
     return builder.build();
   }
 
@@ -64,24 +65,176 @@
    * - i is right identity if for each x in K, x * i = x.
    * - a is left absorbing if for each x in K, a * x = a.
    * - a is right absorbing if for each x in K, x * a = a.
+   * In a space K, a binop * is associative if for each x,y,z in K, (x * y) * z = x * (y * z).
    * </code>
    */
-  private static class BinopDescriptor {
+  private enum BinopDescriptor {
+    ADD(0, 0, null, null, true) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return Add.create(numericType, dest, left, right);
+      }
+
+      @Override
+      int evaluate(int left, int right) {
+        return left + right;
+      }
+
+      @Override
+      long evaluate(long left, long right) {
+        return left + right;
+      }
+    },
+    SUB(null, 0, null, null, false) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return new Sub(numericType, dest, left, right);
+      }
+
+      @Override
+      int evaluate(int left, int right) {
+        return left - right;
+      }
+
+      @Override
+      long evaluate(long left, long right) {
+        return left - right;
+      }
+    },
+    MUL(1, 1, 0, 0, true) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return Mul.create(numericType, dest, left, right);
+      }
+
+      @Override
+      int evaluate(int left, int right) {
+        return left * right;
+      }
+
+      @Override
+      long evaluate(long left, long right) {
+        return left * right;
+      }
+    },
+    // The following two can be improved if we handle ZeroDivide.
+    DIV(null, 1, null, null, false),
+    REM(null, null, null, null, false),
+    AND(ALL_BITS_SET, ALL_BITS_SET, 0, 0, true) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return And.create(numericType, dest, left, right);
+      }
+
+      @Override
+      int evaluate(int left, int right) {
+        return left & right;
+      }
+
+      @Override
+      long evaluate(long left, long right) {
+        return left & right;
+      }
+    },
+    OR(0, 0, ALL_BITS_SET, ALL_BITS_SET, true) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return Or.create(numericType, dest, left, right);
+      }
+
+      @Override
+      int evaluate(int left, int right) {
+        return left | right;
+      }
+
+      @Override
+      long evaluate(long left, long right) {
+        return left | right;
+      }
+    },
+    XOR(0, 0, null, null, true) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return Xor.create(numericType, dest, left, right);
+      }
+
+      @Override
+      int evaluate(int left, int right) {
+        return left ^ right;
+      }
+
+      @Override
+      long evaluate(long left, long right) {
+        return left ^ right;
+      }
+    },
+    SHL(null, 0, 0, null, false) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return new Shl(numericType, dest, left, right);
+      }
+
+      @Override
+      boolean isShift() {
+        return true;
+      }
+    },
+    SHR(null, 0, 0, null, false) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return new Shr(numericType, dest, left, right);
+      }
+
+      @Override
+      boolean isShift() {
+        return true;
+      }
+    },
+    USHR(null, 0, 0, null, false) {
+      @Override
+      Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+        return new Ushr(numericType, dest, left, right);
+      }
+
+      @Override
+      boolean isShift() {
+        return true;
+      }
+    };
 
     final Integer leftIdentity;
     final Integer rightIdentity;
     final Integer leftAbsorbing;
     final Integer rightAbsorbing;
+    final boolean associativeAndCommutative;
 
-    private BinopDescriptor(
+    BinopDescriptor(
         Integer leftIdentity,
         Integer rightIdentity,
         Integer leftAbsorbing,
-        Integer rightAbsorbing) {
+        Integer rightAbsorbing,
+        boolean associativeAndCommutative) {
       this.leftIdentity = leftIdentity;
       this.rightIdentity = rightIdentity;
       this.leftAbsorbing = leftAbsorbing;
       this.rightAbsorbing = rightAbsorbing;
+      this.associativeAndCommutative = associativeAndCommutative;
+    }
+
+    Binop instantiate(NumericType numericType, Value dest, Value left, Value right) {
+      throw new Unreachable();
+    }
+
+    int evaluate(int left, int right) {
+      throw new Unreachable();
+    }
+
+    long evaluate(long left, long right) {
+      throw new Unreachable();
+    }
+
+    boolean isShift() {
+      return false;
     }
   }
 
@@ -106,30 +259,10 @@
             || binop.getNumericType() == NumericType.LONG) {
           BinopDescriptor binopDescriptor = descriptors.get(binop.getClass());
           assert binopDescriptor != null;
-          ConstNumber constNumber = getConstNumber(binop.leftValue());
-          if (constNumber != null) {
-            if (simplify(
-                binop,
-                iterator,
-                constNumber,
-                binopDescriptor.leftIdentity,
-                binop.rightValue(),
-                binopDescriptor.leftAbsorbing,
-                binop.leftValue())) {
-              continue;
-            }
+          if (identityAbsorbingSimplification(iterator, binop, binopDescriptor)) {
+            continue;
           }
-          constNumber = getConstNumber(binop.rightValue());
-          if (constNumber != null) {
-            simplify(
-                binop,
-                iterator,
-                constNumber,
-                binopDescriptor.rightIdentity,
-                binop.leftValue(),
-                binopDescriptor.rightAbsorbing,
-                binop.rightValue());
-          }
+          successiveSimplification(iterator, binop, binopDescriptor, code);
         }
       }
     }
@@ -137,6 +270,138 @@
     assert code.isConsistentSSA(appView);
   }
 
+  private void successiveSimplification(
+      InstructionListIterator iterator, Binop binop, BinopDescriptor binopDescriptor, IRCode code) {
+    if (binop.outValue().hasDebugUsers()) {
+      return;
+    }
+    ConstNumber constLeft = getConstNumber(binop.leftValue());
+    ConstNumber constRight = getConstNumber(binop.rightValue());
+    if ((constLeft != null && constRight != null) || (constLeft == null && constRight == null)) {
+      return;
+    }
+    Value otherValue = constLeft == null ? binop.leftValue() : binop.rightValue();
+    if (otherValue.isPhi() || !otherValue.getDefinition().isBinop()) {
+      return;
+    }
+    Binop prevBinop = otherValue.getDefinition().asBinop();
+    ConstNumber prevConstLeft = getConstNumber(prevBinop.leftValue());
+    ConstNumber prevConstRight = getConstNumber(prevBinop.rightValue());
+    if ((prevConstLeft != null && prevConstRight != null)
+        || (prevConstLeft == null && prevConstRight == null)) {
+      return;
+    }
+    ConstNumber constB = constLeft == null ? constRight : constLeft;
+    ConstNumber constA = prevConstLeft == null ? prevConstRight : prevConstLeft;
+    Value input = prevConstLeft == null ? prevBinop.leftValue() : prevBinop.rightValue();
+    // We have two successive binops so that a,b constants, x the input and a * x * b.
+    if (prevBinop.getClass() == binop.getClass()) {
+      if (binopDescriptor.associativeAndCommutative) {
+        // a * x * b => x * (a * b) where (a * b) is a constant.
+        assert binop.isCommutative();
+        Value newConst = addNewConstNumber(code, iterator, constB, constA, binopDescriptor);
+        iterator.replaceCurrentInstruction(
+            instantiateBinop(code, input, newConst, binopDescriptor));
+      } else if (binopDescriptor.isShift()) {
+        // x shift: a shift: b => x shift: (a + b) where a + b is a constant.
+        if (constRight != null && prevConstRight != null) {
+          Value newConst = addNewConstNumber(code, iterator, constB, constA, BinopDescriptor.ADD);
+          iterator.replaceCurrentInstruction(
+              instantiateBinop(code, input, newConst, binopDescriptor));
+        }
+      } else if (binop.isSub()) {
+        // a - x - b => (a - b) - x where (a - b) is a constant.
+        // x - a - b => x - (a + b) where (a + b) is a constant.
+        if (prevConstRight == null) {
+          Value newConst = addNewConstNumber(code, iterator, constA, constB, BinopDescriptor.SUB);
+          iterator.replaceCurrentInstruction(
+              instantiateBinop(code, newConst, input, BinopDescriptor.SUB));
+        } else {
+          Value newConst = addNewConstNumber(code, iterator, constB, constA, BinopDescriptor.ADD);
+          iterator.replaceCurrentInstruction(
+              instantiateBinop(code, input, newConst, BinopDescriptor.SUB));
+        }
+      }
+    } else {
+      if (binop.isSub() && prevBinop.isAdd()) {
+        // x + a - b => x + (a - b) where (a - b) is a constant.
+        // a + x - b => x + (a - b) where (a - b) is a constant.
+        Value newConst = addNewConstNumber(code, iterator, constA, constB, BinopDescriptor.SUB);
+        iterator.replaceCurrentInstruction(
+            instantiateBinop(code, newConst, input, BinopDescriptor.ADD));
+      } else if (binop.isAdd() && prevBinop.isSub()) {
+        // x - a + b => x - (a - b) where (a - b) is a constant.
+        // a - x + b => (a + b) - x where (a + b) is a constant.
+        if (prevConstLeft == null) {
+          Value newConst = addNewConstNumber(code, iterator, constA, constB, BinopDescriptor.SUB);
+          iterator.replaceCurrentInstruction(
+              instantiateBinop(code, input, newConst, BinopDescriptor.SUB));
+        } else {
+          Value newConst = addNewConstNumber(code, iterator, constB, constA, BinopDescriptor.ADD);
+          iterator.replaceCurrentInstruction(
+              instantiateBinop(code, newConst, input, BinopDescriptor.SUB));
+        }
+      }
+    }
+  }
+
+  private Instruction instantiateBinop(
+      IRCode code, Value left, Value right, BinopDescriptor descriptor) {
+    TypeElement representative = left.getType().isInt() ? right.getType() : left.getType();
+    Value newValue = code.createValue(representative);
+    NumericType numericType = representative.isInt() ? NumericType.INT : NumericType.LONG;
+    return descriptor.instantiate(numericType, newValue, left, right);
+  }
+
+  private Value addNewConstNumber(
+      IRCode code,
+      InstructionListIterator iterator,
+      ConstNumber left,
+      ConstNumber right,
+      BinopDescriptor descriptor) {
+    TypeElement representative =
+        left.outValue().getType().isInt() ? right.outValue().getType() : left.outValue().getType();
+    long result =
+        representative.isInt()
+            ? descriptor.evaluate(left.getIntValue(), right.getIntValue())
+            : descriptor.evaluate(left.getLongValue(), right.getLongValue());
+    iterator.previous();
+    Value value =
+        iterator.insertConstNumberInstruction(
+            code, appView.options(), result, left.outValue().getType());
+    iterator.next();
+    return value;
+  }
+
+  private boolean identityAbsorbingSimplification(
+      InstructionListIterator iterator, Binop binop, BinopDescriptor binopDescriptor) {
+    ConstNumber constNumber = getConstNumber(binop.leftValue());
+    if (constNumber != null) {
+      if (simplify(
+          binop,
+          iterator,
+          constNumber,
+          binopDescriptor.leftIdentity,
+          binop.rightValue(),
+          binopDescriptor.leftAbsorbing,
+          binop.leftValue())) {
+        return true;
+      }
+    }
+    constNumber = getConstNumber(binop.rightValue());
+    if (constNumber != null) {
+      return simplify(
+          binop,
+          iterator,
+          constNumber,
+          binopDescriptor.rightIdentity,
+          binop.leftValue(),
+          binopDescriptor.rightAbsorbing,
+          binop.rightValue());
+    }
+    return false;
+  }
+
   private ConstNumber getConstNumber(Value val) {
     ConstNumber constNumber = getConstNumberIfConstant(val);
     if (constNumber != null) {
diff --git a/src/test/java/com/android/tools/r8/ir/AssociativeIntTest.java b/src/test/java/com/android/tools/r8/ir/AssociativeIntTest.java
new file mode 100644
index 0000000..8a6f9c0
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/AssociativeIntTest.java
@@ -0,0 +1,424 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class AssociativeIntTest extends TestBase {
+
+  private static final String EXPECTED_RESULT =
+      StringUtils.lines(
+          "Associative",
+          "7",
+          "47",
+          "-2147483644",
+          "-2147483643",
+          "12",
+          "252",
+          "-6",
+          "0",
+          "2",
+          "2",
+          "2",
+          "0",
+          "3",
+          "43",
+          "2147483647",
+          "-2147483645",
+          "3",
+          "43",
+          "2147483646",
+          "-2147483647",
+          "Shift",
+          "64",
+          "1344",
+          "-32",
+          "0",
+          "0",
+          "1",
+          "67108863",
+          "-67108864",
+          "0",
+          "1",
+          "67108863",
+          "67108864",
+          "Sub",
+          "-1",
+          "-41",
+          "-2147483646",
+          "-2147483647",
+          "-3",
+          "37",
+          "2147483642",
+          "2147483643",
+          "Mixed",
+          "3",
+          "43",
+          "-2147483648",
+          "-2147483647",
+          "3",
+          "-37",
+          "-2147483642",
+          "-2147483643",
+          "3",
+          "43",
+          "-2147483648",
+          "-2147483647",
+          "1",
+          "41",
+          "2147483646",
+          "2147483647",
+          "Double Associative",
+          "12",
+          "52",
+          "84",
+          "1764",
+          "2",
+          "2",
+          "7",
+          "47",
+          "4",
+          "44",
+          "Double Shift",
+          "128",
+          "2688",
+          "0",
+          "0",
+          "0",
+          "0",
+          "Double Sub",
+          "-1",
+          "-41",
+          "-10",
+          "30",
+          "Double Mixed",
+          "-4",
+          "36",
+          "7",
+          "-33",
+          "-4",
+          "36",
+          "5",
+          "45");
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withCfRuntimes().withDexRuntimes().withAllApiLevels().build();
+  }
+
+  public AssociativeIntTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    testForRuntime(parameters)
+        .addProgramClasses(Main.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClasses(Main.class)
+        .addKeepMainRule(Main.class)
+        .enableInliningAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+
+  private void inspect(CodeInspector inspector) {
+    ClassSubject clazz = inspector.clazz(Main.class);
+    for (FoundMethodSubject method :
+        clazz.allMethods(m -> m.getParameters().size() > 0 && m.getParameter(0).is("int"))) {
+      assertEquals(
+          1,
+          method
+              .streamInstructions()
+              .filter(i -> i.isIntArithmeticBinop() || i.isIntLogicalBinop())
+              .count());
+    }
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      simple();
+      doubleOps();
+    }
+
+    @NeverInline
+    private static void simple() {
+      // Associative + * & | ^.
+      System.out.println("Associative");
+      add(2);
+      add(42);
+      add(Integer.MAX_VALUE);
+      add(Integer.MIN_VALUE);
+      mul(2);
+      mul(42);
+      mul(Integer.MAX_VALUE);
+      mul(Integer.MIN_VALUE);
+      and(2);
+      and(42);
+      and(Integer.MAX_VALUE);
+      and(Integer.MIN_VALUE);
+      or(2);
+      or(42);
+      or(Integer.MAX_VALUE);
+      or(Integer.MIN_VALUE);
+      xor(2);
+      xor(42);
+      xor(Integer.MAX_VALUE);
+      xor(Integer.MIN_VALUE);
+
+      // Shift composition.
+      System.out.println("Shift");
+      shl(2);
+      shl(42);
+      shl(Integer.MAX_VALUE);
+      shl(Integer.MIN_VALUE);
+      shr(2);
+      shr(42);
+      shr(Integer.MAX_VALUE);
+      shr(Integer.MIN_VALUE);
+      ushr(2);
+      ushr(42);
+      ushr(Integer.MAX_VALUE);
+      ushr(Integer.MIN_VALUE);
+
+      // Special for -.
+      System.out.println("Sub");
+      sub(2);
+      sub(42);
+      sub(Integer.MAX_VALUE);
+      sub(Integer.MIN_VALUE);
+      sub2(2);
+      sub2(42);
+      sub2(Integer.MAX_VALUE);
+      sub2(Integer.MIN_VALUE);
+
+      // Mixed for + and -.
+      System.out.println("Mixed");
+      addSub(2);
+      addSub(42);
+      addSub(Integer.MAX_VALUE);
+      addSub(Integer.MIN_VALUE);
+      subAdd(2);
+      subAdd(42);
+      subAdd(Integer.MAX_VALUE);
+      subAdd(Integer.MIN_VALUE);
+      addSub2(2);
+      addSub2(42);
+      addSub2(Integer.MAX_VALUE);
+      addSub2(Integer.MIN_VALUE);
+      subAdd2(2);
+      subAdd2(42);
+      subAdd2(Integer.MAX_VALUE);
+      subAdd2(Integer.MIN_VALUE);
+    }
+
+    @NeverInline
+    private static void doubleOps() {
+      // Associative + * & | ^.
+      System.out.println("Double Associative");
+      addDouble(2);
+      addDouble(42);
+      mulDouble(2);
+      mulDouble(42);
+      andDouble(2);
+      andDouble(42);
+      orDouble(2);
+      orDouble(42);
+      xorDouble(2);
+      xorDouble(42);
+
+      // Shift composition.
+      System.out.println("Double Shift");
+      shlDouble(2);
+      shlDouble(42);
+      shrDouble(2);
+      shrDouble(42);
+      ushrDouble(2);
+      ushrDouble(42);
+
+      // Special for -.
+      System.out.println("Double Sub");
+      subDouble(2);
+      subDouble(42);
+      sub2Double(2);
+      sub2Double(42);
+
+      // Mixed for + and -.
+      System.out.println("Double Mixed");
+      addSubDouble(2);
+      addSubDouble(42);
+      subAddDouble(2);
+      subAddDouble(42);
+      addSub2Double(2);
+      addSub2Double(42);
+      subAdd2Double(2);
+      subAdd2Double(42);
+    }
+
+    @NeverInline
+    public static void add(int x) {
+      System.out.println(3 + x + 2);
+    }
+
+    @NeverInline
+    public static void mul(int x) {
+      System.out.println(3 * x * 2);
+    }
+
+    @NeverInline
+    public static void and(int x) {
+      System.out.println(3 & x & 2);
+    }
+
+    @NeverInline
+    public static void or(int x) {
+      System.out.println(3 | x | 2);
+    }
+
+    @NeverInline
+    public static void xor(int x) {
+      System.out.println(3 ^ x ^ 2);
+    }
+
+    @NeverInline
+    public static void shl(int x) {
+      System.out.println(x << 2 << 3);
+    }
+
+    @NeverInline
+    public static void shr(int x) {
+      System.out.println(x >> 2 >> 3);
+    }
+
+    @NeverInline
+    public static void ushr(int x) {
+      System.out.println(x >>> 2 >>> 3);
+    }
+
+    @NeverInline
+    public static void sub(int x) {
+      System.out.println(3 - x - 2);
+    }
+
+    @NeverInline
+    public static void sub2(int x) {
+      System.out.println(x - 3 - 2);
+    }
+
+    @NeverInline
+    public static void addSub(int x) {
+      System.out.println(3 + x - 2);
+    }
+
+    @NeverInline
+    public static void addSub2(int x) {
+      System.out.println(x + 3 - 2);
+    }
+
+    @NeverInline
+    public static void subAdd(int x) {
+      System.out.println(3 - x + 2);
+    }
+
+    @NeverInline
+    public static void subAdd2(int x) {
+      System.out.println(x - 3 + 2);
+    }
+
+    @NeverInline
+    public static void addDouble(int x) {
+      System.out.println(3 + x + 2 + 5);
+    }
+
+    @NeverInline
+    public static void mulDouble(int x) {
+      System.out.println(3 * x * 2 * 7);
+    }
+
+    @NeverInline
+    public static void andDouble(int x) {
+      System.out.println(3 & x & 2 & 7);
+    }
+
+    @NeverInline
+    public static void orDouble(int x) {
+      System.out.println(3 | x | 2 | 7);
+    }
+
+    @NeverInline
+    public static void xorDouble(int x) {
+      System.out.println(3 ^ x ^ 2 ^ 7);
+    }
+
+    @NeverInline
+    public static void shlDouble(int x) {
+      System.out.println(x << 2 << 3 << 1);
+    }
+
+    @NeverInline
+    public static void shrDouble(int x) {
+      System.out.println(x >> 2 >> 3 >> 1);
+    }
+
+    @NeverInline
+    public static void ushrDouble(int x) {
+      System.out.println(x >>> 2 >>> 3 >>> 1);
+    }
+
+    @NeverInline
+    public static void subDouble(int x) {
+      System.out.println(3 - x - 2);
+    }
+
+    @NeverInline
+    public static void sub2Double(int x) {
+      System.out.println(x - 3 - 2 - 7);
+    }
+
+    @NeverInline
+    public static void addSubDouble(int x) {
+      System.out.println(3 + x - 2 - 7);
+    }
+
+    @NeverInline
+    public static void addSub2Double(int x) {
+      System.out.println(x + 3 - 2 - 7);
+    }
+
+    @NeverInline
+    public static void subAddDouble(int x) {
+      System.out.println(3 - x + 2 + 4);
+    }
+
+    @NeverInline
+    public static void subAdd2Double(int x) {
+      System.out.println(x - 3 + 2 + 4);
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/AssociativeLongTest.java b/src/test/java/com/android/tools/r8/ir/AssociativeLongTest.java
new file mode 100644
index 0000000..22ea7e8
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/AssociativeLongTest.java
@@ -0,0 +1,424 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class AssociativeLongTest extends TestBase {
+
+  private static final String EXPECTED_RESULT =
+      StringUtils.lines(
+          "Associative",
+          "7",
+          "47",
+          "-9223372036854775804",
+          "-9223372036854775803",
+          "12",
+          "252",
+          "-6",
+          "0",
+          "2",
+          "2",
+          "2",
+          "0",
+          "3",
+          "43",
+          "9223372036854775807",
+          "-9223372036854775805",
+          "3",
+          "43",
+          "9223372036854775806",
+          "-9223372036854775807",
+          "Shift",
+          "64",
+          "1344",
+          "-32",
+          "0",
+          "0",
+          "1",
+          "288230376151711743",
+          "-288230376151711744",
+          "0",
+          "1",
+          "288230376151711743",
+          "288230376151711744",
+          "Sub",
+          "-1",
+          "-41",
+          "-9223372036854775806",
+          "-9223372036854775807",
+          "-3",
+          "37",
+          "9223372036854775802",
+          "9223372036854775803",
+          "Mixed",
+          "3",
+          "43",
+          "-9223372036854775808",
+          "-9223372036854775807",
+          "3",
+          "-37",
+          "-9223372036854775802",
+          "-9223372036854775803",
+          "3",
+          "43",
+          "-9223372036854775808",
+          "-9223372036854775807",
+          "1",
+          "41",
+          "9223372036854775806",
+          "9223372036854775807",
+          "Double Associative",
+          "12",
+          "52",
+          "84",
+          "1764",
+          "2",
+          "2",
+          "7",
+          "47",
+          "4",
+          "44",
+          "Double Shift",
+          "128",
+          "2688",
+          "0",
+          "0",
+          "0",
+          "0",
+          "Double Sub",
+          "-1",
+          "-41",
+          "-10",
+          "30",
+          "Double Mixed",
+          "-4",
+          "36",
+          "7",
+          "-33",
+          "-4",
+          "36",
+          "5",
+          "45");
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withCfRuntimes().withDexRuntimes().withAllApiLevels().build();
+  }
+
+  public AssociativeLongTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    testForRuntime(parameters)
+        .addProgramClasses(Main.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClasses(Main.class)
+        .addKeepMainRule(Main.class)
+        .enableInliningAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+
+  private void inspect(CodeInspector inspector) {
+    ClassSubject clazz = inspector.clazz(Main.class);
+    for (FoundMethodSubject method :
+        clazz.allMethods(m -> m.getParameters().size() > 0 && m.getParameter(0).is("long"))) {
+      assertEquals(
+          1,
+          method
+              .streamInstructions()
+              .filter(i -> i.isIntArithmeticBinop() || i.isLongLogicalBinop())
+              .count());
+    }
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      simple();
+      doubleOps();
+    }
+
+    @NeverInline
+    private static void simple() {
+      // Associative + * & | ^.
+      System.out.println("Associative");
+      add(2);
+      add(42);
+      add(Long.MAX_VALUE);
+      add(Long.MIN_VALUE);
+      mul(2);
+      mul(42);
+      mul(Long.MAX_VALUE);
+      mul(Long.MIN_VALUE);
+      and(2);
+      and(42);
+      and(Long.MAX_VALUE);
+      and(Long.MIN_VALUE);
+      or(2);
+      or(42);
+      or(Long.MAX_VALUE);
+      or(Long.MIN_VALUE);
+      xor(2);
+      xor(42);
+      xor(Long.MAX_VALUE);
+      xor(Long.MIN_VALUE);
+
+      // Shift composition.
+      System.out.println("Shift");
+      shl(2);
+      shl(42);
+      shl(Long.MAX_VALUE);
+      shl(Long.MIN_VALUE);
+      shr(2);
+      shr(42);
+      shr(Long.MAX_VALUE);
+      shr(Long.MIN_VALUE);
+      ushr(2);
+      ushr(42);
+      ushr(Long.MAX_VALUE);
+      ushr(Long.MIN_VALUE);
+
+      // Special for -.
+      System.out.println("Sub");
+      sub(2);
+      sub(42);
+      sub(Long.MAX_VALUE);
+      sub(Long.MIN_VALUE);
+      sub2(2);
+      sub2(42);
+      sub2(Long.MAX_VALUE);
+      sub2(Long.MIN_VALUE);
+
+      // Mixed for + and -.
+      System.out.println("Mixed");
+      addSub(2);
+      addSub(42);
+      addSub(Long.MAX_VALUE);
+      addSub(Long.MIN_VALUE);
+      subAdd(2);
+      subAdd(42);
+      subAdd(Long.MAX_VALUE);
+      subAdd(Long.MIN_VALUE);
+      addSub2(2);
+      addSub2(42);
+      addSub2(Long.MAX_VALUE);
+      addSub2(Long.MIN_VALUE);
+      subAdd2(2);
+      subAdd2(42);
+      subAdd2(Long.MAX_VALUE);
+      subAdd2(Long.MIN_VALUE);
+    }
+
+    @NeverInline
+    private static void doubleOps() {
+      // Associative + * & | ^.
+      System.out.println("Double Associative");
+      addDouble(2);
+      addDouble(42);
+      mulDouble(2);
+      mulDouble(42);
+      andDouble(2);
+      andDouble(42);
+      orDouble(2);
+      orDouble(42);
+      xorDouble(2);
+      xorDouble(42);
+
+      // Shift composition.
+      System.out.println("Double Shift");
+      shlDouble(2);
+      shlDouble(42);
+      shrDouble(2);
+      shrDouble(42);
+      ushrDouble(2);
+      ushrDouble(42);
+
+      // Special for -.
+      System.out.println("Double Sub");
+      subDouble(2);
+      subDouble(42);
+      sub2Double(2);
+      sub2Double(42);
+
+      // Mixed for + and -.
+      System.out.println("Double Mixed");
+      addSubDouble(2);
+      addSubDouble(42);
+      subAddDouble(2);
+      subAddDouble(42);
+      addSub2Double(2);
+      addSub2Double(42);
+      subAdd2Double(2);
+      subAdd2Double(42);
+    }
+
+    @NeverInline
+    public static void add(long x) {
+      System.out.println(3L + x + 2L);
+    }
+
+    @NeverInline
+    public static void mul(long x) {
+      System.out.println(3L * x * 2L);
+    }
+
+    @NeverInline
+    public static void and(long x) {
+      System.out.println(3L & x & 2L);
+    }
+
+    @NeverInline
+    public static void or(long x) {
+      System.out.println(3L | x | 2L);
+    }
+
+    @NeverInline
+    public static void xor(long x) {
+      System.out.println(3L ^ x ^ 2L);
+    }
+
+    @NeverInline
+    public static void shl(long x) {
+      System.out.println(x << 2L << 3L);
+    }
+
+    @NeverInline
+    public static void shr(long x) {
+      System.out.println(x >> 2L >> 3L);
+    }
+
+    @NeverInline
+    public static void ushr(long x) {
+      System.out.println(x >>> 2L >>> 3L);
+    }
+
+    @NeverInline
+    public static void sub(long x) {
+      System.out.println(3L - x - 2L);
+    }
+
+    @NeverInline
+    public static void sub2(long x) {
+      System.out.println(x - 3L - 2L);
+    }
+
+    @NeverInline
+    public static void addSub(long x) {
+      System.out.println(3L + x - 2L);
+    }
+
+    @NeverInline
+    public static void addSub2(long x) {
+      System.out.println(x + 3L - 2L);
+    }
+
+    @NeverInline
+    public static void subAdd(long x) {
+      System.out.println(3L - x + 2L);
+    }
+
+    @NeverInline
+    public static void subAdd2(long x) {
+      System.out.println(x - 3L + 2L);
+    }
+
+    @NeverInline
+    public static void addDouble(long x) {
+      System.out.println(3L + x + 2L + 5);
+    }
+
+    @NeverInline
+    public static void mulDouble(long x) {
+      System.out.println(3L * x * 2L * 7L);
+    }
+
+    @NeverInline
+    public static void andDouble(long x) {
+      System.out.println(3L & x & 2L & 7L);
+    }
+
+    @NeverInline
+    public static void orDouble(long x) {
+      System.out.println(3L | x | 2L | 7L);
+    }
+
+    @NeverInline
+    public static void xorDouble(long x) {
+      System.out.println(3L ^ x ^ 2L ^ 7L);
+    }
+
+    @NeverInline
+    public static void shlDouble(long x) {
+      System.out.println(x << 2L << 3L << 1L);
+    }
+
+    @NeverInline
+    public static void shrDouble(long x) {
+      System.out.println(x >> 2L >> 3L >> 1L);
+    }
+
+    @NeverInline
+    public static void ushrDouble(long x) {
+      System.out.println(x >>> 2L >>> 3L >>> 1L);
+    }
+
+    @NeverInline
+    public static void subDouble(long x) {
+      System.out.println(3L - x - 2L);
+    }
+
+    @NeverInline
+    public static void sub2Double(long x) {
+      System.out.println(x - 3L - 2L - 7L);
+    }
+
+    @NeverInline
+    public static void addSubDouble(long x) {
+      System.out.println(3L + x - 2L - 7L);
+    }
+
+    @NeverInline
+    public static void addSub2Double(long x) {
+      System.out.println(x + 3L - 2L - 7L);
+    }
+
+    @NeverInline
+    public static void subAddDouble(long x) {
+      System.out.println(3L - x + 2L + 4L);
+    }
+
+    @NeverInline
+    public static void subAdd2Double(long x) {
+      System.out.println(x - 3L + 2L + 4L);
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/IdentityAbsorbingTest.java b/src/test/java/com/android/tools/r8/ir/IdentityAbsorbingTest.java
index 390d6ca..3f9a8b2 100644
--- a/src/test/java/com/android/tools/r8/ir/IdentityAbsorbingTest.java
+++ b/src/test/java/com/android/tools/r8/ir/IdentityAbsorbingTest.java
@@ -853,7 +853,11 @@
                 assertTrue(
                     m.streamInstructions()
                         .noneMatch(
-                            i -> i.isIntOrLongLogicalBinop() || i.isIntOrLongArithmeticBinop())));
+                            i ->
+                                i.isIntLogicalBinop()
+                                    || i.isLongLogicalBinop()
+                                    || i.isIntArithmeticBinop()
+                                    || i.isLongArithmeticBinop())));
   }
 
   static class Main {
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
index 45d6202..79425a3 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
@@ -339,17 +339,27 @@
   }
 
   @Override
-  public boolean isIntOrLongArithmeticBinop() {
+  public boolean isIntArithmeticBinop() {
     return instruction instanceof CfArithmeticBinop
-        && (((CfArithmeticBinop) instruction).getType() == NumericType.INT
-            || ((CfArithmeticBinop) instruction).getType() == NumericType.LONG);
+        && ((CfArithmeticBinop) instruction).getType() == NumericType.INT;
   }
 
   @Override
-  public boolean isIntOrLongLogicalBinop() {
+  public boolean isIntLogicalBinop() {
     return instruction instanceof CfLogicalBinop
-        && (((CfLogicalBinop) instruction).getType() == NumericType.INT
-            || ((CfLogicalBinop) instruction).getType() == NumericType.LONG);
+        && ((CfLogicalBinop) instruction).getType() == NumericType.INT;
+  }
+
+  @Override
+  public boolean isLongArithmeticBinop() {
+    return instruction instanceof CfArithmeticBinop
+        && ((CfArithmeticBinop) instruction).getType() == NumericType.LONG;
+  }
+
+  @Override
+  public boolean isLongLogicalBinop() {
+    return instruction instanceof CfLogicalBinop
+        && ((CfLogicalBinop) instruction).getType() == NumericType.LONG;
   }
 
   @Override
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
index 38421a2..bd68422 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
@@ -122,6 +122,8 @@
 import com.android.tools.r8.dex.code.DexReturn;
 import com.android.tools.r8.dex.code.DexReturnObject;
 import com.android.tools.r8.dex.code.DexReturnVoid;
+import com.android.tools.r8.dex.code.DexRsubInt;
+import com.android.tools.r8.dex.code.DexRsubIntLit8;
 import com.android.tools.r8.dex.code.DexSget;
 import com.android.tools.r8.dex.code.DexSgetBoolean;
 import com.android.tools.r8.dex.code.DexSgetByte;
@@ -501,69 +503,77 @@
     return instruction instanceof DexSparseSwitch;
   }
 
-  public boolean isIntOrLongArithmeticBinop() {
+  public boolean isIntArithmeticBinop() {
     return instruction instanceof DexMulInt
         || instruction instanceof DexMulIntLit8
         || instruction instanceof DexMulIntLit16
         || instruction instanceof DexMulInt2Addr
-        || instruction instanceof DexMulLong
-        || instruction instanceof DexMulLong2Addr
         || instruction instanceof DexAddInt
         || instruction instanceof DexAddIntLit8
         || instruction instanceof DexAddIntLit16
         || instruction instanceof DexAddInt2Addr
-        || instruction instanceof DexAddLong
-        || instruction instanceof DexAddLong2Addr
+        || instruction instanceof DexRsubInt
+        || instruction instanceof DexRsubIntLit8
         || instruction instanceof DexSubInt
         || instruction instanceof DexSubInt2Addr
-        || instruction instanceof DexSubLong
-        || instruction instanceof DexSubLong2Addr
         || instruction instanceof DexDivInt
         || instruction instanceof DexDivIntLit8
         || instruction instanceof DexDivIntLit16
         || instruction instanceof DexDivInt2Addr
-        || instruction instanceof DexDivLong
-        || instruction instanceof DexDivLong2Addr
         || instruction instanceof DexRemInt
         || instruction instanceof DexRemIntLit8
         || instruction instanceof DexRemIntLit16
-        || instruction instanceof DexRemInt2Addr
+        || instruction instanceof DexRemInt2Addr;
+  }
+
+  public boolean isLongArithmeticBinop() {
+    return instruction instanceof DexMulLong
+        || instruction instanceof DexMulLong2Addr
+        || instruction instanceof DexAddLong
+        || instruction instanceof DexAddLong2Addr
+        || instruction instanceof DexSubLong
+        || instruction instanceof DexSubLong2Addr
+        || instruction instanceof DexDivLong
+        || instruction instanceof DexDivLong2Addr
         || instruction instanceof DexRemLong
         || instruction instanceof DexRemLong2Addr;
   }
 
-  public boolean isIntOrLongLogicalBinop() {
+  public boolean isIntLogicalBinop() {
     return instruction instanceof DexAndInt
         || instruction instanceof DexAndIntLit8
         || instruction instanceof DexAndIntLit16
         || instruction instanceof DexAndInt2Addr
-        || instruction instanceof DexAndLong
-        || instruction instanceof DexAndLong2Addr
         || instruction instanceof DexOrInt
         || instruction instanceof DexOrIntLit8
         || instruction instanceof DexOrIntLit16
         || instruction instanceof DexOrInt2Addr
-        || instruction instanceof DexOrLong
-        || instruction instanceof DexOrLong2Addr
         || instruction instanceof DexXorInt
         || instruction instanceof DexXorIntLit8
         || instruction instanceof DexXorIntLit16
         || instruction instanceof DexXorInt2Addr
-        || instruction instanceof DexXorLong
-        || instruction instanceof DexXorLong2Addr
         || instruction instanceof DexShrInt
         || instruction instanceof DexShrIntLit8
         || instruction instanceof DexShrInt2Addr
-        || instruction instanceof DexShrLong
-        || instruction instanceof DexShrLong2Addr
         || instruction instanceof DexShlInt
         || instruction instanceof DexShlIntLit8
         || instruction instanceof DexShlInt2Addr
-        || instruction instanceof DexShlLong
-        || instruction instanceof DexShlLong2Addr
         || instruction instanceof DexUshrInt
         || instruction instanceof DexUshrIntLit8
-        || instruction instanceof DexUshrInt2Addr
+        || instruction instanceof DexUshrInt2Addr;
+  }
+
+  public boolean isLongLogicalBinop() {
+    return instruction instanceof DexAndLong
+        || instruction instanceof DexAndLong2Addr
+        || instruction instanceof DexOrLong
+        || instruction instanceof DexOrLong2Addr
+        || instruction instanceof DexXorLong
+        || instruction instanceof DexXorLong2Addr
+        || instruction instanceof DexShrLong
+        || instruction instanceof DexShrLong2Addr
+        || instruction instanceof DexShlLong
+        || instruction instanceof DexShlLong2Addr
         || instruction instanceof DexUshrLong
         || instruction instanceof DexUshrLong2Addr;
   }
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
index 6d1bd94..b77029f 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
@@ -140,9 +140,13 @@
 
   boolean isSparseSwitch();
 
-  boolean isIntOrLongArithmeticBinop();
+  boolean isIntArithmeticBinop();
 
-  boolean isIntOrLongLogicalBinop();
+  boolean isIntLogicalBinop();
+
+  boolean isLongArithmeticBinop();
+
+  boolean isLongLogicalBinop();
 
   boolean isMultiplication();