Solve argument propagation flow constraints

Bug: 190154391
Change-Id: Id6ab669667ab4530871ee5ca33f3210ee927c00c
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
index c0e5aa4..918db59 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CallSiteOptimizationInfoPropagator.java
@@ -14,19 +14,11 @@
 import com.android.tools.r8.graph.LookupResult;
 import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.graph.ProgramMethod;
-import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
-import com.android.tools.r8.ir.analysis.type.TypeElement;
-import com.android.tools.r8.ir.analysis.value.AbstractValue;
-import com.android.tools.r8.ir.analysis.value.SingleValue;
-import com.android.tools.r8.ir.code.Assume;
-import com.android.tools.r8.ir.code.ConstNumber;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
-import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.InvokeCustom;
 import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.InvokeMethodWithReceiver;
-import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.PostOptimization;
 import com.android.tools.r8.ir.optimize.info.CallSiteOptimizationInfo;
 import com.android.tools.r8.ir.optimize.info.ConcreteCallSiteOptimizationInfo;
@@ -41,8 +33,6 @@
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.Sets;
-import java.util.LinkedList;
-import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
index e5ee744..9449dda 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorOptimizationInfoPopulator.java
@@ -7,7 +7,15 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.optimize.info.ConcreteCallSiteOptimizationInfo;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMonomorphicMethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteParameterState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollection;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ParameterState;
+import com.android.tools.r8.optimize.argumentpropagation.propagation.InParameterFlowPropagator;
 import com.android.tools.r8.optimize.argumentpropagation.propagation.InterfaceMethodArgumentPropagator;
 import com.android.tools.r8.optimize.argumentpropagation.propagation.VirtualDispatchMethodArgumentPropagator;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
@@ -96,6 +104,13 @@
     //  that the method returns the constant.
     ThreadUtils.processItems(
         stronglyConnectedComponents, this::processStronglyConnectedComponent, executorService);
+
+    // Solve the parameter flow constraints.
+    new InParameterFlowPropagator(appView, methodStates).run(executorService);
+
+    // The information stored on each method is now sound, and can be used as optimization info.
+    setOptimizationInfo(executorService);
+
     assert methodStates.isEmpty();
   }
 
@@ -121,4 +136,46 @@
     new VirtualDispatchMethodArgumentPropagator(appView, immediateSubtypingInfo, methodStates)
         .run(stronglyConnectedComponent);
   }
