Optimize bitwise operations in nested @Composable functions

Bug: b/302483644
Change-Id: I8b4329b014fd6f8e4a3fe261dca960512f14739b
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/constant/SparseConditionalConstantPropagation.java b/src/main/java/com/android/tools/r8/ir/analysis/constant/SparseConditionalConstantPropagation.java
index b3b0924..38b433a 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/constant/SparseConditionalConstantPropagation.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/constant/SparseConditionalConstantPropagation.java
@@ -57,9 +57,13 @@
     return true;
   }
 
+  public Map<Value, AbstractValue> analyze(IRCode code) {
+    return new SparseConditionalConstantPropagationOnCode(code).analyze().mapping;
+  }
+
   @Override
   protected CodeRewriterResult rewriteCode(IRCode code) {
-    return new SparseConditionalConstantPropagationOnCode(code).run();
+    return new SparseConditionalConstantPropagationOnCode(code).analyze().run();
   }
 
   private class SparseConditionalConstantPropagationOnCode {
@@ -86,7 +90,7 @@
       visitedBlocks = new BitSet(maxBlockNumber);
     }
 
-    protected CodeRewriterResult run() {
+    public SparseConditionalConstantPropagationOnCode analyze() {
       BasicBlock firstBlock = code.entryBlock();
       visitInstructions(firstBlock);
 
@@ -113,6 +117,10 @@
           }
         }
       }
+      return this;
+    }
+
+    protected CodeRewriterResult run() {
       boolean hasChanged = rewriteConstants();
       return CodeRewriterResult.hasChanged(hasChanged);
     }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index b47b900..bdc3ec5 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -928,6 +928,10 @@
     appView.withArgumentPropagator(
         argumentPropagator -> argumentPropagator.scan(method, code, methodProcessor, timing));
 
+    if (methodProcessor.isComposeMethodProcessor()) {
+      methodProcessor.asComposeMethodProcessor().scan(method, code, timing);
+    }
+
     if (methodProcessor.isPrimaryMethodProcessor()) {
       enumUnboxer.analyzeEnums(code, methodProcessor);
     }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
