Refactor argument propagation to allow subsequent partial runs

Change-Id: I239367f6193c2bb87834768fa8077dec25c36cf7
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
index 2f90ff5..97ead9d 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
@@ -9,6 +9,7 @@
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.code.AbstractValueSupplier;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
@@ -112,7 +113,9 @@
       ProgramMethod method, IRCode code, MethodProcessor methodProcessor, Timing timing) {
     if (codeScanner != null) {
       assert methodProcessor.isPrimaryMethodProcessor();
-      codeScanner.scan(method, code, timing);
+      AbstractValueSupplier abstractValueSupplier =
+          value -> value.getAbstractValue(appView, method);
+      codeScanner.scan(method, code, abstractValueSupplier, timing);
 
       assert effectivelyUnusedArgumentsAnalysis != null;
       effectivelyUnusedArgumentsAnalysis.scan(method, code);
@@ -226,14 +229,16 @@
     postMethodProcessorBuilder.rewrittenWithLens(appView);
 
     timing.begin("Compute optimization info");
-    new ArgumentPropagatorOptimizationInfoPopulator(
+    new ArgumentPropagatorOptimizationInfoPropagator(
             appView,
             converter,
             immediateSubtypingInfo,
             codeScannerResult,
-            postMethodProcessorBuilder,
             stronglyConnectedProgramComponents,
             interfaceDispatchOutsideProgram)
+        .propagateOptimizationInfo(executorService, timing);
+    new ArgumentPropagatorOptimizationInfoPopulator(
+            appView, converter, codeScannerResult, postMethodProcessorBuilder)
         .populateOptimizationInfo(executorService, timing);
     timing.end();
 
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
index 49d6400..0dc55d3 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
@@ -17,6 +17,7 @@
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.ir.code.AbstractValueSupplier;
 import com.android.tools.r8.ir.code.AliasedValueConfiguration;
 import com.android.tools.r8.ir.code.AssumeAndCheckCastAliasedValueConfiguration;
 import com.android.tools.r8.ir.code.IRCode;
@@ -89,6 +90,10 @@
   private final MethodStateCollectionByReference methodStates =
       MethodStateCollectionByReference.createConcurrent();
 
+  public ArgumentPropagatorCodeScanner(AppView<AppInfoWithLiveness> appView) {
+    this(appView, new ArgumentPropagatorReprocessingCriteriaCollection(appView));
+  }
+
   ArgumentPropagatorCodeScanner(
       AppView<AppInfoWithLiveness> appView,
       ArgumentPropagatorReprocessingCriteriaCollection reprocessingCriteriaCollection) {
@@ -105,7 +110,7 @@
     virtualRootMethods.putAll(extension);
   }
 
-  MethodStateCollectionByReference getMethodStates() {
+  public MethodStateCollectionByReference getMethodStates() {
     return methodStates;
   }
 
@@ -113,7 +118,8 @@
     return virtualRootMethods.get(method.getReference());
   }
 
-  boolean isMethodParameterAlreadyUnknown(MethodParameter methodParameter, ProgramMethod method) {
+  protected boolean isMethodParameterAlreadyUnknown(
+      MethodParameter methodParameter, ProgramMethod method) {
     MethodState methodState =
         methodStates.get(
             method.getDefinition().belongsToDirectPool() || isMonomorphicVirtualMethod(method)
@@ -141,19 +147,27 @@
     return monomorphicVirtualMethods.contains(method);
   }
 
-  void scan(ProgramMethod method, IRCode code, Timing timing) {
+  public void scan(
+      ProgramMethod method,
+      IRCode code,
+      AbstractValueSupplier abstractValueSupplier,
+      Timing timing) {
     timing.begin("Argument propagation scanner");
     for (Invoke invoke : code.<Invoke>instructions(Instruction::isInvoke)) {
       if (invoke.isInvokeMethod()) {
-        scan(invoke.asInvokeMethod(), method, timing);
+        scan(invoke.asInvokeMethod(), abstractValueSupplier, method, timing);
       } else if (invoke.isInvokeCustom()) {
-        scan(invoke.asInvokeCustom(), method);
+        scan(invoke.asInvokeCustom());
       }
     }
     timing.end();
   }
 
-  private void scan(InvokeMethod invoke, ProgramMethod context, Timing timing) {
+  private void scan(
+      InvokeMethod invoke,
+      AbstractValueSupplier abstractValueSupplier,
+      ProgramMethod context,
+      Timing timing) {
     DexMethod invokedMethod = invoke.getInvokedMethod();
     if (invokedMethod.getHolderType().isArrayType()) {
       // Nothing to propagate; the targeted method is not a program method.
@@ -231,13 +245,27 @@
     // possible dispatch targets and propagate the information to these methods (this is expensive).
     // Instead we record the information in one place and then later propagate the information to
     // all dispatch targets.
-    ProgramMethod finalResolvedMethod = resolvedMethod;
+    addTemporaryMethodState(invoke, resolvedMethod, abstractValueSupplier, context, timing);
+  }
+
+  protected void addTemporaryMethodState(
+      InvokeMethod invoke,
+      ProgramMethod resolvedMethod,
+      AbstractValueSupplier abstractValueSupplier,
+      ProgramMethod context,
+      Timing timing) {
     timing.begin("Add method state");
     methodStates.addTemporaryMethodState(
         appView,
         getRepresentative(invoke, resolvedMethod),
         existingMethodState ->
-            computeMethodState(invoke, finalResolvedMethod, context, existingMethodState, timing),
+            computeMethodState(
+                invoke,
+                resolvedMethod,
+                abstractValueSupplier,
+                context,
+                existingMethodState,
+                timing),
         timing);
     timing.end();
   }
@@ -245,6 +273,7 @@
   private MethodState computeMethodState(
       InvokeMethod invoke,
       ProgramMethod resolvedMethod,
+      AbstractValueSupplier abstractValueSupplier,
       ProgramMethod context,
       MethodState existingMethodState,
       Timing timing) {
@@ -262,6 +291,7 @@
           computePolymorphicMethodState(
               invoke.asInvokeMethodWithReceiver(),
               resolvedMethod,
+              abstractValueSupplier,
               context,
               existingMethodState.asPolymorphicOrBottom());
     } else {
@@ -271,6 +301,7 @@
               invoke,
               resolvedMethod,
               invoke.lookupSingleProgramTarget(appView, context),
+              abstractValueSupplier,
               context,
               existingMethodState.asMonomorphicOrBottom());
     }
@@ -285,6 +316,7 @@
   private MethodState computePolymorphicMethodState(
       InvokeMethodWithReceiver invoke,
       ProgramMethod resolvedMethod,
+      AbstractValueSupplier abstractValueSupplier,
       ProgramMethod context,
       ConcretePolymorphicMethodStateOrBottom existingMethodState) {
     DynamicTypeWithUpperBound dynamicReceiverType = invoke.getReceiver().getDynamicType(appView);
@@ -320,6 +352,7 @@
             invoke,
             resolvedMethod,
             singleTarget,
+            abstractValueSupplier,
             context,
             existingMethodStateForBounds.asMonomorphicOrBottom(),
             dynamicReceiverType);
@@ -363,12 +396,14 @@
       InvokeMethod invoke,
       ProgramMethod resolvedMethod,
       ProgramMethod singleTarget,
+      AbstractValueSupplier abstractValueSupplier,
       ProgramMethod context,
       ConcreteMonomorphicMethodStateOrBottom existingMethodState) {
     return computeMonomorphicMethodState(
         invoke,
         resolvedMethod,
         singleTarget,
+        abstractValueSupplier,
         context,
         existingMethodState,
         invoke.isInvokeMethodWithReceiver()
@@ -381,6 +416,7 @@
       InvokeMethod invoke,
       ProgramMethod resolvedMethod,
       ProgramMethod singleTarget,
+      AbstractValueSupplier abstractValueSupplier,
       ProgramMethod context,
       ConcreteMonomorphicMethodStateOrBottom existingMethodState,
       DynamicType dynamicReceiverType) {
@@ -410,6 +446,7 @@
               singleTarget,
               argumentIndex,
               invoke.getArgument(argumentIndex),
+              abstractValueSupplier,
               context,
               existingMethodState));
     }
@@ -454,11 +491,12 @@
       ProgramMethod singleTarget,
       int argumentIndex,
       Value argument,
+      AbstractValueSupplier abstractValueSupplier,
       ProgramMethod context,
       ConcreteMonomorphicMethodStateOrBottom existingMethodState) {
     ParameterState modeledState =
         modeling.modelParameterStateForArgumentToFunction(
-            invoke, singleTarget, argumentIndex, argument);
+            invoke, singleTarget, argumentIndex, argument, context);
     if (modeledState != null) {
       return modeledState;
     }
@@ -504,7 +542,7 @@
           : new ConcreteArrayTypeParameterState(nullability);
     }
 
-    AbstractValue abstractValue = argument.getAbstractValue(appView, context);
+    AbstractValue abstractValue = abstractValueSupplier.getAbstractValue(argument);
 
     // For class types, we track both the abstract value and the dynamic type. If both are unknown,
     // then use UnknownParameterState.
@@ -555,8 +593,7 @@
         && !isMonomorphicVirtualMethod(getRepresentative(invoke, resolvedMethod));
   }
 
-  @SuppressWarnings("UnusedVariable")
-  private void scan(InvokeCustom invoke, ProgramMethod context) {
+  private void scan(InvokeCustom invoke) {
     // If the bootstrap method is program declared it will be called. The call is with runtime
     // provided arguments so ensure that the argument information is unknown.
     DexMethodHandle bootstrapMethod = invoke.getCallSite().bootstrapMethod;
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScannerModeling.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScannerModeling.java
index 7ab506e..7931180 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScannerModeling.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScannerModeling.java
@@ -23,10 +23,14 @@
   }
 
   ParameterState modelParameterStateForArgumentToFunction(
-      InvokeMethod invoke, ProgramMethod singleTarget, int argumentIndex, Value argument) {
+      InvokeMethod invoke,
+      ProgramMethod singleTarget,
+      int argumentIndex,
+      Value argument,
+      ProgramMethod context) {
     if (composeModeling != null) {
       return composeModeling.modelParameterStateForChangedOrDefaultArgumentToComposableFunction(
-          invoke, singleTarget, argumentIndex, argument);
+          invoke, singleTarget, argumentIndex, argument, context);
     }
     return null;
   }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
index c626347..ea9d69a 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
@@ -7,10 +7,8 @@
 import static com.android.tools.r8.ir.optimize.info.OptimizationFeedback.getSimpleFeedback;
 
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexMethodSignature;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.analysis.type.DynamicType;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
@@ -28,9 +26,6 @@
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollectionByReference;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ParameterState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.StateCloner;
-import com.android.tools.r8.optimize.argumentpropagation.propagation.InParameterFlowPropagator;
-import com.android.tools.r8.optimize.argumentpropagation.propagation.InterfaceMethodArgumentPropagator;
-import com.android.tools.r8.optimize.argumentpropagation.propagation.VirtualDispatchMethodArgumentPropagator;
 import com.android.tools.r8.optimize.argumentpropagation.utils.WideningUtils;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.InternalOptions;
@@ -40,10 +35,8 @@
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import java.util.List;
-import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
-import java.util.function.BiConsumer;
 
 /**
  * Propagates the argument flow information collected by the {@link ArgumentPropagatorCodeScanner}.
@@ -58,28 +51,16 @@
   private final InternalOptions options;
   private final PostMethodProcessor.Builder postMethodProcessorBuilder;
 
-  private final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
-  private final List<Set<DexProgramClass>> stronglyConnectedProgramComponents;
-
-  private final BiConsumer<Set<DexProgramClass>, DexMethodSignature>
-      interfaceDispatchOutsideProgram;
-
-  ArgumentPropagatorOptimizationInfoPopulator(
+  public ArgumentPropagatorOptimizationInfoPopulator(
       AppView<AppInfoWithLiveness> appView,
       PrimaryR8IRConverter converter,
-      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
       MethodStateCollectionByReference methodStates,
-      PostMethodProcessor.Builder postMethodProcessorBuilder,
-      List<Set<DexProgramClass>> stronglyConnectedProgramComponents,
-      BiConsumer<Set<DexProgramClass>, DexMethodSignature> interfaceDispatchOutsideProgram) {
+      PostMethodProcessor.Builder postMethodProcessorBuilder) {
     this.appView = appView;
     this.converter = converter;
-    this.immediateSubtypingInfo = immediateSubtypingInfo;
     this.methodStates = methodStates;
     this.options = appView.options();
     this.postMethodProcessorBuilder = postMethodProcessorBuilder;
-    this.stronglyConnectedProgramComponents = stronglyConnectedProgramComponents;
-    this.interfaceDispatchOutsideProgram = interfaceDispatchOutsideProgram;
   }
 
   /**
@@ -88,24 +69,6 @@
    */
   void populateOptimizationInfo(ExecutorService executorService, Timing timing)
       throws ExecutionException {
-    // TODO(b/190154391): Propagate argument information to handle virtual dispatch.
-    // TODO(b/190154391): To deal with arguments that are themselves passed as arguments to invoke
-    //  instructions, build a flow graph where nodes are parameters and there is an edge from a
-    //  parameter p1 to p2 if the value of p2 is at least the value of p1. Then propagate the
-    //  collected argument information throughout the flow graph.
-    timing.begin("Propagate argument information for virtual methods");
-    ThreadUtils.processItems(
-        stronglyConnectedProgramComponents,
-        this::processStronglyConnectedComponent,
-        appView.options().getThreadingModule(),
-        executorService);
-    timing.end();
-
-    // Solve the parameter flow constraints.
-    timing.begin("Solve flow constraints");
-    new InParameterFlowPropagator(appView, converter, methodStates).run(executorService);
-    timing.end();
-
     // The information stored on each method is now sound, and can be used as optimization info.
     timing.begin("Set optimization info");
     setOptimizationInfo(executorService);
@@ -114,41 +77,6 @@
     assert methodStates.isEmpty();
   }
 
-  private void processStronglyConnectedComponent(Set<DexProgramClass> stronglyConnectedComponent) {
-    // Invoke instructions that target interface methods may dispatch to methods that are not
-    // defined on a subclass of the interface method holder.
-    //
-    // Example: Calling I.m() will dispatch to A.m(), but A is not a subtype of I.
-    //
-    //   class A { public void m() {} }
-    //   interface I { void m(); }
-    //   class B extends A implements I {}
-    //
-    // To handle this we first propagate any argument information stored for I.m() to A.m() by doing
-    // a top-down traversal over the interfaces in the strongly connected component.
-    new InterfaceMethodArgumentPropagator(
-            appView,
-            immediateSubtypingInfo,
-            methodStates,
-            signature ->
-                interfaceDispatchOutsideProgram.accept(stronglyConnectedComponent, signature))
-        .run(stronglyConnectedComponent);
-
-    // Now all the argument information for a given method is guaranteed to be stored on a supertype
-    // of the method's holder. All that remains is to propagate the information downwards in the
-    // class hierarchy to propagate the argument information for a non-private virtual method to its
-    // overrides.
-    // TODO(b/190154391): Before running the top-down traversal, consider lowering the argument
-    //  information for non-private virtual methods. If we have some argument information with upper
-    //  bound=B, which is stored on a method on class A, we could move this argument information
-    //  from class A to B. This way we could potentially get rid of the "inactive argument
-    //  information" during the depth-first class hierarchy traversal, since the argument
-    //  information would be active by construction when it is first seen during the top-down class
-    //  hierarchy traversal.
-    new VirtualDispatchMethodArgumentPropagator(appView, immediateSubtypingInfo, methodStates)
-        .run(stronglyConnectedComponent);
-  }
-
   private void setOptimizationInfo(ExecutorService executorService) throws ExecutionException {
     ProgramMethodSet prunedMethods = ProgramMethodSet.createConcurrent();
     ThreadUtils.processItems(
@@ -170,8 +98,12 @@
     return prunedMethods;
   }
 
-  private void setOptimizationInfo(ProgramMethod method, ProgramMethodSet prunedMethods) {
-    MethodState methodState = methodStates.remove(method);
+  public void setOptimizationInfo(ProgramMethod method, ProgramMethodSet prunedMethods) {
+    setOptimizationInfo(method, prunedMethods, methodStates.remove(method));
+  }
+
+  public void setOptimizationInfo(
+      ProgramMethod method, ProgramMethodSet prunedMethods, MethodState methodState) {
     if (methodState.isBottom()) {
       if (method.getDefinition().isClassInitializer()) {
         return;
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPropagator.java
new file mode 100644
index 0000000..2c8f5e4
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPropagator.java
@@ -0,0 +1,108 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.argumentpropagation;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexMethodSignature;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.ir.conversion.PrimaryR8IRConverter;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollectionByReference;
+import com.android.tools.r8.optimize.argumentpropagation.propagation.InParameterFlowPropagator;
+import com.android.tools.r8.optimize.argumentpropagation.propagation.InterfaceMethodArgumentPropagator;
+import com.android.tools.r8.optimize.argumentpropagation.propagation.VirtualDispatchMethodArgumentPropagator;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.Timing;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BiConsumer;
+
+/**
+ * Propagates the argument flow information collected by the {@link ArgumentPropagatorCodeScanner}.
+ * This is needed to propagate argument information from call sites to all possible dispatch
+ * targets.
+ */
+public class ArgumentPropagatorOptimizationInfoPropagator {
+
+  private final AppView<AppInfoWithLiveness> appView;
+  private final PrimaryR8IRConverter converter;
+  private final MethodStateCollectionByReference methodStates;
+
+  private final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
+  private final List<Set<DexProgramClass>> stronglyConnectedProgramComponents;
+
+  private final BiConsumer<Set<DexProgramClass>, DexMethodSignature>
+      interfaceDispatchOutsideProgram;
+
+  ArgumentPropagatorOptimizationInfoPropagator(
+      AppView<AppInfoWithLiveness> appView,
+      PrimaryR8IRConverter converter,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
+      MethodStateCollectionByReference methodStates,
+      List<Set<DexProgramClass>> stronglyConnectedProgramComponents,
+      BiConsumer<Set<DexProgramClass>, DexMethodSignature> interfaceDispatchOutsideProgram) {
+    this.appView = appView;
+    this.converter = converter;
+    this.immediateSubtypingInfo = immediateSubtypingInfo;
+    this.methodStates = methodStates;
+    this.stronglyConnectedProgramComponents = stronglyConnectedProgramComponents;
+    this.interfaceDispatchOutsideProgram = interfaceDispatchOutsideProgram;
+  }
+
+  /** Computes an over-approximation of each parameter's value and type. */
+  void propagateOptimizationInfo(ExecutorService executorService, Timing timing)
+      throws ExecutionException {
+    timing.begin("Propagate argument information for virtual methods");
+    ThreadUtils.processItems(
+        stronglyConnectedProgramComponents,
+        this::processStronglyConnectedComponent,
+        appView.options().getThreadingModule(),
+        executorService);
+    timing.end();
+
+    // Solve the parameter flow constraints.
+    timing.begin("Solve flow constraints");
+    new InParameterFlowPropagator(appView, converter, methodStates).run(executorService);
+    timing.end();
+  }
+
+  private void processStronglyConnectedComponent(Set<DexProgramClass> stronglyConnectedComponent) {
+    // Invoke instructions that target interface methods may dispatch to methods that are not
+    // defined on a subclass of the interface method holder.
+    //
+    // Example: Calling I.m() will dispatch to A.m(), but A is not a subtype of I.
+    //
+    //   class A { public void m() {} }
+    //   interface I { void m(); }
+    //   class B extends A implements I {}
+    //
+    // To handle this we first propagate any argument information stored for I.m() to A.m() by doing
+    // a top-down traversal over the interfaces in the strongly connected component.
+    new InterfaceMethodArgumentPropagator(
+            appView,
+            immediateSubtypingInfo,
+            methodStates,
+            signature ->
+                interfaceDispatchOutsideProgram.accept(stronglyConnectedComponent, signature))
+        .run(stronglyConnectedComponent);
+
+    // Now all the argument information for a given method is guaranteed to be stored on a supertype
+    // of the method's holder. All that remains is to propagate the information downwards in the
+    // class hierarchy to propagate the argument information for a non-private virtual method to its
+    // overrides.
+    // TODO(b/190154391): Before running the top-down traversal, consider lowering the argument
+    //  information for non-private virtual methods. If we have some argument information with upper
+    //  bound=B, which is stored on a method on class A, we could move this argument information
+    //  from class A to B. This way we could potentially get rid of the "inactive argument
+    //  information" during the depth-first class hierarchy traversal, since the argument
+    //  information would be active by construction when it is first seen during the top-down class
+    //  hierarchy traversal.
+    new VirtualDispatchMethodArgumentPropagator(appView, immediateSubtypingInfo, methodStates)
+        .run(stronglyConnectedComponent);
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
index 49cd7d5..c4fd469 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
@@ -67,8 +67,14 @@
    *   }
    * </pre>
    */
+  // TODO(b/302483644): Only apply modeling when the context is recognized as being a restart
+  //  lambda.
   public ParameterState modelParameterStateForChangedOrDefaultArgumentToComposableFunction(
-      InvokeMethod invoke, ProgramMethod singleTarget, int argumentIndex, Value argument) {
+      InvokeMethod invoke,
+      ProgramMethod singleTarget,
+      int argumentIndex,
+      Value argument,
+      ProgramMethod context) {
     // First check if this is an invoke to a @Composable function.
     if (singleTarget == null
         || !singleTarget