+
+  private void setOptimizationInfo(ExecutorService executorService) throws ExecutionException {
+    ThreadUtils.processItems(
+        appView.appInfo().classes(), this::setOptimizationInfo, executorService);
+  }
+
+  private void setOptimizationInfo(DexProgramClass clazz) {
+    clazz.forEachProgramMethod(this::setOptimizationInfo);
+  }
+
+  private void setOptimizationInfo(ProgramMethod method) {
+    MethodState methodState = methodStates.remove(method);
+    if (methodState.isBottom()) {
+      // TODO(b/190154391): This should only happen if the method is never called. Consider removing
+      //  the method in this case.
+      return;
+    }
+
+    if (methodState.isUnknown()) {
+      // Nothing is known about the arguments to this method.
+      return;
+    }
+
+    ConcreteMethodState concreteMethodState = methodState.asConcrete();
+    if (concreteMethodState.isPolymorphic()) {
+      assert false;
+      return;
+    }
+
+    ConcreteMonomorphicMethodState monomorphicMethodState = concreteMethodState.asMonomorphic();
+    assert monomorphicMethodState.getParameterStates().stream()
+        .filter(ParameterState::isConcrete)
+        .map(ParameterState::asConcrete)
+        .noneMatch(ConcreteParameterState::hasInParameters);
+
+    method
+        .getDefinition()
+        .joinCallSiteOptimizationInfo(
+            ConcreteCallSiteOptimizationInfo.fromMethodState(
+                appView, method, monomorphicMethodState),
+            appView);
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteArrayTypeParameterState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteArrayTypeParameterState.java
index 7615aaf..e7533e5 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteArrayTypeParameterState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteArrayTypeParameterState.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.utils.Action;
 
 public class ConcreteArrayTypeParameterState extends ConcreteParameterState {
 
@@ -21,13 +22,17 @@
     this.nullability = nullability;
   }
 
-  public ParameterState mutableJoin(ConcreteArrayTypeParameterState parameterState) {
+  public ParameterState mutableJoin(
+      ConcreteArrayTypeParameterState parameterState, Action onChangedAction) {
     assert !nullability.isMaybeNull();
     assert !parameterState.nullability.isMaybeNull();
-    mutableJoinInParameters(parameterState);
+    boolean inParametersChanged = mutableJoinInParameters(parameterState);
     if (widenInParameters()) {
       return unknown();
     }
+    if (inParametersChanged) {
+      onChangedAction.execute();
+    }
     return this;
   }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteClassTypeParameterState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteClassTypeParameterState.java
index 9574d08..3c59a07 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteClassTypeParameterState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteClassTypeParameterState.java
@@ -8,6 +8,7 @@
 import com.android.tools.r8.ir.analysis.type.DynamicType;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Action;
 
 public class ConcreteClassTypeParameterState extends ConcreteParameterState {
 
@@ -26,9 +27,12 @@
   }
 
   public ParameterState mutableJoin(
-      AppView<AppInfoWithLiveness> appView, ConcreteClassTypeParameterState parameterState) {
+      AppView<AppInfoWithLiveness> appView,
+      ConcreteClassTypeParameterState parameterState,
+      Action onChangedAction) {
     boolean allowNullOrAbstractValue = true;
     boolean allowNonConstantNumbers = false;
+    AbstractValue oldAbstractValue = abstractValue;
     abstractValue =
         abstractValue.join(
             parameterState.abstractValue,
@@ -38,15 +42,19 @@
     // TODO(b/190154391): Join the dynamic types using SubtypingInfo.
     // TODO(b/190154391): Take in the static type as an argument, and unset the dynamic type if it
     //  equals the static type.
+    DynamicType oldDynamicType = dynamicType;
     dynamicType =
         dynamicType.equals(parameterState.dynamicType) ? dynamicType : DynamicType.unknown();
     if (abstractValue.isUnknown() && dynamicType.isUnknown()) {
       return unknown();
     }
-    mutableJoinInParameters(parameterState);
+    boolean inParametersChanged = mutableJoinInParameters(parameterState);
     if (widenInParameters()) {
       return unknown();
     }
+    if (abstractValue != oldAbstractValue || dynamicType != oldDynamicType || inParametersChanged) {
+      onChangedAction.execute();
+    }
     return this;
   }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteMonomorphicMethodState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteMonomorphicMethodState.java
index c93f4cf..fc8fe25 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteMonomorphicMethodState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteMonomorphicMethodState.java
@@ -24,6 +24,10 @@
     return parameterStates.get(index);
   }
 
+  public List<ParameterState> getParameterStates() {
+    return parameterStates;
+  }
+
   public ConcreteMonomorphicMethodStateOrUnknown mutableJoin(
       AppView<AppInfoWithLiveness> appView, ConcreteMonomorphicMethodState methodState) {
     if (size() != methodState.size()) {
@@ -53,6 +57,10 @@
     return this;
   }
 
+  public void setParameterState(int index, ParameterState parameterState) {
+    parameterStates.set(index, parameterState);
+  }
+
   public int size() {
     return parameterStates.size();
   }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteParameterState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteParameterState.java
index 1452512..1e5ebae 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteParameterState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteParameterState.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Action;
 import com.android.tools.r8.utils.SetUtils;
 import com.google.common.collect.Sets;
 import java.util.Collections;
@@ -30,6 +31,18 @@
     this.inParameters = SetUtils.newHashSet(inParameter);
   }
 
+  public void clearInParameters() {
+    inParameters.clear();
+  }
+
+  public boolean hasInParameters() {
+    return !inParameters.isEmpty();
+  }
+
+  public Set<MethodParameter> getInParameters() {
+    return inParameters;
+  }
+
   public abstract ConcreteParameterStateKind getKind();
 
   public boolean isArrayParameter() {
@@ -76,7 +89,7 @@
 
   @Override
   public ParameterState mutableJoin(
-      AppView<AppInfoWithLiveness> appView, ParameterState parameterState) {
+      AppView<AppInfoWithLiveness> appView, ParameterState parameterState, Action onChangedAction) {
     if (parameterState.isUnknown()) {
       return parameterState;
     }
@@ -85,16 +98,19 @@
     if (kind == otherKind) {
       switch (getKind()) {
         case ARRAY:
-          return asArrayParameter().mutableJoin(parameterState.asConcrete().asArrayParameter());
+          return asArrayParameter()
+              .mutableJoin(parameterState.asConcrete().asArrayParameter(), onChangedAction);
         case CLASS:
           return asClassParameter()
-              .mutableJoin(appView, parameterState.asConcrete().asClassParameter());
+              .mutableJoin(
+                  appView, parameterState.asConcrete().asClassParameter(), onChangedAction);
         case PRIMITIVE:
           return asPrimitiveParameter()
-              .mutableJoin(appView, parameterState.asConcrete().asPrimitiveParameter());
+              .mutableJoin(
+                  appView, parameterState.asConcrete().asPrimitiveParameter(), onChangedAction);
         case RECEIVER:
           return asReceiverParameter()
-              .mutableJoin(parameterState.asConcrete().asReceiverParameter());
+              .mutableJoin(parameterState.asConcrete().asReceiverParameter(), onChangedAction);
         default:
           // Dead.
       }
@@ -104,15 +120,15 @@
     return unknown();
   }
 
-  void mutableJoinInParameters(ConcreteParameterState parameterState) {
+  boolean mutableJoinInParameters(ConcreteParameterState parameterState) {
     if (parameterState.inParameters.isEmpty()) {
-      return;
+      return false;
     }
     if (inParameters.isEmpty()) {
       assert inParameters == Collections.<MethodParameter>emptySet();
       inParameters = Sets.newIdentityHashSet();
     }
-    inParameters.addAll(parameterState.inParameters);
+    return inParameters.addAll(parameterState.inParameters);
   }
 
   boolean widenInParameters() {
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcretePrimitiveTypeParameterState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcretePrimitiveTypeParameterState.java
index dc50932..0b85fb6 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcretePrimitiveTypeParameterState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcretePrimitiveTypeParameterState.java
@@ -7,6 +7,7 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Action;
 
 public class ConcretePrimitiveTypeParameterState extends ConcreteParameterState {
 
@@ -23,9 +24,12 @@
   }
 
   public ParameterState mutableJoin(
-      AppView<AppInfoWithLiveness> appView, ConcretePrimitiveTypeParameterState parameterState) {
+      AppView<AppInfoWithLiveness> appView,
+      ConcretePrimitiveTypeParameterState parameterState,
+      Action onChangedAction) {
     boolean allowNullOrAbstractValue = false;
     boolean allowNonConstantNumbers = false;
+    AbstractValue oldAbstractValue = abstractValue;
     abstractValue =
         abstractValue.join(
             parameterState.abstractValue,
@@ -35,10 +39,13 @@
     if (abstractValue.isUnknown()) {
       return unknown();
     }
-    mutableJoinInParameters(parameterState);
+    boolean inParametersChanged = mutableJoinInParameters(parameterState);
     if (widenInParameters()) {
       return unknown();
     }
+    if (abstractValue != oldAbstractValue || inParametersChanged) {
+      onChangedAction.execute();
+    }
     return this;
   }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteReceiverParameterState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteReceiverParameterState.java
index 9a529b2..145d29b 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteReceiverParameterState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ConcreteReceiverParameterState.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.ir.analysis.type.DynamicType;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.utils.Action;
 
 public class ConcreteReceiverParameterState extends ConcreteParameterState {
 
@@ -15,19 +16,24 @@
     this.dynamicType = dynamicType;
   }
 
-  public ParameterState mutableJoin(ConcreteReceiverParameterState parameterState) {
+  public ParameterState mutableJoin(
+      ConcreteReceiverParameterState parameterState, Action onChangedAction) {
     // TODO(b/190154391): Join the dynamic types using SubtypingInfo.
     // TODO(b/190154391): Take in the static type as an argument, and unset the dynamic type if it
     //  equals the static type.
+    DynamicType oldDynamicType = dynamicType;
     dynamicType =
         dynamicType.equals(parameterState.dynamicType) ? dynamicType : DynamicType.unknown();
     if (dynamicType.isUnknown()) {
       return unknown();
     }
-    mutableJoinInParameters(parameterState);
+    boolean inParametersChanged = mutableJoinInParameters(parameterState);
     if (widenInParameters()) {
       return unknown();
     }
+    if (dynamicType != oldDynamicType || inParametersChanged) {
+      onChangedAction.execute();
+    }
     return this;
   }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/MethodParameter.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/MethodParameter.java
index 7fd63da..9f7434b 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/MethodParameter.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/MethodParameter.java
@@ -17,6 +17,14 @@
     this.index = index;
   }
 
+  public DexMethod getMethod() {
+    return method;
+  }
+
+  public int getIndex() {
+    return index;
+  }
+
   @Override
   public boolean equals(Object obj) {
     if (obj == null || getClass() != obj.getClass()) {
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/MethodStateCollection.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/MethodStateCollection.java
index 7155df2..9c699ae 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/MethodStateCollection.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/MethodStateCollection.java
@@ -85,4 +85,8 @@
     MethodState removed = methodStates.remove(method.getReference());
     return removed != null ? removed : MethodState.bottom();
   }
+
+  public void set(ProgramMethod method, MethodState methodState) {
+    methodStates.put(method.getReference(), methodState);
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ParameterState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ParameterState.java
index 8f759cf..2c9a205 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ParameterState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ParameterState.java
@@ -7,6 +7,7 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Action;
 
 public abstract class ParameterState {
 
@@ -28,6 +29,11 @@
     return false;
   }
 
+  public final ParameterState mutableJoin(
+      AppView<AppInfoWithLiveness> appView, ParameterState parameterState) {
+    return mutableJoin(appView, parameterState, Action.empty());
+  }
+
   public abstract ParameterState mutableJoin(
-      AppView<AppInfoWithLiveness> appView, ParameterState parameterState);
+      AppView<AppInfoWithLiveness> appView, ParameterState parameterState, Action onChangedAction);
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/UnknownParameterState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/UnknownParameterState.java
index 8993b5c..bfc835d 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/UnknownParameterState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/UnknownParameterState.java
@@ -7,6 +7,7 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Action;
 
 public class UnknownParameterState extends ParameterState {
 
@@ -30,7 +31,7 @@
 
   @Override
   public ParameterState mutableJoin(
-      AppView<AppInfoWithLiveness> appView, ParameterState parameterState) {
+      AppView<AppInfoWithLiveness> appView, ParameterState parameterState, Action onChangedAction) {
     return this;
   }
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InParameterFlowPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InParameterFlowPropagator.java
new file mode 100644
index 0000000..a4e4877
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InParameterFlowPropagator.java
@@ -0,0 +1,302 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.argumentpropagation.propagation;
+
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMonomorphicMethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteParameterState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodParameter;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollection;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ParameterState;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Action;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.google.common.collect.Sets;
+import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
+import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.IdentityHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
+
+public class InParameterFlowPropagator {
+
+  final AppView<AppInfoWithLiveness> appView;
+  final MethodStateCollection methodStates;
+
+  public InParameterFlowPropagator(
+      AppView<AppInfoWithLiveness> appView, MethodStateCollection methodStates) {
+    this.appView = appView;
+    this.methodStates = methodStates;
+  }
+
+  public void run(ExecutorService executorService) throws ExecutionException {
+    // Build a graph with an edge from parameter p -> parameter p' if all argument information for p
+    // must be included in the argument information for p'.
+    FlowGraph flowGraph = new FlowGraph(appView.appInfo().classes());
+
+    // Build a worklist containing all the parameter nodes.
+    Deque<ParameterNode> worklist = new ArrayDeque<>();
+    flowGraph.forEachNode(worklist::add);
+
+    // Repeatedly propagate argument information through edges in the flow graph until there are no
+    // more changes.
+    // TODO(b/190154391): Consider parallelizing the flow propagation. There are a few scenarios
+    //  that need to be covered, such as (i) two threads could race to update the same parameter
+    //  state, (ii) a thread may try to propagate a parameter state to its successors while
+    //  another thread is trying to update the state of the parameter itself.
+    // TODO(b/190154391): Consider a path p1 -> p2 -> p3 in the graph. If we process p2 first, then
+    //  p3, and then p1, then the processing of p1 could cause p2 to change, which means that we
+    //  need to reprocess p2 and then p3. If we always process leaves in the graph first, we would
+    //  process p1, then p2, then p3, and then be done.
+    // TODO(b/190154391): Prune the graph on-the-fly. If the argument information for a parameter
+    //  becomes unknown, we could consider clearing its predecessors since none of the predecessors
+    //  could contribute any information even if they change.
+    while (!worklist.isEmpty()) {
+      ParameterNode parameterNode = worklist.removeLast();
+      parameterNode.unsetPending();
+      propagate(
+          parameterNode,
+          affectedNode -> {
+            // No need to enqueue the affected node if it is already in the worklist or if it does
+            // not have any successors (i.e., the successor is a leaf).
+            if (!affectedNode.isPending() && affectedNode.hasSuccessors()) {
+              worklist.add(affectedNode);
+              affectedNode.setPending();
+            }
+          });
+    }
+
+    // The algorithm only changes the parameter states of each monomorphic method state. In case any
+    // of these method states have effectively become unknown, we replace them by the canonicalized
+    // unknown method state.
+    postProcessMethodStates(executorService);
+  }
+
+  private void propagate(
+      ParameterNode parameterNode, Consumer<ParameterNode> affectedNodeConsumer) {
+    ParameterState parameterState = parameterNode.getState();
+    for (ParameterNode successorNode : parameterNode.getSuccessors()) {
+      successorNode.addState(
+          appView, parameterState, () -> affectedNodeConsumer.accept(successorNode));
+    }
+  }
+
+  private void postProcessMethodStates(ExecutorService executorService) throws ExecutionException {
+    ThreadUtils.processItems(
+        appView.appInfo().classes(), this::postProcessMethodStates, executorService);
+  }
+
+  private void postProcessMethodStates(DexProgramClass clazz) {
+    clazz.forEachProgramMethod(this::postProcessMethodState);
+  }
+
+  private void postProcessMethodState(ProgramMethod method) {
+    ConcreteMethodState methodState = methodStates.get(method).asConcrete();
+    if (methodState == null) {
+      return;
+    }
+    assert methodState.isMonomorphic();
+    for (ParameterState parameterState : methodState.asMonomorphic().getParameterStates()) {
+      if (!parameterState.isUnknown()) {
+        return;
+      }
+    }
+    methodStates.set(method, MethodState.unknown());
+  }
+
+  private class FlowGraph {
+
+    private final Map<DexMethod, Int2ReferenceMap<ParameterNode>> nodes = new IdentityHashMap<>();
+
+    public FlowGraph(Iterable<DexProgramClass> classes) {
+      classes.forEach(this::add);
+    }
+
+    void forEachNode(Consumer<? super ParameterNode> consumer) {
+      nodes.values().forEach(nodesForMethod -> nodesForMethod.values().forEach(consumer));
+    }
+
+    private void add(DexProgramClass clazz) {
+      clazz.forEachProgramMethod(this::add);
+    }
+
+    private void add(ProgramMethod method) {
+      MethodState methodState = methodStates.get(method);
+
+      // No need to create nodes for parameters with no in-flow or no useful information.
+      if (methodState.isBottom() || methodState.isUnknown()) {
+        return;
+      }
+
+      // Add nodes for the parameters for which we have non-trivial information.
+      ConcreteMonomorphicMethodState monomorphicMethodState =
+          methodState.asConcrete().asMonomorphic();
+      List<ParameterState> parameterStates = monomorphicMethodState.getParameterStates();
+      for (int parameterIndex = 0; parameterIndex < parameterStates.size(); parameterIndex++) {
+        ParameterState parameterState = parameterStates.get(parameterIndex);
+        add(method, parameterIndex, monomorphicMethodState, parameterState);
+      }
+    }
+
+    private void add(
+        ProgramMethod method,
+        int parameterIndex,
+        ConcreteMonomorphicMethodState methodState,
+        ParameterState parameterState) {
+      // No need to create nodes for parameters we don't know anything about.
+      if (parameterState.isUnknown()) {
+        return;
+      }
+
+      ConcreteParameterState concreteParameterState = parameterState.asConcrete();
+
+      // No need to create a node for a parameter that doesn't depend on any other parameters
+      // (unless some other parameter depends on this parameter).
+      if (!concreteParameterState.hasInParameters()) {
+        return;
+      }
+
+      ParameterNode node =
+          getOrCreateParameterNode(method.getReference(), parameterIndex, methodState);
+      for (MethodParameter inParameter : concreteParameterState.getInParameters()) {
+        MethodState enclosingMethodState = getEnclosingMethodStateForParameter(inParameter);
+        if (enclosingMethodState.isBottom()) {
+          // The current method is called from a dead method; no need to propagate any information
+          // from the dead call site.
+          continue;
+        }
+
+        if (enclosingMethodState.isUnknown()) {
+          // The parameter depends on another parameter for which we don't know anything.
+          node.clearPredecessors();
+          node.setState(ParameterState.unknown());
+          break;
+        }
+
+        assert enclosingMethodState.isConcrete();
+        assert enclosingMethodState.asConcrete().isMonomorphic();
+
+        ParameterNode predecessor =
+            getOrCreateParameterNode(
+                inParameter.getMethod(),
+                inParameter.getIndex(),
+                enclosingMethodState.asConcrete().asMonomorphic());
+        node.addPredecessor(predecessor);
+      }
+      concreteParameterState.clearInParameters();
+    }
+
+    private ParameterNode getOrCreateParameterNode(
+        DexMethod key, int parameterIndex, ConcreteMonomorphicMethodState methodState) {
+      Int2ReferenceMap<ParameterNode> parameterNodesForMethod =
+          nodes.computeIfAbsent(key, ignoreKey(Int2ReferenceOpenHashMap::new));
+      return parameterNodesForMethod.compute(
+          parameterIndex,
+          (ignore, parameterNode) ->
+              parameterNode != null
+                  ? parameterNode
+                  : new ParameterNode(methodState, parameterIndex));
+    }
+
+    private MethodState getEnclosingMethodStateForParameter(MethodParameter methodParameter) {
+      DexMethod methodReference = methodParameter.getMethod();
+      ProgramMethod method =
+          methodReference.lookupOnProgramClass(
+              asProgramClassOrNull(
+                  appView.definitionFor(methodParameter.getMethod().getHolderType())));
+      if (method == null) {
+        // Conservatively return unknown if for some reason we can't find the method.
+        assert false;
+        return MethodState.unknown();
+      }
+      return methodStates.get(method);
+    }
+  }
+
+  static class ParameterNode {
+
+    private final ConcreteMonomorphicMethodState methodState;
+    private final int parameterIndex;
+
+    private final Set<ParameterNode> predecessors = Sets.newIdentityHashSet();
+    private final Set<ParameterNode> successors = Sets.newIdentityHashSet();
+
+    private boolean pending = true;
+
+    ParameterNode(ConcreteMonomorphicMethodState methodState, int parameterIndex) {
+      this.methodState = methodState;
+      this.parameterIndex = parameterIndex;
+    }
+
+    void addPredecessor(ParameterNode predecessor) {
+      predecessor.successors.add(this);
+      predecessors.add(predecessor);
+    }
+
+    void clearPredecessors() {
+      for (ParameterNode predecessor : predecessors) {
+        predecessor.successors.remove(this);
+      }
+      predecessors.clear();
+    }
+
+    ParameterState getState() {
+      return methodState.getParameterState(parameterIndex);
+    }
+
+    Set<ParameterNode> getSuccessors() {
+      return successors;
+    }
+
+    boolean hasSuccessors() {
+      return !successors.isEmpty();
+    }
+
+    boolean isPending() {
+      return pending;
+    }
+
+    void addState(
+        AppView<AppInfoWithLiveness> appView,
+        ParameterState parameterStateToAdd,
+        Action onChangedAction) {
+      ParameterState oldParameterState = getState();
+      ParameterState newParameterState =
+          oldParameterState.mutableJoin(appView, parameterStateToAdd, onChangedAction);
+      if (newParameterState != oldParameterState) {
+        setState(newParameterState);
+        onChangedAction.execute();
+      }
+    }
+
+    void setPending() {
+      assert !isPending();
+      pending = true;
+    }
+
+    void setState(ParameterState parameterState) {
+      methodState.setParameterState(parameterIndex, parameterState);
+    }
+
+    void unsetPending() {
+      assert pending;
+      pending = false;
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/VirtualDispatchMethodArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/VirtualDispatchMethodArgumentPropagator.java
index 0e51d66..db8a85d 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/VirtualDispatchMethodArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/VirtualDispatchMethodArgumentPropagator.java
@@ -8,14 +8,12 @@
 import static com.android.tools.r8.utils.MapUtils.ignoreKey;
 
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.analysis.type.DynamicType;
-import com.android.tools.r8.ir.optimize.info.ConcreteCallSiteOptimizationInfo;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcretePolymorphicMethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
@@ -26,7 +24,6 @@
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Consumer;
-import java.util.stream.Collectors;
 
 public class VirtualDispatchMethodArgumentPropagator extends MethodArgumentPropagator {
 
@@ -134,7 +131,7 @@
   public void run(Set<DexProgramClass> stronglyConnectedComponent) {
     super.run(stronglyConnectedComponent);
     assert verifyAllClassesFinished(stronglyConnectedComponent);
-    assert verifyStatePruned(stronglyConnectedComponent);
+    assert verifyStatePruned();
   }
 
   @Override
@@ -162,7 +159,7 @@
   public void visit(DexProgramClass clazz) {
     assert !propagationStates.containsKey(clazz);
     PropagationState propagationState = computePropagationState(clazz);
-    setOptimizationInfo(clazz, propagationState);
+    computeFinalMethodStates(clazz, propagationState);
   }
 
   private PropagationState computePropagationState(DexProgramClass clazz) {
@@ -231,44 +228,20 @@
     return propagationState;
   }
 
-  private void setOptimizationInfo(DexProgramClass clazz, PropagationState propagationState) {
-    clazz.forEachProgramMethod(method -> setOptimizationInfo(method, propagationState));
+  private void computeFinalMethodStates(DexProgramClass clazz, PropagationState propagationState) {
+    clazz.forEachProgramMethod(method -> computeFinalMethodState(method, propagationState));
   }
 
-  private void setOptimizationInfo(ProgramMethod method, PropagationState propagationState) {
-    MethodState methodState = methodStates.remove(method);
+  private void computeFinalMethodState(ProgramMethod method, PropagationState propagationState) {
+    MethodState methodState = methodStates.get(method);
 
     // If this is a polymorphic method, we need to compute the method state to account for dynamic
     // dispatch.
     if (methodState.isConcrete() && methodState.asConcrete().isPolymorphic()) {
       methodState = propagationState.computeMethodStateForPolymorhicMethod(method);
+      assert !methodState.isConcrete() || methodState.asConcrete().isMonomorphic();
+      methodStates.set(method, methodState);
     }
-
-    if (methodState.isBottom()) {
-      // TODO(b/190154391): This should only happen if the method is never called. Consider removing
-      //  the method in this case.
-      return;
-    }
-
-    if (methodState.isUnknown()) {
-      // Nothing is known about the arguments to this method.
-      return;
-    }
-
-    ConcreteMethodState concreteMethodState = methodState.asConcrete();
-    if (concreteMethodState.isPolymorphic()) {
-      assert false;
-      return;
-    }
-
-    // TODO(b/190154391): We need to resolve the flow constraints before this is guaranteed to be
-    //  sound.
-    method
-        .getDefinition()
-        .joinCallSiteOptimizationInfo(
-            ConcreteCallSiteOptimizationInfo.fromMethodState(
-                appView, method, concreteMethodState.asMonomorphic()),
-            appView);
   }
 
   @Override
@@ -283,13 +256,7 @@
     return true;
   }
 
-  private boolean verifyStatePruned(Set<DexProgramClass> stronglyConnectedComponent) {
-    Set<DexType> types =
-        stronglyConnectedComponent.stream().map(DexClass::getType).collect(Collectors.toSet());
-    methodStates.forEach(
-        (method, methodState) -> {
-          assert !types.contains(method.getHolderType());
-        });
+  private boolean verifyStatePruned() {
     assert propagationStates.isEmpty();
     return true;
   }
diff --git a/src/main/java/com/android/tools/r8/utils/Action.java b/src/main/java/com/android/tools/r8/utils/Action.java
index d69b130..5ab0177 100644
--- a/src/main/java/com/android/tools/r8/utils/Action.java
+++ b/src/main/java/com/android/tools/r8/utils/Action.java
@@ -6,5 +6,12 @@
 
 @FunctionalInterface
 public interface Action {
+
+  Action EMPTY = () -> {};
+
+  static Action empty() {
+    return EMPTY;
+  }
+
   void execute();
 }
diff --git a/src/test/java/com/android/tools/r8/optimize/argumentpropagation/StaticMethodWithConstantArgumentThroughCallChainTest.java b/src/test/java/com/android/tools/r8/optimize/argumentpropagation/StaticMethodWithConstantArgumentThroughCallChainTest.java
new file mode 100644
index 0000000..b0387b2
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/argumentpropagation/StaticMethodWithConstantArgumentThroughCallChainTest.java
@@ -0,0 +1,125 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.argumentpropagation;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static junit.framework.TestCase.assertTrue;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.InternalOptions.CallSiteOptimizationOptions;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class StaticMethodWithConstantArgumentThroughCallChainTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection parameters() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options -> {
+              CallSiteOptimizationOptions callSiteOptimizationOptions =
+                  options.callSiteOptimizationOptions();
+              callSiteOptimizationOptions.setEnableExperimentalArgumentPropagation();
+              callSiteOptimizationOptions.setEnableConstantPropagation();
+            })
+        .enableInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(
+            inspector -> {
+              ClassSubject mainClassSubject = inspector.clazz(Main.class);
+              assertThat(mainClassSubject, isPresent());
+
+              // The test1(), test2(), and test3() methods have been optimized.
+              for (int i = 1; i <= 3; i++) {
+                MethodSubject testMethodSubject = mainClassSubject.uniqueMethodWithName("test" + i);
+                assertThat(testMethodSubject, isPresent());
+                // TODO(b/190154391): The parameter x should be removed.
+                assertEquals(1, testMethodSubject.getProgramMethod().getParameters().size());
+                assertTrue(
+                    testMethodSubject.streamInstructions().noneMatch(InstructionSubject::isIf));
+              }
+
+              assertThat(mainClassSubject.uniqueMethodWithName("dead"), isAbsent());
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(
+            "Hello from test1()",
+            "Hello from test2()",
+            "Hello from test3()",
+            "Hello from test1()",
+            "Hello from test2()",
+            "Hello from test3()",
+            "Hello from test1()",
+            "Hello from test2()",
+            "Hello from test3()");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      test1(42);
+      test1(42);
+      test1(42);
+    }
+
+    @NeverInline
+    static void test1(int x) {
+      if (x == 42) {
+        System.out.println("Hello from test1()");
+      } else {
+        dead();
+      }
+      test2(x);
+    }
+
+    @NeverInline
+    static void test2(int x) {
+      if (x == 42) {
+        System.out.println("Hello from test2()");
+      } else {
+        dead();
+      }
+      test3(x);
+    }
+
+    @NeverInline
+    static void test3(int x) {
+      if (x == 42) {
+        System.out.println("Hello from test3()");
+      } else {
+        dead();
+      }
+    }
+
+    @NeverInline
+    static void dead() {
+      System.out.println("Unreachable");
+    }
+  }
+}