Avoid argument propagation to unoptimizable methods

Bug: 190154391
Change-Id: I91d38f4787205f9468edccbcb7f7f64f237364fe
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
index de1fa53..bf688ef 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagator.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.optimize.argumentpropagation;
 
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -24,6 +25,7 @@
 import java.util.concurrent.ExecutorService;
 
 /** Optimization that propagates information about arguments from call sites to method entries. */
+// TODO(b/190154391): Add timing information for performance tracking.
 public class ArgumentPropagator {
 
   private final AppView<AppInfoWithLiveness> appView;
@@ -53,6 +55,15 @@
    */
   public void initializeCodeScanner() {
     codeScanner = new ArgumentPropagatorCodeScanner(appView);
+
+    // Disable argument propagation for methods that should not be optimized.
+    ImmediateProgramSubtypingInfo immediateSubtypingInfo =
+        ImmediateProgramSubtypingInfo.create(appView);
+    // TODO(b/190154391): Consider computing the strongly connected components and running this in
+    //  parallel for each scc.
+    new ArgumentPropagatorUnoptimizableMethods(
+            appView, immediateSubtypingInfo, codeScanner.getMethodStates())
+        .disableArgumentPropagationForUnoptimizableMethods(appView.appInfo().classes());
   }
 
   /** Called by {@link IRConverter} prior to finalizing methods. */
@@ -75,7 +86,7 @@
       throws ExecutionException {
     // Unset the scanner since all code objects have been scanned at this point.
     assert appView.isAllCodeProcessed();
-    MethodStateCollection codeScannerResult = codeScanner.getResult();
+    MethodStateCollection codeScannerResult = codeScanner.getMethodStates();
     codeScanner = null;
 
     new ArgumentPropagatorOptimizationInfoPopulator(appView, codeScannerResult)
@@ -133,7 +144,8 @@
    */
   public void enqueueMethodsForProcessing(PostMethodProcessor.Builder postMethodProcessorBuilder) {
     for (DexProgramClass clazz : appView.appInfo().classes()) {
-      clazz.forEachProgramMethod(
+      clazz.forEachProgramMethodMatching(
+          DexEncodedMethod::hasCode,
           method -> {
             CallSiteOptimizationInfo callSiteOptimizationInfo =
                 method.getDefinition().getCallSiteOptimizationInfo();
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
index 23ae6ba..44c022d 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorCodeScanner.java
@@ -4,9 +4,9 @@
 
 package com.android.tools.r8.optimize.argumentpropagation;
 
-import com.android.tools.r8.errors.Unimplemented;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexMethodHandle;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -35,14 +35,13 @@
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollection;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ParameterState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.UnknownMethodState;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.google.common.collect.Iterables;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 
 /**
  * Analyzes each {@link IRCode} during the primary optimization to collect information about the
@@ -70,21 +69,10 @@
    */
   private final MethodStateCollection methodStates;
 
-  /**
-   * The methods that are not subject to argument propagation. This includes (i) methods that are
-   * not subject to optimization due to -keep rules, (ii) classpath/library method overrides, and
-   * (iii) methods that are unlikely to benefit from argument propagation according to heuristics.
-   *
-   * <p>Argument propagation must also be disabled for lambda implementation methods unless we model
-   * the calls from lambda main methods synthesized by the JVM.
-   */
-  private final Set<DexMethod> unoptimizableMethods;
-
   ArgumentPropagatorCodeScanner(AppView<AppInfoWithLiveness> appView) {
     this.appView = appView;
     this.classMethodRoots = computeClassMethodRoots();
     this.methodStates = computeInitialMethodStates();
-    this.unoptimizableMethods = computeUnoptimizableMethods();
   }
 
   private Map<DexMethod, DexMethod> computeClassMethodRoots() {
@@ -104,14 +92,7 @@
     return MethodStateCollection.createConcurrent();
   }
 
-  private Set<DexMethod> computeUnoptimizableMethods() {
-    // TODO(b/190154391): Ensure we don't store any information for kept methods and their
-    //  overrides.
-    // TODO(b/190154391): Consider bailing out for all classes that inherit from a missing class.
-    return Collections.emptySet();
-  }
-
-  MethodStateCollection getResult() {
+  MethodStateCollection getMethodStates() {
     return methodStates;
   }
 
@@ -335,7 +316,16 @@
   }
 
   private void scan(InvokeCustom invoke, ProgramMethod context) {
-    // TODO(b/190154391): Handle call to bootstrap method.
-    throw new Unimplemented();
+    // If the bootstrap method is program declared it will be called. The call is with runtime
+    // provided arguments so ensure that the argument information is unknown.
+    DexMethodHandle bootstrapMethod = invoke.getCallSite().bootstrapMethod;
+    SingleResolutionResult resolution =
+        appView
+            .appInfo()
+            .resolveMethod(bootstrapMethod.asMethod(), bootstrapMethod.isInterface)
+            .asSingleResolution();
+    if (resolution != null && resolution.getResolvedHolder().isProgramClass()) {
+      methodStates.set(resolution.getResolvedProgramMethod(), UnknownMethodState.get());
+    }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorUnoptimizableMethods.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorUnoptimizableMethods.java
new file mode 100644
index 0000000..cc51bd8
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorUnoptimizableMethods.java
@@ -0,0 +1,245 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.argumentpropagation;
+
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollection;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.UnknownMethodState;
+import com.android.tools.r8.optimize.argumentpropagation.utils.DepthFirstTopDownClassHierarchyTraversal;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.MethodSignatureEquivalence;
+import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.google.common.base.Equivalence.Wrapper;
+import com.google.common.collect.Sets;
+import java.util.Collection;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+
+public class ArgumentPropagatorUnoptimizableMethods {
+
+  private static final MethodSignatureEquivalence equivalence = MethodSignatureEquivalence.get();
+
+  private final AppView<AppInfoWithLiveness> appView;
+  private final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
+  private final MethodStateCollection methodStates;
+
+  public ArgumentPropagatorUnoptimizableMethods(
+      AppView<AppInfoWithLiveness> appView,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
+      MethodStateCollection methodStates) {
+    this.appView = appView;
+    this.immediateSubtypingInfo = immediateSubtypingInfo;
+    this.methodStates = methodStates;
+  }
+
+  // TODO(b/190154391): Consider if we should bail out for classes that inherit from a missing
+  //  class.
+  public void disableArgumentPropagationForUnoptimizableMethods(
+      Collection<DexProgramClass> stronglyConnectedComponent) {
+    ProgramMethodSet unoptimizableClassRootMethods = ProgramMethodSet.create();
+    ProgramMethodSet unoptimizableInterfaceRootMethods = ProgramMethodSet.create();
+    forEachUnoptimizableMethod(
+        stronglyConnectedComponent,
+        method -> {
+          if (method.getDefinition().belongsToVirtualPool()
+              && !method.getHolder().isFinal()
+              && !method.getAccessFlags().isFinal()) {
+            if (method.getHolder().isInterface()) {
+              unoptimizableInterfaceRootMethods.add(method);
+            } else {
+              unoptimizableClassRootMethods.add(method);
+            }
+          } else {
+            disableArgumentPropagationForMethod(method);
+          }
+        });
+
+    // Disable argument propagation for all overrides of the root methods. Since interface methods
+    // may be implemented by classes that are not a subtype of the interface that declares the
+    // interface method, we first mark the interface method overrides on such classes as ineligible
+    // for argument propagation.
+    if (!unoptimizableInterfaceRootMethods.isEmpty()) {
+      new UnoptimizableInterfaceMethodPropagator(
+              unoptimizableClassRootMethods, unoptimizableInterfaceRootMethods)
+          .run(stronglyConnectedComponent);
+    }
+
+    // At this point we can mark all overrides by a simple top-down traversal over the class
+    // hierarchy.
+    new UnoptimizableClassMethodPropagator(
+            unoptimizableClassRootMethods, unoptimizableInterfaceRootMethods)
+        .run(stronglyConnectedComponent);
+  }
+
+  private void disableArgumentPropagationForMethod(ProgramMethod method) {
+    methodStates.set(method, UnknownMethodState.get());
+  }
+
+  private void forEachUnoptimizableMethod(
+      Collection<DexProgramClass> stronglyConnectedComponent, Consumer<ProgramMethod> consumer) {
+    AppInfoWithLiveness appInfo = appView.appInfo();
+    InternalOptions options = appView.options();
+    for (DexProgramClass clazz : stronglyConnectedComponent) {
+      clazz.forEachProgramMethod(
+          method -> {
+            assert !method.getDefinition().isLibraryMethodOverride().isUnknown();
+            if (method.getDefinition().isLibraryMethodOverride().isPossiblyTrue()
+                || appInfo.isMethodTargetedByInvokeDynamic(method)
+                || !appInfo
+                    .getKeepInfo()
+                    .getMethodInfo(method)
+                    .isArgumentPropagationAllowed(options)) {
+              consumer.accept(method);
+            }
+          });
+    }
+  }
+
+  private class UnoptimizableInterfaceMethodPropagator
+      extends DepthFirstTopDownClassHierarchyTraversal {
+
+    private final ProgramMethodSet unoptimizableClassRootMethods;
+    private final Map<DexProgramClass, Set<Wrapper<DexMethod>>> unoptimizableInterfaceMethods =
+        new IdentityHashMap<>();
+
+    UnoptimizableInterfaceMethodPropagator(
+        ProgramMethodSet unoptimizableClassRootMethods,
+        ProgramMethodSet unoptimizableInterfaceRootMethods) {
+      super(
+          ArgumentPropagatorUnoptimizableMethods.this.appView,
+          ArgumentPropagatorUnoptimizableMethods.this.immediateSubtypingInfo);
+      this.unoptimizableClassRootMethods = unoptimizableClassRootMethods;
+      unoptimizableInterfaceRootMethods.forEach(this::addUnoptimizableRootMethod);
+    }
+
+    private void addUnoptimizableRootMethod(ProgramMethod method) {
+      unoptimizableInterfaceMethods
+          .computeIfAbsent(method.getHolder(), ignoreKey(Sets::newIdentityHashSet))
+          .add(equivalence.wrap(method.getReference()));
+    }
+
+    @Override
+    public void visit(DexProgramClass clazz) {
+      Set<Wrapper<DexMethod>> unoptimizableInterfaceMethodsForClass =
+          unoptimizableInterfaceMethods.computeIfAbsent(clazz, ignoreKey(Sets::newIdentityHashSet));
+
+      // Add the unoptimizable interface methods from the parent interfaces.
+      immediateSubtypingInfo.forEachImmediateSuperClassMatching(
+          clazz,
+          (supertype, superclass) -> superclass != null && superclass.isProgramClass(),
+          (supertype, superclass) ->
+              unoptimizableInterfaceMethodsForClass.addAll(
+                  unoptimizableInterfaceMethods.get(superclass.asProgramClass())));
+
+      // Propagate the unoptimizable interface methods of this interface to all immediate
+      // (non-interface) subclasses.
+      for (DexProgramClass implementer : immediateSubtypingInfo.getSubclasses(clazz)) {
+        if (implementer.isInterface()) {
+          continue;
+        }
+
+        for (Wrapper<DexMethod> unoptimizableInterfaceMethod :
+            unoptimizableInterfaceMethodsForClass) {
+          SingleResolutionResult resolutionResult =
+              appView
+                  .appInfo()
+                  .resolveMethodOnClass(unoptimizableInterfaceMethod.get(), implementer)
+                  .asSingleResolution();
+          if (resolutionResult == null || !resolutionResult.getResolvedHolder().isProgramClass()) {
+            continue;
+          }
+
+          ProgramMethod resolvedMethod = resolutionResult.getResolvedProgramMethod();
+          if (resolvedMethod.getHolder().isInterface()
+              || resolvedMethod.getHolder() == implementer) {
+            continue;
+          }
+
+          unoptimizableClassRootMethods.add(resolvedMethod);
+        }
+      }
+    }
+
+    @Override
+    public void prune(DexProgramClass clazz) {
+      unoptimizableInterfaceMethods.remove(clazz);
+    }
+
+    @Override
+    public boolean isRoot(DexProgramClass clazz) {
+      return clazz.isInterface() && super.isRoot(clazz);
+    }
+
+    @Override
+    public void forEachSubClass(DexProgramClass clazz, Consumer<DexProgramClass> consumer) {
+      for (DexProgramClass subclass : immediateSubtypingInfo.getSubclasses(clazz)) {
+        if (subclass.isInterface()) {
+          consumer.accept(subclass);
+        }
+      }
+    }
+  }
+
+  private class UnoptimizableClassMethodPropagator
+      extends DepthFirstTopDownClassHierarchyTraversal {
+
+    private final Map<DexProgramClass, DexMethodSignatureSet> unoptimizableMethods =
+        new IdentityHashMap<>();
+
+    UnoptimizableClassMethodPropagator(
+        ProgramMethodSet unoptimizableClassRootMethods,
+        ProgramMethodSet unoptimizableInterfaceRootMethods) {
+      super(
+          ArgumentPropagatorUnoptimizableMethods.this.appView,
+          ArgumentPropagatorUnoptimizableMethods.this.immediateSubtypingInfo);
+      unoptimizableClassRootMethods.forEach(this::addUnoptimizableRootMethod);
+      unoptimizableInterfaceRootMethods.forEach(this::addUnoptimizableRootMethod);
+    }
+
+    private void addUnoptimizableRootMethod(ProgramMethod method) {
+      unoptimizableMethods
+          .computeIfAbsent(method.getHolder(), ignoreKey(DexMethodSignatureSet::create))
+          .add(method);
+    }
+
+    @Override
+    public void visit(DexProgramClass clazz) {
+      DexMethodSignatureSet unoptimizableMethodsForClass =
+          unoptimizableMethods.computeIfAbsent(clazz, ignoreKey(DexMethodSignatureSet::create));
+
+      // Add the unoptimizable methods from the parent classes.
+      immediateSubtypingInfo.forEachImmediateSuperClassMatching(
+          clazz,
+          (supertype, superclass) -> superclass != null && superclass.isProgramClass(),
+          (supertype, superclass) ->
+              unoptimizableMethodsForClass.addAll(
+                  unoptimizableMethods.get(superclass.asProgramClass())));
+
+      // Disable argument propagation for the unoptimizable methods of this class.
+      clazz.forEachProgramVirtualMethod(
+          method -> {
+            if (unoptimizableMethodsForClass.contains(method)) {
+              disableArgumentPropagationForMethod(method);
+            }
+          });
+    }
+
+    @Override
+    public void prune(DexProgramClass clazz) {
+      unoptimizableMethods.remove(clazz);
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
index a1ac738..cfb7fff 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InterfaceMethodArgumentPropagator.java
@@ -7,16 +7,15 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollection;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import java.util.Collection;
 import java.util.IdentityHashMap;
 import java.util.Map;
-import java.util.Set;
 import java.util.function.Consumer;
 
 /**
@@ -46,7 +45,7 @@
   }
 
   @Override
-  public void run(Set<DexProgramClass> stronglyConnectedComponent) {
+  public void run(Collection<DexProgramClass> stronglyConnectedComponent) {
     super.run(stronglyConnectedComponent);
     assert verifyAllInterfacesFinished(stronglyConnectedComponent);
   }
@@ -62,16 +61,7 @@
 
   @Override
   public boolean isRoot(DexProgramClass clazz) {
-    if (!clazz.isInterface()) {
-      return false;
-    }
-    for (DexType implementedType : clazz.getInterfaces()) {
-      DexClass implementedDefinition = appView.definitionFor(implementedType);
-      if (implementedDefinition != null && implementedDefinition.isProgramClass()) {
-        return false;
-      }
-    }
-    return true;
+    return clazz.isInterface() && super.isRoot(clazz);
   }
 
   @Override
@@ -151,7 +141,8 @@
                 }));
   }
 
-  private boolean verifyAllInterfacesFinished(Set<DexProgramClass> stronglyConnectedComponent) {
+  private boolean verifyAllInterfacesFinished(
+      Collection<DexProgramClass> stronglyConnectedComponent) {
     assert stronglyConnectedComponent.stream()
         .filter(DexClass::isInterface)
         .allMatch(this::isClassFinished);
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/MethodArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/MethodArgumentPropagator.java
index 20456c5..1554357 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/MethodArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/MethodArgumentPropagator.java
@@ -4,176 +4,21 @@
 
 package com.android.tools.r8.optimize.argumentpropagation.propagation;
 
-import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
-
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollection;
+import com.android.tools.r8.optimize.argumentpropagation.utils.DepthFirstTopDownClassHierarchyTraversal;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Deque;
-import java.util.IdentityHashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.function.Consumer;
 
-public abstract class MethodArgumentPropagator {
+public abstract class MethodArgumentPropagator extends DepthFirstTopDownClassHierarchyTraversal {
 
-  // The state of a given class in the top-down traversal.
-  private enum TraversalState {
-    // Represents that a given class and all of its direct and indirect supertypes have been
-    // visited by the top-down traversal, but all of the direct and indirect subtypes are still
-    // not visited.
-    SEEN,
-    // Represents that a given class and all of its direct and indirect subtypes have been visited.
-    // Such nodes will never be seen again in the top-down traversal, and any state stored for
-    // such nodes can be pruned.
-    FINISHED
-  }
-
-  final AppView<AppInfoWithLiveness> appView;
-  final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
   final MethodStateCollection methodStates;
 
-  // Contains the traversal state for each class. If a given class is not in the map the class is
-  // not yet seen.
-  private final Map<DexProgramClass, TraversalState> states = new IdentityHashMap<>();
-
-  // The class hierarchy is not a tree, thus for completeness we need to process all parent
-  // interfaces for a given class or interface before continuing the top-down traversal. When the
-  // top-down traversal for a given root returns, this means that there may be interfaces that are
-  // seen but not finished. These interfaces are added to this collection such that we can
-  // prioritize them over classes or interfaces that are yet not seen. This leads to more efficient
-  // state pruning, since the state for these interfaces can be pruned when they transition to being
-  // finished.
-  //
-  // See also prioritizeNewlySeenButNotFinishedRoots().
-  private final List<DexProgramClass> newlySeenButNotFinishedRoots = new ArrayList<>();
-
   public MethodArgumentPropagator(
       AppView<AppInfoWithLiveness> appView,
       ImmediateProgramSubtypingInfo immediateSubtypingInfo,
       MethodStateCollection methodStates) {
-    this.appView = appView;
-    this.immediateSubtypingInfo = immediateSubtypingInfo;
+    super(appView, immediateSubtypingInfo);
     this.methodStates = methodStates;
   }
-
-  public abstract void forEachSubClass(DexProgramClass clazz, Consumer<DexProgramClass> consumer);
-
-  public abstract boolean isRoot(DexProgramClass clazz);
-
-  public abstract void visit(DexProgramClass clazz);
-
-  public abstract void prune(DexProgramClass clazz);
-
-  public void run(Set<DexProgramClass> stronglyConnectedComponent) {
-    // Perform a top-down traversal from each root in the strongly connected component.
-    Deque<DexProgramClass> roots = computeRoots(stronglyConnectedComponent);
-    while (!roots.isEmpty()) {
-      DexProgramClass root = roots.removeLast();
-      traverse(root);
-      prioritizeNewlySeenButNotFinishedRoots(roots);
-    }
-  }
-
-  private Deque<DexProgramClass> computeRoots(Set<DexProgramClass> stronglyConnectedComponent) {
-    Deque<DexProgramClass> roots = new ArrayDeque<>();
-    for (DexProgramClass clazz : stronglyConnectedComponent) {
-      if (isRoot(clazz)) {
-        roots.add(clazz);
-      }
-    }
-    return roots;
-  }
-
-  private void prioritizeNewlySeenButNotFinishedRoots(Deque<DexProgramClass> roots) {
-    assert newlySeenButNotFinishedRoots.stream()
-        .allMatch(
-            newlySeenButNotFinishedRoot -> {
-              assert newlySeenButNotFinishedRoot.isInterface();
-              assert isRoot(newlySeenButNotFinishedRoot);
-              assert isClassSeenButNotFinished(newlySeenButNotFinishedRoot);
-              return true;
-            });
-    // Prioritize this interface over other not yet seen interfaces. This leads to more efficient
-    // state pruning.
-    roots.addAll(newlySeenButNotFinishedRoots);
-    newlySeenButNotFinishedRoots.clear();
-  }
-
-  private void traverse(DexProgramClass clazz) {
-    // Check it the class and all of its subtypes are already processed.
-    if (isClassFinished(clazz)) {
-      return;
-    }
-
-    // Before continuing the top-down traversal, ensure that all super interfaces are processed,
-    // but without visiting the entire subtree of each super interface.
-    if (!isClassSeenButNotFinished(clazz)) {
-      processImplementedInterfaces(clazz);
-      processClass(clazz);
-    }
-
-    processSubclasses(clazz);
-    markFinished(clazz);
-  }
-
-  private void processImplementedInterfaces(DexProgramClass interfaceDefinition) {
-    assert !isClassSeenButNotFinished(interfaceDefinition);
-    assert !isClassFinished(interfaceDefinition);
-    for (DexType implementedType : interfaceDefinition.getInterfaces()) {
-      DexProgramClass implementedDefinition =
-          asProgramClassOrNull(appView.definitionFor(implementedType));
-      if (implementedDefinition == null || isClassSeenButNotFinished(implementedDefinition)) {
-        continue;
-      }
-      assert isClassUnseen(implementedDefinition);
-      processImplementedInterfaces(implementedDefinition);
-      processClass(implementedDefinition);
-
-      // If this is a root, then record that this root is seen but not finished.
-      if (isRoot(implementedDefinition)) {
-        newlySeenButNotFinishedRoots.add(implementedDefinition);
-      }
-    }
-  }
-
-  private void processSubclasses(DexProgramClass clazz) {
-    forEachSubClass(clazz, this::traverse);
-  }
-
-  private void processClass(DexProgramClass interfaceDefinition) {
-    assert !isClassSeenButNotFinished(interfaceDefinition);
-    assert !isClassFinished(interfaceDefinition);
-    visit(interfaceDefinition);
-    markSeenButNotFinished(interfaceDefinition);
-  }
-
-  boolean isClassUnseen(DexProgramClass clazz) {
-    return !states.containsKey(clazz);
-  }
-
-  boolean isClassSeenButNotFinished(DexProgramClass clazz) {
-    return states.get(clazz) == TraversalState.SEEN;
-  }
-
-  boolean isClassFinished(DexProgramClass clazz) {
-    return states.get(clazz) == TraversalState.FINISHED;
-  }
-
-  private void markSeenButNotFinished(DexProgramClass clazz) {
-    assert isClassUnseen(clazz);
-    states.put(clazz, TraversalState.SEEN);
-  }
-
-  private void markFinished(DexProgramClass clazz) {
-    assert isClassSeenButNotFinished(clazz);
-    states.put(clazz, TraversalState.FINISHED);
-    prune(clazz);
-  }
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/VirtualDispatchMethodArgumentPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/VirtualDispatchMethodArgumentPropagator.java
index db8a85d..eafcd99 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/VirtualDispatchMethodArgumentPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/VirtualDispatchMethodArgumentPropagator.java
@@ -4,7 +4,7 @@
 
 package com.android.tools.r8.optimize.argumentpropagation.propagation;
 
-import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+import static com.android.tools.r8.ir.analysis.type.Nullability.maybeNull;
 import static com.android.tools.r8.utils.MapUtils.ignoreKey;
 
 import com.android.tools.r8.graph.AppView;
@@ -14,16 +14,16 @@
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.analysis.type.DynamicType;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcretePolymorphicMethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollection;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.IdentityHashMap;
 import java.util.Map;
-import java.util.Set;
-import java.util.function.Consumer;
 
 public class VirtualDispatchMethodArgumentPropagator extends MethodArgumentPropagator {
 
@@ -54,6 +54,9 @@
     //  memory usage, but would require visiting all transitive (program) super classes for each
     //  subclass.
     private void addParentState(DexProgramClass clazz, DexProgramClass superclass) {
+      ClassTypeElement classType =
+          TypeElement.fromDexType(clazz.getType(), maybeNull(), appView).asClassType();
+
       PropagationState parentState = propagationStates.get(superclass.asProgramClass());
       assert parentState != null;
 
@@ -78,7 +81,7 @@
       parentState.inactiveUntilUpperBound.forEach(
           (bounds, inactiveMethodState) -> {
             ClassTypeElement upperBound = bounds.getDynamicUpperBoundType().asClassType();
-            if (upperBound.getClassType() == clazz.getType()) {
+            if (upperBound.equalUpToNullability(classType)) {
               // The upper bound is the current class, thus this inactive information now becomes
               // active.
               if (bounds.hasDynamicLowerBoundType()) {
@@ -128,34 +131,13 @@
   }
 
   @Override
-  public void run(Set<DexProgramClass> stronglyConnectedComponent) {
+  public void run(Collection<DexProgramClass> stronglyConnectedComponent) {
     super.run(stronglyConnectedComponent);
     assert verifyAllClassesFinished(stronglyConnectedComponent);
     assert verifyStatePruned();
   }
 
   @Override
-  public void forEachSubClass(DexProgramClass clazz, Consumer<DexProgramClass> consumer) {
-    immediateSubtypingInfo.getSubclasses(clazz).forEach(consumer);
-  }
-
-  @Override
-  public boolean isRoot(DexProgramClass clazz) {
-    DexProgramClass superclass = asProgramClassOrNull(appView.definitionFor(clazz.getSuperType()));
-    if (superclass != null) {
-      return false;
-    }
-    for (DexType implementedType : clazz.getInterfaces()) {
-      DexProgramClass implementedClass =
-          asProgramClassOrNull(appView.definitionFor(implementedType));
-      if (implementedClass != null) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  @Override
   public void visit(DexProgramClass clazz) {
     assert !propagationStates.containsKey(clazz);
     PropagationState propagationState = computePropagationState(clazz);
@@ -163,6 +145,8 @@
   }
 
   private PropagationState computePropagationState(DexProgramClass clazz) {
+    ClassTypeElement classType =
+        TypeElement.fromDexType(clazz.getType(), maybeNull(), appView).asClassType();
     PropagationState propagationState = new PropagationState(clazz);
 
     // Join the argument information from the methods of the current class.
@@ -199,7 +183,7 @@
                   // TODO(b/190154391): Verify that the bounds are not trivial according to the
                   //  static receiver type.
                   ClassTypeElement upperBound = bounds.getDynamicUpperBoundType().asClassType();
-                  if (upperBound.getClassType() == clazz.getType()) {
+                  if (upperBound.equalUpToNullability(classType)) {
                     if (bounds.hasDynamicLowerBoundType()) {
                       // TODO(b/190154391): Verify that the lower bound is a subtype of the current
                       //  class.
@@ -214,7 +198,7 @@
                           appView, method.getReference(), methodStateForBounds);
                     }
                   } else {
-                    assert !appView.appInfo().isSubtype(clazz.getType(), upperBound.getClassType());
+                    assert !classType.lessThanOrEqualUpToNullability(upperBound, appView);
                     propagationState
                         .inactiveUntilUpperBound
                         .computeIfAbsent(bounds, ignoreKey(MethodStateCollection::create))
@@ -233,6 +217,11 @@
   }
 
   private void computeFinalMethodState(ProgramMethod method, PropagationState propagationState) {
+    if (!method.getDefinition().hasCode()) {
+      methodStates.remove(method);
+      return;
+    }
+
     MethodState methodState = methodStates.get(method);
 
     // If this is a polymorphic method, we need to compute the method state to account for dynamic
@@ -249,7 +238,7 @@
     propagationStates.remove(clazz);
   }
 
-  private boolean verifyAllClassesFinished(Set<DexProgramClass> stronglyConnectedComponent) {
+  private boolean verifyAllClassesFinished(Collection<DexProgramClass> stronglyConnectedComponent) {
     for (DexProgramClass clazz : stronglyConnectedComponent) {
       assert isClassFinished(clazz);
     }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/DepthFirstTopDownClassHierarchyTraversal.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/DepthFirstTopDownClassHierarchyTraversal.java
new file mode 100644
index 0000000..379efaf
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/utils/DepthFirstTopDownClassHierarchyTraversal.java
@@ -0,0 +1,190 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.argumentpropagation.utils;
+
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Deque;
+import java.util.IdentityHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+
+public abstract class DepthFirstTopDownClassHierarchyTraversal {
+
+  // The state of a given class in the top-down traversal.
+  private enum TraversalState {
+    // Represents that a given class and all of its direct and indirect supertypes have been
+    // visited by the top-down traversal, but all of the direct and indirect subtypes are still
+    // not visited.
+    SEEN,
+    // Represents that a given class and all of its direct and indirect subtypes have been visited.
+    // Such nodes will never be seen again in the top-down traversal, and any state stored for
+    // such nodes can be pruned.
+    FINISHED
+  }
+
+  protected final AppView<AppInfoWithLiveness> appView;
+  protected final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
+
+  // Contains the traversal state for each class. If a given class is not in the map the class is
+  // not yet seen.
+  private final Map<DexProgramClass, TraversalState> states = new IdentityHashMap<>();
+
+  // The class hierarchy is not a tree, thus for completeness we need to process all parent
+  // interfaces for a given class or interface before continuing the top-down traversal. When the
+  // top-down traversal for a given root returns, this means that there may be interfaces that are
+  // seen but not finished. These interfaces are added to this collection such that we can
+  // prioritize them over classes or interfaces that are yet not seen. This leads to more efficient
+  // state pruning, since the state for these interfaces can be pruned when they transition to being
+  // finished.
+  //
+  // See also prioritizeNewlySeenButNotFinishedRoots().
+  private final List<DexProgramClass> newlySeenButNotFinishedRoots = new ArrayList<>();
+
+  public DepthFirstTopDownClassHierarchyTraversal(
+      AppView<AppInfoWithLiveness> appView, ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
+    this.appView = appView;
+    this.immediateSubtypingInfo = immediateSubtypingInfo;
+  }
+
+  public abstract void visit(DexProgramClass clazz);
+
+  public abstract void prune(DexProgramClass clazz);
+
+  public void run(Collection<DexProgramClass> stronglyConnectedComponent) {
+    // Perform a top-down traversal from each root in the strongly connected component.
+    Deque<DexProgramClass> roots = computeRoots(stronglyConnectedComponent);
+    while (!roots.isEmpty()) {
+      DexProgramClass root = roots.removeLast();
+      traverse(root);
+      prioritizeNewlySeenButNotFinishedRoots(roots);
+    }
+  }
+
+  private Deque<DexProgramClass> computeRoots(
+      Collection<DexProgramClass> stronglyConnectedComponent) {
+    Deque<DexProgramClass> roots = new ArrayDeque<>();
+    for (DexProgramClass clazz : stronglyConnectedComponent) {
+      if (isRoot(clazz)) {
+        roots.add(clazz);
+      }
+    }
+    return roots;
+  }
+
+  public boolean isRoot(DexProgramClass clazz) {
+    DexProgramClass superclass = asProgramClassOrNull(appView.definitionFor(clazz.getSuperType()));
+    if (superclass != null) {
+      return false;
+    }
+    for (DexType implementedType : clazz.getInterfaces()) {
+      DexProgramClass implementedClass =
+          asProgramClassOrNull(appView.definitionFor(implementedType));
+      if (implementedClass != null) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  private void prioritizeNewlySeenButNotFinishedRoots(Deque<DexProgramClass> roots) {
+    assert newlySeenButNotFinishedRoots.stream()
+        .allMatch(
+            newlySeenButNotFinishedRoot -> {
+              assert newlySeenButNotFinishedRoot.isInterface();
+              assert isRoot(newlySeenButNotFinishedRoot);
+              assert isClassSeenButNotFinished(newlySeenButNotFinishedRoot);
+              return true;
+            });
+    // Prioritize this interface over other not yet seen interfaces. This leads to more efficient
+    // state pruning.
+    roots.addAll(newlySeenButNotFinishedRoots);
+    newlySeenButNotFinishedRoots.clear();
+  }
+
+  private void traverse(DexProgramClass clazz) {
+    // Check it the class and all of its subtypes are already processed.
+    if (isClassFinished(clazz)) {
+      return;
+    }
+
+    // Before continuing the top-down traversal, ensure that all super interfaces are processed,
+    // but without visiting the entire subtree of each super interface.
+    if (!isClassSeenButNotFinished(clazz)) {
+      processImplementedInterfaces(clazz);
+      processClass(clazz);
+    }
+
+    processSubclasses(clazz);
+    markFinished(clazz);
+  }
+
+  private void processImplementedInterfaces(DexProgramClass interfaceDefinition) {
+    assert !isClassSeenButNotFinished(interfaceDefinition);
+    assert !isClassFinished(interfaceDefinition);
+    for (DexType implementedType : interfaceDefinition.getInterfaces()) {
+      DexProgramClass implementedDefinition =
+          asProgramClassOrNull(appView.definitionFor(implementedType));
+      if (implementedDefinition == null || isClassSeenButNotFinished(implementedDefinition)) {
+        continue;
+      }
+      assert isClassUnseen(implementedDefinition);
+      processImplementedInterfaces(implementedDefinition);
+      processClass(implementedDefinition);
+
+      // If this is a root, then record that this root is seen but not finished.
+      if (isRoot(implementedDefinition)) {
+        newlySeenButNotFinishedRoots.add(implementedDefinition);
+      }
+    }
+  }
+
+  private void processSubclasses(DexProgramClass clazz) {
+    forEachSubClass(clazz, this::traverse);
+  }
+
+  public void forEachSubClass(DexProgramClass clazz, Consumer<DexProgramClass> consumer) {
+    immediateSubtypingInfo.getSubclasses(clazz).forEach(consumer);
+  }
+
+  private void processClass(DexProgramClass interfaceDefinition) {
+    assert !isClassSeenButNotFinished(interfaceDefinition);
+    assert !isClassFinished(interfaceDefinition);
+    visit(interfaceDefinition);
+    markSeenButNotFinished(interfaceDefinition);
+  }
+
+  protected boolean isClassUnseen(DexProgramClass clazz) {
+    return !states.containsKey(clazz);
+  }
+
+  protected boolean isClassSeenButNotFinished(DexProgramClass clazz) {
+    return states.get(clazz) == TraversalState.SEEN;
+  }
+
+  protected boolean isClassFinished(DexProgramClass clazz) {
+    return states.get(clazz) == TraversalState.FINISHED;
+  }
+
+  private void markSeenButNotFinished(DexProgramClass clazz) {
+    assert isClassUnseen(clazz);
+    states.put(clazz, TraversalState.SEEN);
+  }
+
+  private void markFinished(DexProgramClass clazz) {
+    assert isClassSeenButNotFinished(clazz);
+    states.put(clazz, TraversalState.FINISHED);
+    prune(clazz);
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index de7c43d..aef58b7 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -603,6 +603,10 @@
     return methodsTargetedByInvokeDynamic.contains(method);
   }
 
+  public boolean isMethodTargetedByInvokeDynamic(ProgramMethod method) {
+    return isMethodTargetedByInvokeDynamic(method.getReference());
+  }
+
   public Set<DexMethod> getVirtualMethodsTargetedByInvokeDirect() {
     return virtualMethodsTargetedByInvokeDirect;
   }
diff --git a/src/main/java/com/android/tools/r8/shaking/KeepMethodInfo.java b/src/main/java/com/android/tools/r8/shaking/KeepMethodInfo.java
index 1669a40..dd4226c 100644
--- a/src/main/java/com/android/tools/r8/shaking/KeepMethodInfo.java
+++ b/src/main/java/com/android/tools/r8/shaking/KeepMethodInfo.java
@@ -35,6 +35,10 @@
     return new Builder(this);
   }
 
+  public boolean isArgumentPropagationAllowed(GlobalKeepInfoConfiguration configuration) {
+    return isOptimizationAllowed(configuration);
+  }
+
   public Joiner joiner() {
     assert !isTop();
     return new Joiner(this);
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureSet.java b/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureSet.java
index 8ef5506..5122253 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureSet.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureSet.java
@@ -94,6 +94,10 @@
     return backing.contains(signature);
   }
 
+  public boolean contains(DexClassAndMethod method) {
+    return contains(method.getMethodSignature());
+  }
+
   @Override
   public boolean containsAll(Collection<?> collection) {
     return backing.containsAll(collection);
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLibraryLambdaPropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLibraryLambdaPropagationTest.java
index 774267d..d53a470 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLibraryLambdaPropagationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationLibraryLambdaPropagationTest.java
@@ -8,8 +8,9 @@
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.BooleanUtils;
+import java.util.List;
 import java.util.function.Consumer;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -19,18 +20,23 @@
 @RunWith(Parameterized.class)
 public class CallSiteOptimizationLibraryLambdaPropagationTest extends TestBase {
 
+  private final boolean enableExperimentalArgumentPropagation;
   private final TestParameters parameters;
 
-  @Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters()
-        .withCfRuntimes()
-        .withDexRuntimes()
-        .withApiLevelsStartingAtIncluding(AndroidApiLevel.N)
-        .build();
+  @Parameters(name = "{1}, experimental: {0}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        BooleanUtils.values(),
+        getTestParameters()
+            .withCfRuntimes()
+            .withDexRuntimes()
+            .withApiLevelsStartingAtIncluding(AndroidApiLevel.N)
+            .build());
   }
 
-  public CallSiteOptimizationLibraryLambdaPropagationTest(TestParameters parameters) {
+  public CallSiteOptimizationLibraryLambdaPropagationTest(
+      boolean enableExperimentalArgumentPropagation, TestParameters parameters) {
+    this.enableExperimentalArgumentPropagation = enableExperimentalArgumentPropagation;
     this.parameters = parameters;
   }
 
@@ -39,6 +45,14 @@
     testForR8(parameters.getBackend())
         .addInnerClasses(CallSiteOptimizationLibraryLambdaPropagationTest.class)
         .addKeepMainRule(TestClass.class)
+        .applyIf(
+            enableExperimentalArgumentPropagation,
+            builder ->
+                builder.addOptionsModification(
+                    options ->
+                        options
+                            .callSiteOptimizationOptions()
+                            .setEnableExperimentalArgumentPropagation()))
         .enableInliningAnnotations()
         .enableNeverClassInliningAnnotations()
         .setMinApi(parameters.getApiLevel())
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
index 9e42eb6..2bf8f8d 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationPinnedMethodOverridePropagationTest.java
@@ -11,9 +11,9 @@
 import com.android.tools.r8.R8TestCompileResult;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
-import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.BooleanUtils;
 import com.google.common.collect.ImmutableList;
+import java.util.List;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -24,14 +24,19 @@
 
   private static final String CLASS_PREFIX =
       "com.android.tools.r8.ir.optimize.callsites.CallSiteOptimizationPinnedMethodOverridePropagationTest$";
+
+  private final boolean enableExperimentalArgumentPropagation;
   private final TestParameters parameters;
 
-  @Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withDexRuntimes().withAllApiLevels().build();
+  @Parameters(name = "{1}, experimental: {0}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        BooleanUtils.values(), getTestParameters().withDexRuntimes().withAllApiLevels().build());
   }
 
-  public CallSiteOptimizationPinnedMethodOverridePropagationTest(TestParameters parameters) {
+  public CallSiteOptimizationPinnedMethodOverridePropagationTest(
+      boolean enableExperimentalArgumentPropagation, TestParameters parameters) {
+    this.enableExperimentalArgumentPropagation = enableExperimentalArgumentPropagation;
     this.parameters = parameters;
   }
 
@@ -58,13 +63,20 @@
                         + "Arg getArg2(); \npublic static "
                         + CLASS_PREFIX
                         + "Call getCaller(); \n}"))
+            .applyIf(
+                enableExperimentalArgumentPropagation,
+                builder ->
+                    builder.addOptionsModification(
+                        options ->
+                            options
+                                .callSiteOptimizationOptions()
+                                .setEnableExperimentalArgumentPropagation()))
             .enableNoVerticalClassMergingAnnotations()
             .enableNoHorizontalClassMergingAnnotations()
             .enableInliningAnnotations()
             .enableMemberValuePropagationAnnotations()
             .setMinApi(parameters.getApiLevel())
             .compile();
-    CodeInspector inspector = compiled.inspector();
     compiled.run(parameters.getRuntime(), Main2.class).assertSuccessWithOutputLines("Arg1");
     testForD8()
         .addProgramClasses(Main.class)
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationProgramLambdaPropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationProgramLambdaPropagationTest.java
index 86145c6..6d8e9d4 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationProgramLambdaPropagationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationProgramLambdaPropagationTest.java
@@ -8,7 +8,8 @@
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.BooleanUtils;
+import java.util.List;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -17,14 +18,18 @@
 @RunWith(Parameterized.class)
 public class CallSiteOptimizationProgramLambdaPropagationTest extends TestBase {
 
+  private final boolean enableExperimentalArgumentPropagation;
   private final TestParameters parameters;
 
-  @Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  @Parameters(name = "{1}, experimental: {0}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        BooleanUtils.values(), getTestParameters().withAllRuntimesAndApiLevels().build());
   }
 
-  public CallSiteOptimizationProgramLambdaPropagationTest(TestParameters parameters) {
+  public CallSiteOptimizationProgramLambdaPropagationTest(
+      boolean enableExperimentalArgumentPropagation, TestParameters parameters) {
+    this.enableExperimentalArgumentPropagation = enableExperimentalArgumentPropagation;
     this.parameters = parameters;
   }
 
@@ -33,6 +38,14 @@
     testForR8(parameters.getBackend())
         .addInnerClasses(CallSiteOptimizationProgramLambdaPropagationTest.class)
         .addKeepMainRule(TestClass.class)
+        .applyIf(
+            enableExperimentalArgumentPropagation,
+            builder ->
+                builder.addOptionsModification(
+                    options ->
+                        options
+                            .callSiteOptimizationOptions()
+                            .setEnableExperimentalArgumentPropagation()))
         .enableInliningAnnotations()
         .enableNeverClassInliningAnnotations()
         .setMinApi(parameters.getApiLevel())
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationWithInvokeCustomTargetTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationWithInvokeCustomTargetTest.java
index 1a2d667..d0d2d22 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationWithInvokeCustomTargetTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/CallSiteOptimizationWithInvokeCustomTargetTest.java
@@ -12,7 +12,7 @@
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.google.common.collect.ImmutableList;
@@ -33,18 +33,23 @@
 
   private static final String EXPECTED = StringUtils.lines("Hello world!");
 
+  private final boolean enableExperimentalArgumentPropagation;
   private final TestParameters parameters;
 
-  @Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters()
-        .withAllRuntimes()
-        // Only works when invoke-custom/dynamic are supported and ConstantCallSite defined.
-        .withApiLevelsStartingAtIncluding(apiLevelWithInvokeCustomSupport())
-        .build();
+  @Parameters(name = "{1}, experimental: {0}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        BooleanUtils.values(),
+        getTestParameters()
+            .withAllRuntimes()
+            // Only works when invoke-custom/dynamic are supported and ConstantCallSite defined.
+            .withApiLevelsStartingAtIncluding(apiLevelWithInvokeCustomSupport())
+            .build());
   }
 
-  public CallSiteOptimizationWithInvokeCustomTargetTest(TestParameters parameters) {
+  public CallSiteOptimizationWithInvokeCustomTargetTest(
+      boolean enableExperimentalArgumentPropagation, TestParameters parameters) {
+    this.enableExperimentalArgumentPropagation = enableExperimentalArgumentPropagation;
     this.parameters = parameters;
   }
 
@@ -61,9 +66,17 @@
     testForR8(parameters.getBackend())
         .addProgramClassFileData(getProgramClassFileData())
         .addKeepMainRule(TestClass.class)
-        .setMinApi(parameters.getApiLevel())
         .addKeepMethodRules(methodFromMethod(TestClass.class.getDeclaredMethod("bar", int.class)))
+        .applyIf(
+            enableExperimentalArgumentPropagation,
+            builder ->
+                builder.addOptionsModification(
+                    options ->
+                        options
+                            .callSiteOptimizationOptions()
+                            .setEnableExperimentalArgumentPropagation()))
         .enableInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), TestClass.class)
         .assertSuccessWithOutput(EXPECTED)
         .inspect(
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/KeptMethodTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/KeptMethodTest.java
index 8d405c4..b579686 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/KeptMethodTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/KeptMethodTest.java
@@ -11,29 +11,34 @@
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
 import com.android.tools.r8.utils.codeinspector.MethodSubject;
 import java.lang.reflect.Method;
+import java.util.List;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
 public class KeptMethodTest extends TestBase {
   private static final Class<?> MAIN = Main.class;
 
-  @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  @Parameters(name = "{1}, experimental: {0}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        BooleanUtils.values(), getTestParameters().withAllRuntimesAndApiLevels().build());
   }
 
+  private final boolean enableExperimentalArgumentPropagation;
   private final TestParameters parameters;
 
-  public KeptMethodTest(TestParameters parameters) {
+  public KeptMethodTest(boolean enableExperimentalArgumentPropagation, TestParameters parameters) {
+    this.enableExperimentalArgumentPropagation = enableExperimentalArgumentPropagation;
     this.parameters = parameters;
   }
 
@@ -43,12 +48,19 @@
         .addInnerClasses(KeptMethodTest.class)
         .addKeepMainRule(MAIN)
         .addKeepClassAndMembersRules(A.class)
+        .applyIf(
+            enableExperimentalArgumentPropagation,
+            builder ->
+                builder.addOptionsModification(
+                    options ->
+                        options
+                            .callSiteOptimizationOptions()
+                            .setEnableExperimentalArgumentPropagation()))
         .enableNeverClassInliningAnnotations()
         .enableInliningAnnotations()
         .addOptionsModification(
-            o -> {
-              o.testing.callSiteOptimizationInfoInspector = this::callSiteOptimizationInfoInspect;
-            })
+            o ->
+                o.testing.callSiteOptimizationInfoInspector = this::callSiteOptimizationInfoInspect)
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), MAIN)
         .assertSuccessWithOutputLines("non-null", "non-null")
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/LibraryMethodOverridesTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/LibraryMethodOverridesTest.java
index 8a057dd..ea7641f 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/LibraryMethodOverridesTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/LibraryMethodOverridesTest.java
@@ -11,13 +11,14 @@
 import com.android.tools.r8.R8TestCompileResult;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper.DexVm.Version;
 import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
 import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import java.util.List;
 import java.util.function.Predicate;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -27,18 +28,23 @@
 public class LibraryMethodOverridesTest extends TestBase {
   private static final Class<?> MAIN = TestClass.class;
 
-  @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters()
-        .withCfRuntimes()
-        // java.util.function.Predicate is not available prior to API level 24 (V7.0).
-        .withDexRuntimesStartingFromIncluding(Version.V7_0_0)
-        .build();
+  @Parameterized.Parameters(name = "{1}, experimental: {0}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        BooleanUtils.values(),
+        getTestParameters()
+            .withCfRuntimes()
+            // java.util.function.Predicate is not available prior to API level 24 (V7.0).
+            .withDexRuntimesStartingFromIncluding(Version.V7_0_0)
+            .build());
   }
 
+  private final boolean enableExperimentalArgumentPropagation;
   private final TestParameters parameters;
 
-  public LibraryMethodOverridesTest(TestParameters parameters) {
+  public LibraryMethodOverridesTest(
+      boolean enableExperimentalArgumentPropagation, TestParameters parameters) {
+    this.enableExperimentalArgumentPropagation = enableExperimentalArgumentPropagation;
     this.parameters = parameters;
   }
 
@@ -54,10 +60,18 @@
         .addProgramClasses(TestClass.class, CustomPredicate.class)
         .addClasspathClasses(LibClass.class)
         .addKeepMainRule(MAIN)
+        .addOptionsModification(
+            o ->
+                o.testing.callSiteOptimizationInfoInspector = this::callSiteOptimizationInfoInspect)
+        .applyIf(
+            enableExperimentalArgumentPropagation,
+            builder ->
+                builder.addOptionsModification(
+                    options ->
+                        options
+                            .callSiteOptimizationOptions()
+                            .setEnableExperimentalArgumentPropagation()))
         .enableInliningAnnotations()
-        .addOptionsModification(o -> {
-          o.testing.callSiteOptimizationInfoInspector = this::callSiteOptimizationInfoInspect;
-        })
         .setMinApi(parameters.getRuntime())
         .compile()
         .addRunClasspathFiles(libraryCompileResult.writeToZip())