Unsound, experimental removal of unused Composable arguments

For method parameters of Composable functions, this CL introduces a condition over the arguments to the Composable function, which evaluates to true if-and-only-if the argument can be skipped.

Since Composable method parameters generally always have uses, this CL ignores uses that are dominated by calls to skipToGroupEnd(), which is potentially unsound.

Bug: b/302281503
Change-Id: I3d7c2862f6aaabb9cd651540073cf8efa2391aa9
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/path/state/ConcretePathConstraintAnalysisState.java b/src/main/java/com/android/tools/r8/ir/analysis/path/state/ConcretePathConstraintAnalysisState.java
index 28bdcf1..8f209ac 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/path/state/ConcretePathConstraintAnalysisState.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/path/state/ConcretePathConstraintAnalysisState.java
@@ -6,8 +6,10 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.optimize.argumentpropagation.computation.ComputationTreeNode;
+import com.android.tools.r8.utils.MapUtils;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.Map.Entry;
 
@@ -43,7 +45,21 @@
 public class ConcretePathConstraintAnalysisState extends PathConstraintAnalysisState {
 
   // TODO(b/302281503): Consider changing this to an ImmutableMap.
-  private final Map<ComputationTreeNode, PathConstraintKind> pathConstraints;
+  // We use an IdentityHashMap during the execution of the analysis. This is important for the
+  // ability to handle control flow with repeated branches such as the following:
+  //
+  //   1. if ((arg & 1) != 0) { ... }
+  //   2. ...
+  //   3. if ((arg & 1) != 0) { ... }
+  //
+  // By using a HashMap backing, after the join point at line 2, we would map the path constraint
+  // `(arg & 1) != 0` to DISABLED. Since the constraint of the IF-condition in line 3 is the same,
+  // this constraint would continue to be mapped to DISABLED in both the else- and then-branch.
+  //
+  // We want to ensure that we have the path constraint `(arg & 1) != 0` mapped to POSITIVE and
+  // NEGATIVE in the then- and else-branch of the branch in line 3, respectively. This is achieved
+  // using a IdentityHashMap as we then no longer conflate the conditions of the two branches.
+  private Map<ComputationTreeNode, PathConstraintKind> pathConstraints;
 
   public ConcretePathConstraintAnalysisState() {
     this.pathConstraints = Collections.emptyMap();
@@ -76,7 +92,7 @@
     // No jumps can dominate the entry of their own block, so when adding the condition of a jump
     // this cannot currently be active.
     Map<ComputationTreeNode, PathConstraintKind> newPathConstraints =
-        new HashMap<>(pathConstraints.size() + 1);
+        new IdentityHashMap<>(pathConstraints.size() + 1);
     newPathConstraints.putAll(pathConstraints);
     newPathConstraints.put(pathConstraint, newKind);
     return new ConcretePathConstraintAnalysisState(newPathConstraints);
@@ -96,6 +112,22 @@
     return this;
   }
 
+  public void ensureHashMapBacking() {
+    if (pathConstraints.isEmpty()) {
+      assert pathConstraints == Collections.<ComputationTreeNode, PathConstraintKind>emptyMap();
+    } else if (!(pathConstraints instanceof HashMap)) {
+      assert pathConstraints instanceof IdentityHashMap;
+      // Copy to a HashMap.
+      pathConstraints =
+          MapUtils.transform(
+              pathConstraints,
+              HashMap::new,
+              pathConstraint -> pathConstraint,
+              kind -> kind,
+              (pathConstraint, kind, otherKind) -> kind.meet(otherKind));
+    }
+  }
+
   @Override
   public boolean isGreaterThanOrEquals(AppView<?> appView, PathConstraintAnalysisState state) {
     if (state.isConcrete()) {
@@ -153,6 +185,10 @@
     return AbstractValue.unknown();
   }
 
+  public PathConstraintKind getKind(ComputationTreeNode pathConstraint) {
+    return pathConstraints.get(pathConstraint);
+  }
+
   public ConcretePathConstraintAnalysisState join(ConcretePathConstraintAnalysisState other) {
     if (isGreaterThanOrEquals(other)) {
       return this;
@@ -161,7 +197,7 @@
       return other;
     }
     Map<ComputationTreeNode, PathConstraintKind> newPathConstraints =
-        new HashMap<>(pathConstraints.size() + other.pathConstraints.size());
+        new IdentityHashMap<>(pathConstraints.size() + other.pathConstraints.size());
     join(other, newPathConstraints);
     other.join(this, newPathConstraints);
     return new ConcretePathConstraintAnalysisState(newPathConstraints);
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/path/state/PathConstraintKind.java b/src/main/java/com/android/tools/r8/ir/analysis/path/state/PathConstraintKind.java
index 6891163..dad7b75 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/path/state/PathConstraintKind.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/path/state/PathConstraintKind.java
@@ -29,4 +29,18 @@
     }
     return this == other ? this : DISABLED;
   }
+
+  PathConstraintKind meet(PathConstraintKind other) {
+    assert other != null;
+    if (this == other) {
+      return this;
+    }
+    if (this == DISABLED) {
+      return other;
+    }
+    if (other == DISABLED) {
+      return this;
+    }
+    return DISABLED;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/code/IfType.java b/src/main/java/com/android/tools/r8/ir/code/IfType.java
index 8a0d96b..36a961c 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IfType.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IfType.java
@@ -77,6 +77,9 @@
   };
 
   public AbstractValue evaluate(AbstractValue operand, AppView<AppInfoWithLiveness> appView) {
+    if (operand.isBottom()) {
+      return AbstractValue.bottom();
+    }
     if (operand.isSingleNumberValue()) {
       int operandValue = operand.asSingleNumberValue().getIntValue();
       boolean result = evaluate(operandValue);
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/unusedarguments/EffectivelyUnusedArgumentsAnalysis.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/unusedarguments/EffectivelyUnusedArgumentsAnalysis.java
index ef389b7..8c88cfd 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/unusedarguments/EffectivelyUnusedArgumentsAnalysis.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/unusedarguments/EffectivelyUnusedArgumentsAnalysis.java
@@ -13,22 +13,27 @@
 import com.android.tools.r8.graph.PrunedItems;
 import com.android.tools.r8.ir.analysis.path.PathConstraintSupplier;
 import com.android.tools.r8.ir.analysis.path.state.ConcretePathConstraintAnalysisState;
+import com.android.tools.r8.ir.analysis.path.state.PathConstraintKind;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.ir.code.Argument;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.InvokeVirtual;
+import com.android.tools.r8.ir.code.LazyDominatorTree;
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodParameter;
 import com.android.tools.r8.optimize.argumentpropagation.computation.ComputationTreeNode;
 import com.android.tools.r8.optimize.argumentpropagation.computation.ComputationTreeUnopCompareNode;
 import com.android.tools.r8.optimize.argumentpropagation.utils.ParameterRemovalUtils;
+import com.android.tools.r8.optimize.compose.ComposeReferences;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.WorkList;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.google.common.collect.Sets;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -88,6 +93,12 @@
   // for the method parameter p to be effectively unused.
   private final Map<MethodParameter, Set<MethodParameter>> constraints = new ConcurrentHashMap<>();
 
+  // Maps Composable parameters to a condition. If the condition evaluates to true for a given
+  // invoke, then the argument to this method parameter can be ignored. Note that this does not
+  // imply that the method parameter is unused and can be removed.
+  private final Map<MethodParameter, ComputationTreeNode> ignoreComposableArgumentConditions =
+      new ConcurrentHashMap<>();
+
   // Set of virtual methods that can definitely be optimized.
   //
   // We conservatively exclude virtual methods with dynamic dispatch from this set, since the
@@ -100,6 +111,10 @@
   }
 
   public ComputationTreeNode getEffectivelyUnusedCondition(MethodParameter methodParameter) {
+    if (ignoreComposableArgumentConditions.containsKey(methodParameter)) {
+      assert appView.options().callSiteOptimizationOptions().isComposableArgumentRemovalEnabled();
+      return ignoreComposableArgumentConditions.get(methodParameter);
+    }
     return conditions.getOrDefault(methodParameter, AbstractValue.unknown());
   }
 
@@ -186,6 +201,8 @@
     private final IRCode code;
     private final PathConstraintSupplier pathConstraintSupplier;
 
+    private Set<BasicBlock> skipToGroupEndBlocks;
+
     private Analyzer(
         ProgramMethod method, IRCode code, PathConstraintSupplier pathConstraintSupplier) {
       this.method = method;
@@ -231,7 +248,7 @@
         return;
       }
       Value usedValue;
-      if (argumentValue.hasPhiUsers()) {
+      if (argumentValue.hasSingleUniquePhiUser()) {
         // If the argument has one or more phi users, we check if there is a single phi due to a
         // default value for the argument. If so, we record this condition and mark the users of the
         // phi as effectively unused argument constraints for the current argument.
@@ -239,12 +256,14 @@
             argumentValue, effectivelyUnusedConditionsConsumer)) {
           return;
         }
-        assert argumentValue.hasSingleUniquePhiUser();
         Phi user = argumentValue.singleUniquePhiUser();
         if (user.hasDebugUsers() || user.hasPhiUsers()) {
           return;
         }
         usedValue = user;
+      } else if (argumentValue.hasPhiUsers()) {
+        computeIgnoreComposableArgumentCondition(argument, argumentValue);
+        return;
       } else {
         usedValue = argumentValue;
       }
@@ -280,37 +299,115 @@
 
     private boolean computeEffectivelyUnusedCondition(
         Value argumentValue, Consumer<ComputationTreeNode> effectivelyUnusedConditionsConsumer) {
-      assert argumentValue.hasPhiUsers();
-      if (argumentValue.hasUsers() || !argumentValue.hasSingleUniquePhiUser()) {
+      assert argumentValue.hasSingleUniquePhiUser();
+      if (argumentValue.hasUsers()) {
         return false;
       }
       Phi phi = argumentValue.singleUniquePhiUser();
-      if (phi.getOperands().size() != 2) {
+      ComputationTreeNode condition = getUnusedCondition(argumentValue, phi);
+      if (condition == null) {
         return false;
       }
+      effectivelyUnusedConditionsConsumer.accept(condition);
+      return true;
+    }
+
+    private ComputationTreeUnopCompareNode getUnusedCondition(Value argumentValue, Phi phi) {
+      if (phi.getOperands().size() != 2) {
+        return null;
+      }
       BasicBlock block = phi.getBlock();
       ConcretePathConstraintAnalysisState leftState =
           pathConstraintSupplier.getPathConstraint(block.getPredecessor(0)).asConcreteState();
       ConcretePathConstraintAnalysisState rightState =
           pathConstraintSupplier.getPathConstraint(block.getPredecessor(1)).asConcreteState();
       if (leftState == null || rightState == null) {
-        return false;
+        return null;
       }
       // Find a condition that can be used to distinguish program paths coming from the two
       // predecessors.
       ComputationTreeNode condition = leftState.getDifferentiatingPathConstraint(rightState);
-      if (!condition.isArgumentBitSetCompareNode()) {
-        return false;
+      if (condition == null || !condition.isArgumentBitSetCompareNode()) {
+        return null;
       }
       // Extract the state corresponding to the program path where the argument is unused. If the
       // condition evaluates to false on this program path then negate the condition.
+      ComputationTreeUnopCompareNode compareCondition = (ComputationTreeUnopCompareNode) condition;
       ConcretePathConstraintAnalysisState unusedState =
           phi.getOperand(0) == argumentValue ? rightState : leftState;
-      if (unusedState.isNegated(condition)) {
-        condition = ((ComputationTreeUnopCompareNode) condition).negate();
+      if (unusedState.isNegated(compareCondition)) {
+        compareCondition = compareCondition.negate();
       }
-      effectivelyUnusedConditionsConsumer.accept(condition);
-      return true;
+      return compareCondition;
+    }
+
+    private void computeIgnoreComposableArgumentCondition(Argument argument, Value argumentValue) {
+      if (!appView.options().callSiteOptimizationOptions().isComposableArgumentRemovalEnabled()
+          || !appView.getComposeReferences().isComposable(method)) {
+        return;
+      }
+      Phi phi = getSingleUniquePhiUserInComposableIgnoringSkipToGroupEndPaths(argumentValue);
+      if (phi == null) {
+        return;
+      }
+      ComputationTreeUnopCompareNode condition = getUnusedCondition(argumentValue, phi);
+      if (condition == null) {
+        return;
+      }
+      for (Instruction user : argumentValue.uniqueUsers()) {
+        ConcretePathConstraintAnalysisState pathConstraint =
+            pathConstraintSupplier.getPathConstraint(user.getBlock()).asConcreteState();
+        if (pathConstraint == null) {
+          return;
+        }
+        pathConstraint.ensureHashMapBacking();
+        // Check that the condition being true implies that this use is unreachable.
+        if (pathConstraint.getKind(condition) == PathConstraintKind.NEGATIVE
+            || pathConstraint.getKind(condition.negate()) == PathConstraintKind.POSITIVE) {
+          continue;
+        }
+        return;
+      }
+      ignoreComposableArgumentConditions.put(
+          new MethodParameter(method, argument.getIndex()), condition);
+    }
+
+    private Phi getSingleUniquePhiUserInComposableIgnoringSkipToGroupEndPaths(Value argumentValue) {
+      assert appView.options().callSiteOptimizationOptions().isComposableArgumentRemovalEnabled();
+      assert appView.getComposeReferences().isComposable(method);
+      Phi result = null;
+      for (Phi phi : argumentValue.uniquePhiUsers()) {
+        for (int operandIndex = 0; operandIndex < phi.getOperands().size(); operandIndex++) {
+          Value operand = phi.getOperand(operandIndex);
+          if (operand != argumentValue
+              || ignoreComposablePath(phi.getBlock().getPredecessor(operandIndex))) {
+            continue;
+          }
+          if (result != null) {
+            return null;
+          }
+          result = phi;
+          break;
+        }
+      }
+      return result;
+    }
+
+    private boolean ignoreComposablePath(BasicBlock block) {
+      assert appView.options().callSiteOptimizationOptions().isComposableArgumentRemovalEnabled();
+      assert appView.getComposeReferences().isComposable(method);
+      if (skipToGroupEndBlocks == null) {
+        skipToGroupEndBlocks = Sets.newIdentityHashSet();
+        LazyDominatorTree dominatorTree = new LazyDominatorTree(code);
+        ComposeReferences references = appView.getComposeReferences();
+        for (InvokeVirtual invoke :
+            code.<InvokeVirtual>instructions(Instruction::isInvokeVirtual)) {
+          if (invoke.getInvokedMethod().getName().isIdenticalTo(references.skipToGroupEndName)) {
+            skipToGroupEndBlocks.addAll(dominatorTree.get().dominatedBlocks(invoke.getBlock()));
+          }
+        }
+      }
+      return skipToGroupEndBlocks.contains(block);
     }
 
     private boolean isUnoptimizable(ProgramMethod method) {
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
index 633008d..bc2a500 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
@@ -12,19 +12,13 @@
 
 public class ComposeReferences {
 
-  public final DexString changedFieldName;
-
+  public final DexString skipToGroupEndName;
   public final DexType composableType;
-  public final DexType composerType;
-
   public final DexMethod updatedChangedFlagsMethod;
 
   public ComposeReferences(DexItemFactory factory) {
-    changedFieldName = factory.createString("$$changed");
-
+    skipToGroupEndName = factory.createString("skipToGroupEnd");
     composableType = factory.createType("Landroidx/compose/runtime/Composable;");
-    composerType = factory.createType("Landroidx/compose/runtime/Composer;");
-
     updatedChangedFlagsMethod =
         factory.createMethod(
             factory.createType("Landroidx/compose/runtime/RecomposeScopeImplKt;"),
@@ -33,13 +27,9 @@
   }
 
   public ComposeReferences(
-      DexString changedFieldName,
-      DexType composableType,
-      DexType composerType,
-      DexMethod updatedChangedFlagsMethod) {
-    this.changedFieldName = changedFieldName;
+      DexString skipToGroupEndName, DexType composableType, DexMethod updatedChangedFlagsMethod) {
+    this.skipToGroupEndName = skipToGroupEndName;
     this.composableType = composableType;
-    this.composerType = composerType;
     this.updatedChangedFlagsMethod = updatedChangedFlagsMethod;
   }
 
@@ -50,9 +40,8 @@
 
   public ComposeReferences rewrittenWithLens(GraphLens graphLens, GraphLens codeLens) {
     return new ComposeReferences(
-        changedFieldName,
+        skipToGroupEndName,
         graphLens.lookupClassType(composableType, codeLens),
-        graphLens.lookupClassType(composerType, codeLens),
         graphLens.getRenamedMethodSignature(updatedChangedFlagsMethod, codeLens));
   }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 32658ea..87ea197 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -1719,6 +1719,11 @@
   public class CallSiteOptimizationOptions {
 
     private boolean enabled = true;
+    // Unsound optimization for Composable argument removal. Do not enable except for running
+    // experiments.
+    private boolean enableComposableArgumentRemoval =
+        SystemPropertyUtils.parseSystemPropertyForDevelopmentOrDefault(
+            "com.android.tools.r8.enableComposableArgumentRemoval", false);
     private boolean enableMethodStaticizing = true;
 
     private boolean forceSyntheticsForInstanceInitializers = false;
@@ -1738,6 +1743,10 @@
       return enabled;
     }
 
+    public boolean isComposableArgumentRemovalEnabled() {
+      return enableComposableArgumentRemoval;
+    }
+
     public boolean isForceSyntheticsForInstanceInitializersEnabled() {
       return forceSyntheticsForInstanceInitializers;
     }