Implement updateChangedFlags on abstract inputs

Bug: b/302483644
Change-Id: Idfce8ebde85f0897935e7933d4f4a7ec18a565bd
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/value/arithmetic/AbstractCalculator.java b/src/main/java/com/android/tools/r8/ir/analysis/value/arithmetic/AbstractCalculator.java
index 741efd0..ed6c059 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/value/arithmetic/AbstractCalculator.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/value/arithmetic/AbstractCalculator.java
@@ -10,7 +10,7 @@
 public class AbstractCalculator {
 
   public static AbstractValue andIntegers(
-      AbstractValue left, AbstractValue right, AppView<?> appView) {
+      AppView<?> appView, AbstractValue left, AbstractValue right) {
     if (left.isZero()) {
       return left;
     }
@@ -44,7 +44,7 @@
   }
 
   public static AbstractValue orIntegers(
-      AbstractValue left, AbstractValue right, AppView<?> appView) {
+      AppView<?> appView, AbstractValue left, AbstractValue right) {
     if (left.isZero()) {
       return right;
     }
@@ -77,12 +77,26 @@
     return AbstractValue.unknown();
   }
 
+  public static AbstractValue orIntegers(
+      AppView<?> appView,
+      AbstractValue first,
+      AbstractValue second,
+      AbstractValue third,
+      AbstractValue fourth) {
+    return orIntegers(
+        appView, first, orIntegers(appView, second, orIntegers(appView, third, fourth)));
+  }
+
   public static AbstractValue shlIntegers(
-      AbstractValue left, AbstractValue right, AppView<?> appView) {
+      AppView<?> appView, AbstractValue left, AbstractValue right) {
     if (!right.isSingleNumberValue()) {
       return AbstractValue.unknown();
     }
     int rightConst = right.asSingleNumberValue().getIntValue();
+    return shlIntegers(appView, left, rightConst);
+  }
+
+  public static AbstractValue shlIntegers(AppView<?> appView, AbstractValue left, int rightConst) {
     if (rightConst == 0) {
       return left;
     }
@@ -104,11 +118,15 @@
   }
 
   public static AbstractValue shrIntegers(
-      AbstractValue left, AbstractValue right, AppView<?> appView) {
+      AppView<?> appView, AbstractValue left, AbstractValue right) {
     if (!right.isSingleNumberValue()) {
       return AbstractValue.unknown();
     }
     int rightConst = right.asSingleNumberValue().getIntValue();
+    return shrIntegers(appView, left, rightConst);
+  }
+
+  public static AbstractValue shrIntegers(AppView<?> appView, AbstractValue left, int rightConst) {
     if (rightConst == 0) {
       return left;
     }
@@ -127,7 +145,7 @@
   }
 
   public static AbstractValue ushrIntegers(
-      AbstractValue left, AbstractValue right, AppView<?> appView) {
+      AppView<?> appView, AbstractValue left, AbstractValue right) {
     if (!right.isSingleNumberValue()) {
       return AbstractValue.unknown();
     }
@@ -153,7 +171,7 @@
   }
 
   public static AbstractValue xorIntegers(
-      AbstractValue left, AbstractValue right, AppView<?> appView) {
+      AppView<?> appView, AbstractValue left, AbstractValue right) {
     if (left.isSingleNumberValue() && right.isSingleNumberValue()) {
       int result =
           left.asSingleNumberValue().getIntValue() ^ right.asSingleNumberValue().getIntValue();
diff --git a/src/main/java/com/android/tools/r8/ir/code/And.java b/src/main/java/com/android/tools/r8/ir/code/And.java
index 08b399f..522c7c0 100644
--- a/src/main/java/com/android/tools/r8/ir/code/And.java
+++ b/src/main/java/com/android/tools/r8/ir/code/And.java
@@ -105,7 +105,7 @@
 
   @Override
   AbstractValue foldIntegers(AbstractValue left, AbstractValue right, AppView<?> appView) {
-    return AbstractCalculator.andIntegers(left, right, appView);
+    return AbstractCalculator.andIntegers(appView, left, right);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/Or.java b/src/main/java/com/android/tools/r8/ir/code/Or.java
index c098fe2..8fd53b3 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Or.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Or.java
@@ -104,7 +104,7 @@
 
   @Override
   AbstractValue foldIntegers(AbstractValue left, AbstractValue right, AppView<?> appView) {
-    return AbstractCalculator.orIntegers(left, right, appView);
+    return AbstractCalculator.orIntegers(appView, left, right);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/Shl.java b/src/main/java/com/android/tools/r8/ir/code/Shl.java
index 5977cad..73c84e8 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Shl.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Shl.java
@@ -99,7 +99,7 @@
 
   @Override
   AbstractValue foldIntegers(AbstractValue left, AbstractValue right, AppView<?> appView) {
-    return AbstractCalculator.shlIntegers(left, right, appView);
+    return AbstractCalculator.shlIntegers(appView, left, right);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/Shr.java b/src/main/java/com/android/tools/r8/ir/code/Shr.java
index d1b6a1c..52c7cb0 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Shr.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Shr.java
@@ -99,7 +99,7 @@
 
   @Override
   AbstractValue foldIntegers(AbstractValue left, AbstractValue right, AppView<?> appView) {
-    return AbstractCalculator.shrIntegers(left, right, appView);
+    return AbstractCalculator.shrIntegers(appView, left, right);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/Ushr.java b/src/main/java/com/android/tools/r8/ir/code/Ushr.java
index 88eb857..c80f280 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Ushr.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Ushr.java
@@ -99,7 +99,7 @@
 
   @Override
   AbstractValue foldIntegers(AbstractValue left, AbstractValue right, AppView<?> appView) {
-    return AbstractCalculator.ushrIntegers(left, right, appView);
+    return AbstractCalculator.ushrIntegers(appView, left, right);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/code/Xor.java b/src/main/java/com/android/tools/r8/ir/code/Xor.java
index fff451f..bddba7e 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Xor.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Xor.java
@@ -99,7 +99,7 @@
 
   @Override
   AbstractValue foldIntegers(AbstractValue left, AbstractValue right, AppView<?> appView) {
-    return AbstractCalculator.xorIntegers(left, right, appView);
+    return AbstractCalculator.xorIntegers(appView, left, right);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/OrAbstractFunction.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/OrAbstractFunction.java
index 3c02dd2..ca1223f 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/OrAbstractFunction.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/OrAbstractFunction.java
@@ -33,7 +33,7 @@
       ConcreteValueState inState) {
     ConcretePrimitiveTypeValueState inPrimitiveState = inState.asPrimitiveState();
     AbstractValue result =
-        AbstractCalculator.orIntegers(inPrimitiveState.getAbstractValue(), constant, appView);
+        AbstractCalculator.orIntegers(appView, inPrimitiveState.getAbstractValue(), constant);
     return ConcretePrimitiveTypeValueState.create(result, inPrimitiveState.copyInFlow());
   }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/UpdateChangedFlagsAbstractFunction.java b/src/main/java/com/android/tools/r8/optimize/compose/UpdateChangedFlagsAbstractFunction.java
index ee25e29..2d940e9 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/UpdateChangedFlagsAbstractFunction.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/UpdateChangedFlagsAbstractFunction.java
@@ -4,8 +4,13 @@
 package com.android.tools.r8.optimize.compose;
 
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.ir.analysis.value.AbstractValueFactory;
+import com.android.tools.r8.ir.analysis.value.SingleNumberValue;
+import com.android.tools.r8.ir.analysis.value.arithmetic.AbstractCalculator;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.AbstractFunction;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.BaseInFlow;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcretePrimitiveTypeValueState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteValueState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.FlowGraphStateProvider;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.InFlow;
@@ -17,6 +22,10 @@
 
 public class UpdateChangedFlagsAbstractFunction implements AbstractFunction {
 
+  private static final int changedLowBitMask = 0b001_001_001_001_001_001_001_001_001_001_0;
+  private static final int changedHighBitMask = changedLowBitMask << 1;
+  private static final int changedMask = ~(changedLowBitMask | changedHighBitMask);
+
   private final InFlow inFlow;
 
   public UpdateChangedFlagsAbstractFunction(InFlow inFlow) {
@@ -36,9 +45,63 @@
     } else {
       inState = baseInState;
     }
-    // TODO(b/302483644): Implement this abstract function to allow correct value propagation of
-    //  updateChangedFlags(x | 1).
-    return inState;
+    if (!inState.isPrimitiveState()) {
+      assert inState.isBottom() || inState.isUnknown();
+      return inState;
+    }
+    AbstractValue result = apply(appView, inState.asPrimitiveState().getAbstractValue());
+    return ConcretePrimitiveTypeValueState.create(result);
+  }
+
+  /**
+   * Applies the following function to the given {@param abstractValue}.
+   *
+   * <pre>
+   * private const val changedLowBitMask = 0b001_001_001_001_001_001_001_001_001_001_0
+   * private const val changedHighBitMask = changedLowBitMask shl 1
+   * private const val changedMask = (changedLowBitMask or changedHighBitMask).inv()
+   *
+   * internal fun updateChangedFlags(flags: Int): Int {
+   *     val lowBits = flags and changedLowBitMask
+   *     val highBits = flags and changedHighBitMask
+   *     return ((flags and changedMask) or
+   *         (lowBits or (highBits shr 1)) or ((lowBits shl 1) and highBits))
+   * }
+   * </pre>
+   */
+  private AbstractValue apply(AppView<AppInfoWithLiveness> appView, AbstractValue flagsValue) {
+    if (flagsValue.isSingleNumberValue()) {
+      return apply(appView, flagsValue.asSingleNumberValue().getIntValue());
+    }
+    AbstractValueFactory factory = appView.abstractValueFactory();
+    // Load constants.
+    AbstractValue changedLowBitMaskValue =
+        factory.createUncheckedSingleNumberValue(changedLowBitMask);
+    AbstractValue changedHighBitMaskValue =
+        factory.createUncheckedSingleNumberValue(changedHighBitMask);
+    AbstractValue changedMaskValue = factory.createUncheckedSingleNumberValue(changedMask);
+    // Evaluate expression.
+    AbstractValue lowBitsValue =
+        AbstractCalculator.andIntegers(appView, flagsValue, changedLowBitMaskValue);
+    AbstractValue highBitsValue =
+        AbstractCalculator.andIntegers(appView, flagsValue, changedHighBitMaskValue);
+    AbstractValue changedBitsValue =
+        AbstractCalculator.andIntegers(appView, flagsValue, changedMaskValue);
+    return AbstractCalculator.orIntegers(
+        appView,
+        changedBitsValue,
+        lowBitsValue,
+        AbstractCalculator.shrIntegers(appView, highBitsValue, 1),
+        AbstractCalculator.andIntegers(
+            appView, AbstractCalculator.shlIntegers(appView, lowBitsValue, 1), highBitsValue));
+  }
+
+  private SingleNumberValue apply(AppView<AppInfoWithLiveness> appView, int flags) {
+    int lowBits = flags & changedLowBitMask;
+    int highBits = flags & changedHighBitMask;
+    int changedBits = flags & changedMask;
+    int result = changedBits | lowBits | (highBits >> 1) | ((lowBits << 1) & highBits);
+    return appView.abstractValueFactory().createUncheckedSingleNumberValue(result);
   }
 
   @Override