index 15dc3d2..a80055a 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
@@ -5,9 +5,18 @@
 
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.conversion.callgraph.CallSiteInformation;
+import com.android.tools.r8.optimize.compose.ComposeMethodProcessor;
 
 public abstract class MethodProcessor {
 
+  public boolean isComposeMethodProcessor() {
+    return false;
+  }
+
+  public ComposeMethodProcessor asComposeMethodProcessor() {
+    return null;
+  }
+
   public boolean isPrimaryMethodProcessor() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
index 8d33d03..cd02f61 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
@@ -98,7 +98,6 @@
     InternalOptions options = appView.options();
     Deque<ProgramMethodSet> waves = new ArrayDeque<>();
     Collection<Node> nodes = callGraph.getNodes();
-    int waveCount = 1;
     while (!nodes.isEmpty()) {
       ProgramMethodSet wave = callGraph.extractLeaves();
       waves.addLast(wave);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
index 1a4a2ec..db30629 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
@@ -25,6 +25,7 @@
 import com.android.tools.r8.lightir.LirCode;
 import com.android.tools.r8.optimize.MemberRebindingIdentityLens;
 import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagator;
+import com.android.tools.r8.optimize.compose.ComposableOptimizationPass;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
@@ -226,6 +227,8 @@
       identifierNameStringMarker.decoupleIdentifierNameStringsInFields(executorService);
     }
 
+    ComposableOptimizationPass.run(appView, this, timing);
+
     // Assure that no more optimization feedback left after post processing.
     assert feedback.noUpdatesLeft();
     return appView.appInfo().app();
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorCodeScannerForComposableFunctions.java b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorCodeScannerForComposableFunctions.java
new file mode 100644
index 0000000..5d0b693
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorCodeScannerForComposableFunctions.java
@@ -0,0 +1,45 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.code.AbstractValueSupplier;
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorCodeScanner;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodParameter;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Timing;
+
+public class ArgumentPropagatorCodeScannerForComposableFunctions
+    extends ArgumentPropagatorCodeScanner {
+
+  private final ComposableCallGraph callGraph;
+
+  public ArgumentPropagatorCodeScannerForComposableFunctions(
+      AppView<AppInfoWithLiveness> appView, ComposableCallGraph callGraph) {
+    super(appView);
+    this.callGraph = callGraph;
+  }
+
+  @Override
+  protected void addTemporaryMethodState(
+      InvokeMethod invoke,
+      ProgramMethod resolvedMethod,
+      AbstractValueSupplier abstractValueSupplier,
+      ProgramMethod context,
+      Timing timing) {
+    ComposableCallGraphNode node = callGraph.getNodes().get(resolvedMethod);
+    if (node != null && node.isComposable()) {
+      super.addTemporaryMethodState(invoke, resolvedMethod, abstractValueSupplier, context, timing);
+    }
+  }
+
+  @Override
+  protected boolean isMethodParameterAlreadyUnknown(
+      MethodParameter methodParameter, ProgramMethod method) {
+    // We haven't defined the virtual root mapping, so we can't tell.
+    return false;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
index c4fd469..6ffeb6d 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
@@ -4,8 +4,10 @@
 package com.android.tools.r8.optimize.compose;
 
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.ir.code.InstanceGet;
@@ -19,16 +21,31 @@
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.BitUtils;
 import com.android.tools.r8.utils.BooleanUtils;
+import com.google.common.collect.Iterables;
 
 public class ArgumentPropagatorComposeModeling {
 
   private final AppView<AppInfoWithLiveness> appView;
-  private final ComposeReferences composeReferences;
+  private final ComposeReferences rewrittenComposeReferences;
+
+  private final DexType rewrittenFunction2Type;
+  private final DexString invokeName;
 
   public ArgumentPropagatorComposeModeling(AppView<AppInfoWithLiveness> appView) {
     assert appView.testing().modelUnknownChangedAndDefaultArgumentsToComposableFunctions;
     this.appView = appView;
-    this.composeReferences = appView.getComposeReferences();
+    this.rewrittenComposeReferences =
+        appView
+            .getComposeReferences()
+            .rewrittenWithLens(appView.graphLens(), GraphLens.getIdentityLens());
+    DexItemFactory dexItemFactory = appView.dexItemFactory();
+    this.rewrittenFunction2Type =
+        appView
+            .graphLens()
+            .lookupType(
+                dexItemFactory.createType("Lkotlin/jvm/functions/Function2;"),
+                GraphLens.getIdentityLens());
+    this.invokeName = dexItemFactory.createString("invoke");
   }
 
   /**
@@ -67,20 +84,29 @@
    *   }
    * </pre>
    */
-  // TODO(b/302483644): Only apply modeling when the context is recognized as being a restart
-  //  lambda.
   public ParameterState modelParameterStateForChangedOrDefaultArgumentToComposableFunction(
       InvokeMethod invoke,
       ProgramMethod singleTarget,
       int argumentIndex,
       Value argument,
       ProgramMethod context) {
+    // TODO(b/302483644): Add some robust way of detecting restart lambda contexts.
+    if (!context.getHolder().getInterfaces().contains(rewrittenFunction2Type)
+        || !invoke.getPosition().getOutermostCaller().getMethod().getName().isEqualTo(invokeName)
+        || Iterables.isEmpty(
+            context
+                .getHolder()
+                .instanceFields(
+                    f -> f.getName().isIdenticalTo(rewrittenComposeReferences.changedFieldName)))) {
+      return null;
+    }
+
     // First check if this is an invoke to a @Composable function.
     if (singleTarget == null
         || !singleTarget
             .getDefinition()
             .annotations()
-            .hasAnnotation(composeReferences.composableType)) {
+            .hasAnnotation(rewrittenComposeReferences.composableType)) {
       return null;
     }
 
@@ -109,7 +135,7 @@
         invokedMethod.getArity() - 2 - BooleanUtils.intValue(hasDefaultParameter);
     if (!invokedMethod
         .getParameter(composerParameterIndex)
-        .isIdenticalTo(composeReferences.composerType)) {
+        .isIdenticalTo(rewrittenComposeReferences.composerType)) {
       return null;
     }
 
@@ -128,13 +154,9 @@
       // We generally expect this argument to be defined by a call to updateChangedFlags().
       if (argument.isDefinedByInstructionSatisfying(Instruction::isInvokeStatic)) {
         InvokeStatic invokeStatic = argument.getDefinition().asInvokeStatic();
-        DexMethod maybeUpdateChangedFlagsMethod =
-            appView
-                .graphLens()
-                .getOriginalMethodSignature(
-                    invokeStatic.getInvokedMethod(), GraphLens.getIdentityLens());
+        DexMethod maybeUpdateChangedFlagsMethod = invokeStatic.getInvokedMethod();
         if (!maybeUpdateChangedFlagsMethod.isIdenticalTo(
-            composeReferences.updatedChangedFlagsMethod)) {
+            rewrittenComposeReferences.updatedChangedFlagsMethod)) {
           return null;
         }
         // Assume the call does not impact the $$changed capture and strip the call.
@@ -160,10 +182,10 @@
                     .createDefiniteBitsNumberValue(
                         BitUtils.ALL_BITS_SET_MASK, BitUtils.ALL_BITS_SET_MASK << 1));
       }
-      expectedFieldName = composeReferences.changedFieldName;
+      expectedFieldName = rewrittenComposeReferences.changedFieldName;
     } else {
       // We are looking at an argument to the $$default parameter of the @Composable function.
-      expectedFieldName = composeReferences.defaultFieldName;
+      expectedFieldName = rewrittenComposeReferences.defaultFieldName;
     }
 
     // At this point we expect that the restart lambda is reading either this.$$changed or
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraph.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraph.java
new file mode 100644
index 0000000..0cbf138
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraph.java
@@ -0,0 +1,170 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.Code;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.UseRegistry;
+import com.android.tools.r8.graph.lens.GraphLens;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
+import java.util.function.Consumer;
+
+/**
+ * A partial call graph that stores call edges to @Composable functions. By processing all the call
+ * sites of a given @Composable function we can reapply arguent propagation for the @Composable
+ * function.
+ */
+public class ComposableCallGraph {
+
+  private final ProgramMethodMap<ComposableCallGraphNode> nodes;
+
+  public ComposableCallGraph(ProgramMethodMap<ComposableCallGraphNode> nodes) {
+    this.nodes = nodes;
+  }
+
+  public static Builder builder(AppView<AppInfoWithLiveness> appView) {
+    return new Builder(appView);
+  }
+
+  public static ComposableCallGraph empty() {
+    return new ComposableCallGraph(ProgramMethodMap.empty());
+  }
+
+  public void forEachNode(Consumer<ComposableCallGraphNode> consumer) {
+    nodes.forEachValue(consumer);
+  }
+
+  public ProgramMethodMap<ComposableCallGraphNode> getNodes() {
+    return nodes;
+  }
+
+  public boolean isEmpty() {
+    return nodes.isEmpty();
+  }
+
+  public static class Builder {
+
+    private final AppView<AppInfoWithLiveness> appView;
+    private final ProgramMethodMap<ComposableCallGraphNode> nodes = ProgramMethodMap.create();
+
+    Builder(AppView<AppInfoWithLiveness> appView) {
+      this.appView = appView;
+    }
+
+    public ComposableCallGraph build() {
+      createCallGraphNodesForComposableFunctions();
+      if (!nodes.isEmpty()) {
+        addCallEdgesToComposableFunctions();
+      }
+      return new ComposableCallGraph(nodes);
+    }
+
+    private void createCallGraphNodesForComposableFunctions() {
+      ComposeReferences rewrittenComposeReferences =
+          appView
+              .getComposeReferences()
+              .rewrittenWithLens(appView.graphLens(), GraphLens.getIdentityLens());
+      for (DexProgramClass clazz : appView.appInfo().classes()) {
+        clazz.forEachProgramDirectMethodMatching(
+            method -> method.annotations().hasAnnotation(rewrittenComposeReferences.composableType),
+            method -> {
+              // TODO(b/302483644): Don't include kept @Composable functions, since we can't
+              //  optimize them anyway.
+              assert method.getAccessFlags().isStatic();
+              nodes.put(method, new ComposableCallGraphNode(method, true));
+            });
+      }
+    }
+
+    // TODO(b/302483644): Parallelize identification of @Composable call sites.
+    private void addCallEdgesToComposableFunctions() {
+      // Code is fully rewritten so no need to lens rewrite in registry.
+      assert appView.codeLens() == appView.graphLens();
+
+      for (DexProgramClass clazz : appView.appInfo().classes()) {
+        clazz.forEachProgramMethodMatching(
+            DexEncodedMethod::hasCode,
+            method -> {
+              Code code = method.getDefinition().getCode();
+
+              // TODO(b/302483644): Leverage LIR code constant pool for efficient checking.
+              // TODO(b/302483644): Maybe remove the possibility of CF/DEX at this point.
+              assert code.isLirCode()
+                  || code.isCfCode()
+                  || code.isDexCode()
+                  || code.isDefaultInstanceInitializerCode()
+                  || code.isThrowNullCode();
+
+              code.registerCodeReferences(
+                  method,
+                  new UseRegistry<>(appView, method) {
+
+                    private final AppView<AppInfoWithLiveness> appViewWithLiveness =
+                        appView.withLiveness();
+
+                    @Override
+                    public void registerInvokeStatic(DexMethod method) {
+                      ProgramMethod resolvedMethod =
+                          appViewWithLiveness
+                              .appInfo()
+                              .unsafeResolveMethodDueToDexFormat(method)
+                              .getResolvedProgramMethod();
+                      if (resolvedMethod == null) {
+                        return;
+                      }
+
+                      ComposableCallGraphNode callee = nodes.get(resolvedMethod);
+                      if (callee == null || !callee.isComposable()) {
+                        // Only record calls to Composable functions.
+                        return;
+                      }
+
+                      ComposableCallGraphNode caller =
+                          nodes.computeIfAbsent(
+                              getContext(), context -> new ComposableCallGraphNode(context, false));
+                      callee.addCaller(caller);
+                    }
+
+                    @Override
+                    public void registerInitClass(DexType type) {}
+
+                    @Override
+                    public void registerInvokeDirect(DexMethod method) {}
+
+                    @Override
+                    public void registerInvokeInterface(DexMethod method) {}
+
+                    @Override
+                    public void registerInvokeSuper(DexMethod method) {}
+
+                    @Override
+                    public void registerInvokeVirtual(DexMethod method) {}
+
+                    @Override
+                    public void registerInstanceFieldRead(DexField field) {}
+
+                    @Override
+                    public void registerInstanceFieldWrite(DexField field) {}
+
+                    @Override
+                    public void registerStaticFieldRead(DexField field) {}
+
+                    @Override
+                    public void registerStaticFieldWrite(DexField field) {}
+
+                    @Override
+                    public void registerTypeReference(DexType type) {}
+                  });
+            });
+      }
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraphNode.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraphNode.java
new file mode 100644
index 0000000..13ec14db
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraphNode.java
@@ -0,0 +1,53 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.utils.SetUtils;
+import java.util.Set;
+import java.util.function.Consumer;
+
+public class ComposableCallGraphNode {
+
+  private final ProgramMethod method;
+  private final boolean isComposable;
+
+  private final Set<ComposableCallGraphNode> callers = SetUtils.newIdentityHashSet();
+  private final Set<ComposableCallGraphNode> callees = SetUtils.newIdentityHashSet();
+
+  ComposableCallGraphNode(ProgramMethod method, boolean isComposable) {
+    this.method = method;
+    this.isComposable = isComposable;
+  }
+
+  public void addCaller(ComposableCallGraphNode caller) {
+    callers.add(caller);
+    caller.callees.add(this);
+  }
+
+  public void forEachComposableCallee(Consumer<ComposableCallGraphNode> consumer) {
+    for (ComposableCallGraphNode callee : callees) {
+      if (callee.isComposable()) {
+        consumer.accept(callee);
+      }
+    }
+  }
+
+  public Set<ComposableCallGraphNode> getCallers() {
+    return callers;
+  }
+
+  public ProgramMethod getMethod() {
+    return method;
+  }
+
+  public boolean isComposable() {
+    return isComposable;
+  }
+
+  @Override
+  public String toString() {
+    return method.toString();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java
new file mode 100644
index 0000000..244dbb4
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java
@@ -0,0 +1,116 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.ir.conversion.PrimaryR8IRConverter;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.InternalOptions.TestingOptions;
+import com.android.tools.r8.utils.SetUtils;
+import com.android.tools.r8.utils.Timing;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+
+public class ComposableOptimizationPass {
+
+  private final AppView<AppInfoWithLiveness> appView;
+  private final PrimaryR8IRConverter converter;
+
+  private ComposableOptimizationPass(
+      AppView<AppInfoWithLiveness> appView, PrimaryR8IRConverter converter) {
+    this.appView = appView;
+    this.converter = converter;
+  }
+
+  public static void run(
+      AppView<AppInfoWithLiveness> appView, PrimaryR8IRConverter converter, Timing timing) {
+    InternalOptions options = appView.options();
+    if (!options.isOptimizing() || !options.isShrinking()) {
+      return;
+    }
+    TestingOptions testingOptions = options.getTestingOptions();
+    if (!testingOptions.enableComposableOptimizationPass
+        || !testingOptions.modelUnknownChangedAndDefaultArgumentsToComposableFunctions) {
+      return;
+    }
+    timing.time(
+        "ComposableOptimizationPass",
+        () -> new ComposableOptimizationPass(appView, converter).processWaves());
+  }
+
+  void processWaves() {
+    ComposableCallGraph callGraph = ComposableCallGraph.builder(appView).build();
+    ComposeMethodProcessor methodProcessor =
+        new ComposeMethodProcessor(appView, callGraph, converter);
+    Set<ComposableCallGraphNode> wave = createInitialWave(callGraph);
+    while (!wave.isEmpty()) {
+      Set<ComposableCallGraphNode> optimizedComposableFunctions = methodProcessor.processWave(wave);
+      wave = createNextWave(methodProcessor, optimizedComposableFunctions);
+    }
+  }
+
+  // TODO(b/302483644): Should we skip root @Composable functions that don't have any nested
+  //  @Composable functions (?).
+  private Set<ComposableCallGraphNode> computeComposableRoots(ComposableCallGraph callGraph) {
+    Set<ComposableCallGraphNode> composableRoots = Sets.newIdentityHashSet();
+    callGraph.forEachNode(
+        node -> {
+          if (!node.isComposable()
+              || Iterables.any(node.getCallers(), ComposableCallGraphNode::isComposable)) {
+            // This is not a @Composable root.
+            return;
+          }
+          if (node.getCallers().isEmpty()) {
+            // Don't include root @Composable functions that are never called. These are either kept
+            // or will be removed in tree shaking.
+            return;
+          }
+          composableRoots.add(node);
+        });
+    return composableRoots;
+  }
+
+  private Set<ComposableCallGraphNode> createInitialWave(ComposableCallGraph callGraph) {
+    Set<ComposableCallGraphNode> wave = Sets.newIdentityHashSet();
+    Set<ComposableCallGraphNode> composableRoots = computeComposableRoots(callGraph);
+    composableRoots.forEach(composableRoot -> wave.addAll(composableRoot.getCallers()));
+    return wave;
+  }
+
+  // TODO(b/302483644): Consider repeatedly extracting the roots from the graph similar to the way
+  //  we extract leaves in the primary optimization pass.
+  private static Set<ComposableCallGraphNode> createNextWave(
+      ComposeMethodProcessor methodProcessor,
+      Set<ComposableCallGraphNode> optimizedComposableFunctions) {
+    Set<ComposableCallGraphNode> nextWave =
+        SetUtils.newIdentityHashSet(optimizedComposableFunctions);
+
+    // If the new wave contains two @Composable functions where one calls the other, then defer the
+    // processing of the callee to a later wave, to ensure that we have seen all of its callers
+    // before processing the callee.
+    List<ComposableCallGraphNode> deferredComposableFunctions = new ArrayList<>();
+    nextWave.forEach(
+        node -> {
+          if (SetUtils.containsAnyOf(nextWave, node.getCallers())) {
+            deferredComposableFunctions.add(node);
+          }
+        });
+    deferredComposableFunctions.forEach(nextWave::remove);
+
+    // To optimize the @Composable functions that are called from the @Composable functions of the
+    // next wave in the wave after that, we need to include their callers in the next wave as well.
+    Set<ComposableCallGraphNode> callersOfCalledComposableFunctions = Sets.newIdentityHashSet();
+    nextWave.forEach(
+        node ->
+            node.forEachComposableCallee(
+                callee -> callersOfCalledComposableFunctions.addAll(callee.getCallers())));
+    nextWave.addAll(callersOfCalledComposableFunctions);
+    nextWave.removeIf(methodProcessor::isProcessed);
+    return nextWave;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
new file mode 100644
index 0000000..8ee6237
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
@@ -0,0 +1,171 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.contexts.CompilationContext.ProcessorContext;
+import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.constant.SparseConditionalConstantPropagation;
+import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.ir.code.AbstractValueSupplier;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.MethodConversionOptions;
+import com.android.tools.r8.ir.conversion.MethodProcessor;
+import com.android.tools.r8.ir.conversion.MethodProcessorEventConsumer;
+import com.android.tools.r8.ir.conversion.PrimaryR8IRConverter;
+import com.android.tools.r8.ir.conversion.callgraph.CallSiteInformation;
+import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
+import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorCodeScanner;
+import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorOptimizationInfoPopulator;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMonomorphicMethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteParameterState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ParameterState;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.LazyBox;
+import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import java.util.Map;
+import java.util.Set;
+
+public class ComposeMethodProcessor extends MethodProcessor {
+
+  private final AppView<AppInfoWithLiveness> appView;
+  private final ArgumentPropagatorCodeScanner codeScanner;
+  private final PrimaryR8IRConverter converter;
+
+  private final Set<ComposableCallGraphNode> processed = Sets.newIdentityHashSet();
+
+  public ComposeMethodProcessor(
+      AppView<AppInfoWithLiveness> appView,
+      ComposableCallGraph callGraph,
+      PrimaryR8IRConverter converter) {
+    this.appView = appView;
+    this.codeScanner = new ArgumentPropagatorCodeScannerForComposableFunctions(appView, callGraph);
+    this.converter = converter;
+  }
+
+  // TODO(b/302483644): Process wave concurrently.
+  public Set<ComposableCallGraphNode> processWave(Set<ComposableCallGraphNode> wave) {
+    ProcessorContext processorContext = appView.createProcessorContext();
+    wave.forEach(
+        node -> {
+          assert !processed.contains(node);
+          converter.processDesugaredMethod(
+              node.getMethod(),
+              OptimizationFeedback.getIgnoreFeedback(),
+              this,
+              processorContext.createMethodProcessingContext(node.getMethod()),
+              MethodConversionOptions.forLirPhase(appView));
+        });
+    processed.addAll(wave);
+    return optimizeComposableFunctionsCalledFromWave(wave);
+  }
+
+  private Set<ComposableCallGraphNode> optimizeComposableFunctionsCalledFromWave(
+      Set<ComposableCallGraphNode> wave) {
+    ArgumentPropagatorOptimizationInfoPopulator optimizationInfoPopulator =
+        new ArgumentPropagatorOptimizationInfoPopulator(appView, null, null, null);
+    Set<ComposableCallGraphNode> optimizedComposableFunctions = Sets.newIdentityHashSet();
+    wave.forEach(
+        node ->
+            node.forEachComposableCallee(
+                callee -> {
+                  if (Iterables.all(callee.getCallers(), this::isProcessed)) {
+                    optimizationInfoPopulator.setOptimizationInfo(
+                        callee.getMethod(), ProgramMethodSet.empty(), getMethodState(callee));
+                    // TODO(b/302483644): Only enqueue this callee if its optimization info changed.
+                    optimizedComposableFunctions.add(callee);
+                  }
+                }));
+    return optimizedComposableFunctions;
+  }
+
+  private MethodState getMethodState(ComposableCallGraphNode node) {
+    assert processed.containsAll(node.getCallers());
+    MethodState methodState = codeScanner.getMethodStates().get(node.getMethod());
+    return widenMethodState(methodState);
+  }
+
+  /**
+   * If a parameter state of the current method state encodes that it is greater than (lattice wise)
+   * than another parameter in the program, then widen the parameter state to unknown. This is
+   * needed since we are not guaranteed to have seen all possible call sites of the callers of this
+   * method.
+   */
+  private MethodState widenMethodState(MethodState methodState) {
+    assert !methodState.isBottom();
+    assert !methodState.isPolymorphic();
+    if (methodState.isMonomorphic()) {
+      ConcreteMonomorphicMethodState monomorphicMethodState = methodState.asMonomorphic();
+      for (int i = 0; i < monomorphicMethodState.size(); i++) {
+        if (monomorphicMethodState.getParameterState(i).isConcrete()) {
+          ConcreteParameterState concreteParameterState =
+              monomorphicMethodState.getParameterState(i).asConcrete();
+          if (concreteParameterState.hasInParameters()) {
+            monomorphicMethodState.setParameterState(i, ParameterState.unknown());
+          }
+        }
+      }
+    } else {
+      assert methodState.isUnknown();
+    }
+    return methodState;
+  }
+
+  public void scan(ProgramMethod method, IRCode code, Timing timing) {
+    LazyBox<Map<Value, AbstractValue>> abstractValues =
+        new LazyBox<>(() -> new SparseConditionalConstantPropagation(appView).analyze(code));
+    AbstractValueSupplier abstractValueSupplier =
+        value -> {
+          AbstractValue abstractValue = abstractValues.computeIfAbsent().get(value);
+          assert abstractValue != null;
+          return abstractValue;
+        };
+    codeScanner.scan(method, code, abstractValueSupplier, timing);
+  }
+
+  public boolean isProcessed(ComposableCallGraphNode node) {
+    return processed.contains(node);
+  }
+
+  @Override
+  public CallSiteInformation getCallSiteInformation() {
+    return CallSiteInformation.empty();
+  }
+
+  @Override
+  public MethodProcessorEventConsumer getEventConsumer() {
+    throw new Unreachable();
+  }
+
+  @Override
+  public boolean isComposeMethodProcessor() {
+    return true;
+  }
+
+  @Override
+  public ComposeMethodProcessor asComposeMethodProcessor() {
+    return this;
+  }
+
+  @Override
+  public boolean isProcessedConcurrently(ProgramMethod method) {
+    return false;
+  }
+
+  @Override
+  public void scheduleDesugaredMethodForProcessing(ProgramMethod method) {
+    throw new Unreachable();
+  }
+
+  @Override
+  public boolean shouldApplyCodeRewritings(ProgramMethod method) {
+    return false;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
index 7a489a7..e441188 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
@@ -7,6 +7,7 @@
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.lens.GraphLens;
 
 public class ComposeReferences {
 
@@ -31,4 +32,26 @@
             factory.createProto(factory.intType, factory.intType),
             "updateChangedFlags");
   }
+
+  public ComposeReferences(
+      DexString changedFieldName,
+      DexString defaultFieldName,
+      DexType composableType,
+      DexType composerType,
+      DexMethod updatedChangedFlagsMethod) {
+    this.changedFieldName = changedFieldName;
+    this.defaultFieldName = defaultFieldName;
+    this.composableType = composableType;
+    this.composerType = composerType;
+    this.updatedChangedFlagsMethod = updatedChangedFlagsMethod;
+  }
+
+  public ComposeReferences rewrittenWithLens(GraphLens graphLens, GraphLens codeLens) {
+    return new ComposeReferences(
+        changedFieldName,
+        defaultFieldName,
+        graphLens.lookupClassType(composableType, codeLens),
+        graphLens.lookupClassType(composerType, codeLens),
+        graphLens.getRenamedMethodSignature(updatedChangedFlagsMethod, codeLens));
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index ffae35b..17e1f0a 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -2410,6 +2410,9 @@
         System.getProperty("com.android.tools.r8.disableMarkingClassesFinal") != null;
     public boolean testEnableTestAssertions = false;
     public boolean keepMetadataInR8IfNotRewritten = true;
+    public boolean enableComposableOptimizationPass =
+        SystemPropertyUtils.parseSystemPropertyForDevelopmentOrDefault(
+            "com.android.tools.r8.enableComposableOptimizationPass", false);
     public boolean modelUnknownChangedAndDefaultArgumentsToComposableFunctions =
         SystemPropertyUtils.parseSystemPropertyForDevelopmentOrDefault(
             "com.android.tools.r8.modelUnknownChangedAndDefaultArgumentsToComposableFunctions",
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
index bbc27a2..eb747e4 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
@@ -12,6 +12,7 @@
 import java.util.function.BiConsumer;
 import java.util.function.BiFunction;
 import java.util.function.BiPredicate;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
@@ -47,6 +48,10 @@
     backing.forEach((wrapper, value) -> consumer.accept(wrapper.get(), value));
   }
 
+  public void forEachValue(Consumer<V> consumer) {
+    backing.values().forEach(consumer);
+  }
+
   public V get(K member) {
     return backing.get(wrap(member));
   }