Add support for abstract functions with multiple inputs

This is used in the InstanceFieldReadAbstractFunction to fallback to the context insensitive field value.

Bug: b/296030319
Change-Id: I1c85c4ed372397e9c5679050c03b56f24b173e5a
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/AbstractFunction.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/AbstractFunction.java
index 5d353f9..0893be8 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/AbstractFunction.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/AbstractFunction.java
@@ -14,20 +14,27 @@
   }
 
   /**
-   * Applies the current abstract function to the given {@param state}.
+   * Applies the current abstract function to its declared inputs (in {@link #getBaseInFlow()}).
    *
    * <p>It is guaranteed by the caller that the given {@param state} is the abstract state for the
-   * field or parameter this function depends on, i.e., the node returned by {@link
-   * #getBaseInFlow()}.
+   * field or parameter that caused this function to be reevaluated. If this abstract function takes
+   * a single input, then {@param state} is guaranteed to be the state for the node returned by
+   * {@link #getBaseInFlow()}, and {@param flowGraphStateProvider} should never be used.
+   *
+   * <p>Abstract functions that depend on multiple inputs can lookup the state for each input in
+   * {@param flowGraphStateProvider}. Attempting to lookup the state of a non-declared input is an
+   * error.
    */
-  NonEmptyValueState apply(ConcreteValueState state);
+  ValueState apply(FlowGraphStateProvider flowGraphStateProvider, ConcreteValueState state);
+
+  /** Returns true if the given {@param inFlow} is a declared input of this abstract function. */
+  boolean containsBaseInFlow(BaseInFlow inFlow);
 
   /**
-   * Returns the (single) program field or parameter graph node that this function depends on. Upon
-   * any change to the abstract state of this graph node this abstract function must be
-   * re-evaluated.
+   * Returns the program field or parameter graph nodes that this function depends on. Upon any
+   * change to the abstract state of any of these nodes this abstract function must be re-evaluated.
    */
-  InFlow getBaseInFlow();
+  Iterable<BaseInFlow> getBaseInFlow();
 
   @Override
   default boolean isAbstractFunction() {
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/FieldValue.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/FieldValue.java
index b049381..f7bd0ec 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/FieldValue.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/FieldValue.java
@@ -25,6 +25,11 @@
   }
 
   @Override
+  public boolean isFieldValue(DexField field) {
+    return this.field.isIdenticalTo(field);
+  }
+
+  @Override
   public FieldValue asFieldValue() {
     return this;
   }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/FlowGraphStateProvider.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/FlowGraphStateProvider.java
new file mode 100644
index 0000000..5fe2970
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/FlowGraphStateProvider.java
@@ -0,0 +1,58 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.argumentpropagation.codescanner;
+
+import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.optimize.argumentpropagation.propagation.InFlowPropagator.FlowGraph;
+import com.android.tools.r8.utils.InternalOptions;
+import java.util.function.Supplier;
+
+public interface FlowGraphStateProvider {
+
+  static FlowGraphStateProvider create(FlowGraph flowGraph, AbstractFunction abstractFunction) {
+    if (!InternalOptions.assertionsEnabled()) {
+      return flowGraph;
+    }
+    // If the abstract function is a canonical function, or the abstract function has a single
+    // declared input, we should never perform any state lookups.
+    if (abstractFunction.isIdentity()
+        || abstractFunction.isUnknownAbstractFunction()
+        || abstractFunction.isUpdateChangedFlagsAbstractFunction()) {
+      return new FlowGraphStateProvider() {
+
+        @Override
+        public ValueState getState(DexField field) {
+          throw new Unreachable();
+        }
+
+        @Override
+        public ValueState getState(BaseInFlow inFlow, Supplier<ValueState> defaultStateProvider) {
+          throw new Unreachable();
+        }
+      };
+    }
+    // Otherwise, restrict state lookups to the declared base in flow. This is required for arriving
+    // at the correct fix point.
+    assert abstractFunction.isInstanceFieldReadAbstractFunction();
+    return new FlowGraphStateProvider() {
+
+      @Override
+      public ValueState getState(DexField field) {
+        assert abstractFunction.containsBaseInFlow(new FieldValue(field));
+        return flowGraph.getState(field);
+      }
+
+      @Override
+      public ValueState getState(BaseInFlow inFlow, Supplier<ValueState> defaultStateProvider) {
+        assert abstractFunction.containsBaseInFlow(inFlow);
+        return flowGraph.getState(inFlow, defaultStateProvider);
+      }
+    };
+  }
+
+  ValueState getState(DexField field);
+
+  ValueState getState(BaseInFlow inFlow, Supplier<ValueState> defaultStateProvider);
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/IdentityAbstractFunction.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/IdentityAbstractFunction.java
index de16de1..e3818a9 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/IdentityAbstractFunction.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/IdentityAbstractFunction.java
@@ -16,12 +16,17 @@
   }
 
   @Override
-  public NonEmptyValueState apply(ConcreteValueState state) {
+  public ValueState apply(FlowGraphStateProvider flowGraphStateProvider, ConcreteValueState state) {
     return state;
   }
 
   @Override
-  public InFlow getBaseInFlow() {
+  public boolean containsBaseInFlow(BaseInFlow inFlow) {
+    throw new Unreachable();
+  }
+
+  @Override
+  public Iterable<BaseInFlow> getBaseInFlow() {
     throw new Unreachable();
   }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/InFlow.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/InFlow.java
index a305998..1f28543 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/InFlow.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/InFlow.java
@@ -3,6 +3,7 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.optimize.argumentpropagation.codescanner;
 
+import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.optimize.compose.UpdateChangedFlagsAbstractFunction;
 
 public interface InFlow {
@@ -27,6 +28,10 @@
     return false;
   }
 
+  default boolean isFieldValue(DexField field) {
+    return false;
+  }
+
   default FieldValue asFieldValue() {
     return null;
   }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/InstanceFieldReadAbstractFunction.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/InstanceFieldReadAbstractFunction.java
index 9a8cf97..59a3a7f 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/InstanceFieldReadAbstractFunction.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/InstanceFieldReadAbstractFunction.java
@@ -5,6 +5,7 @@
 
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.google.common.collect.Lists;
 
 public class InstanceFieldReadAbstractFunction implements AbstractFunction {
 
@@ -16,35 +17,45 @@
     this.field = field;
   }
 
-  // TODO(b/296030319): Instead of returning unknown from here, we should fallback to the state of
-  //  the instance field node in the graph. A prerequisite for this is the ability to express
-  //  multiple inputs to abstract functions.
   @Override
-  public NonEmptyValueState apply(ConcreteValueState state) {
+  public ValueState apply(
+      FlowGraphStateProvider flowGraphStateProvider, ConcreteValueState predecessorState) {
+    ValueState state = flowGraphStateProvider.getState(receiver, () -> ValueState.bottom(field));
+    if (state.isBottom()) {
+      return ValueState.bottom(field);
+    }
     if (!state.isClassState()) {
-      return ValueState.unknown();
+      return getFallbackState(flowGraphStateProvider);
     }
     ConcreteClassTypeValueState classState = state.asClassState();
     if (classState.getNullability().isDefinitelyNull()) {
-      // TODO(b/296030319): This should be rare, but we should really return bottom here, since
-      //  reading a field from the the null value throws an exception, meaning no flow should be
-      //  propagated.
-      return ValueState.unknown();
+      return ValueState.bottom(field);
     }
     AbstractValue abstractValue = state.getAbstractValue(null);
     if (!abstractValue.hasObjectState()) {
-      return ValueState.unknown();
+      return getFallbackState(flowGraphStateProvider);
     }
     AbstractValue fieldValue = abstractValue.getObjectState().getAbstractFieldValue(field);
     if (fieldValue.isUnknown()) {
-      return ValueState.unknown();
+      return getFallbackState(flowGraphStateProvider);
     }
     return ConcreteValueState.create(field.getType(), fieldValue);
   }
 
   @Override
-  public InFlow getBaseInFlow() {
-    return receiver;
+  public boolean containsBaseInFlow(BaseInFlow inFlow) {
+    return inFlow.equals(receiver) || inFlow.isFieldValue(field);
+  }
+
+  @Override
+  public Iterable<BaseInFlow> getBaseInFlow() {
+    return Lists.newArrayList(receiver, new FieldValue(field));
+  }
+
+  private ValueState getFallbackState(FlowGraphStateProvider flowGraphStateProvider) {
+    ValueState valueState = flowGraphStateProvider.getState(new FieldValue(field), null);
+    assert !valueState.isConcrete() || !valueState.asConcrete().hasInFlow();
+    return valueState;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/OrAbstractFunction.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/OrAbstractFunction.java
index e450a76..93ac577 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/OrAbstractFunction.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/OrAbstractFunction.java
@@ -3,6 +3,7 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.optimize.argumentpropagation.codescanner;
 
+import com.android.tools.r8.utils.IterableUtils;
 import java.util.Objects;
 
 /**
@@ -12,28 +13,36 @@
  */
 public class OrAbstractFunction implements AbstractFunction {
 
-  public final InFlow inFlow;
+  public final BaseInFlow inFlow;
   public final long constant;
 
-  public OrAbstractFunction(InFlow inFlow, long constant) {
+  public OrAbstractFunction(BaseInFlow inFlow, long constant) {
     this.inFlow = inFlow;
     this.constant = constant;
   }
 
   @Override
-  public NonEmptyValueState apply(ConcreteValueState state) {
+  public ValueState apply(FlowGraphStateProvider flowGraphStateProvider, ConcreteValueState state) {
     // TODO(b/302483644): Implement this abstract function to allow correct value propagation of
     //  updateChangedFlags(x | 1).
     return state;
   }
 
   @Override
-  public InFlow getBaseInFlow() {
+  public boolean containsBaseInFlow(BaseInFlow otherInFlow) {
+    if (inFlow.isAbstractFunction()) {
+      return inFlow.asAbstractFunction().containsBaseInFlow(otherInFlow);
+    }
+    assert inFlow.isBaseInFlow();
+    return inFlow.equals(otherInFlow);
+  }
+
+  @Override
+  public Iterable<BaseInFlow> getBaseInFlow() {
     if (inFlow.isAbstractFunction()) {
       return inFlow.asAbstractFunction().getBaseInFlow();
     }
-    assert inFlow.isFieldValue() || inFlow.isMethodParameter();
-    return inFlow;
+    return IterableUtils.singleton(inFlow);
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/UnknownAbstractFunction.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/UnknownAbstractFunction.java
index ccddf4d..7ddc238 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/UnknownAbstractFunction.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/UnknownAbstractFunction.java
@@ -16,12 +16,17 @@
   }
 
   @Override
-  public NonEmptyValueState apply(ConcreteValueState state) {
+  public ValueState apply(FlowGraphStateProvider flowGraphStateProvider, ConcreteValueState state) {
     return ValueState.unknown();
   }
 
   @Override
-  public InFlow getBaseInFlow() {
+  public boolean containsBaseInFlow(BaseInFlow inFlow) {
+    throw new Unreachable();
+  }
+
+  @Override
+  public Iterable<BaseInFlow> getBaseInFlow() {
     throw new Unreachable();
   }
 
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ValueState.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ValueState.java
index 4cb4c2b..7e91edf 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ValueState.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/ValueState.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.optimize.argumentpropagation.codescanner;
 
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramField;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
@@ -14,6 +15,10 @@
 public abstract class ValueState {
 
   public static BottomValueState bottom(ProgramField field) {
+    return bottom(field.getReference());
+  }
+
+  public static BottomValueState bottom(DexField field) {
     DexType fieldType = field.getType();
     if (fieldType.isArrayType()) {
       return bottomArrayTypeParameter();
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/DefaultFieldValueJoiner.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/DefaultFieldValueJoiner.java
index 542b17e..834f186 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/DefaultFieldValueJoiner.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/DefaultFieldValueJoiner.java
@@ -17,6 +17,7 @@
 import com.android.tools.r8.utils.IterableUtils;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.MapUtils;
+import com.android.tools.r8.utils.Pair;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.collections.ProgramFieldSet;
 import com.google.common.collect.Lists;
@@ -42,7 +43,7 @@
     this.flowGraphs = flowGraphs;
   }
 
-  public Collection<Deque<Node>> joinDefaultFieldValuesForFieldsWithReadBeforeWrite(
+  public Map<FlowGraph, Deque<Node>> joinDefaultFieldValuesForFieldsWithReadBeforeWrite(
       ExecutorService executorService) throws ExecutionException {
     // Find all the fields where we need to determine if each field read is guaranteed to be
     // dominated by a write.
@@ -191,31 +192,34 @@
     }
   }
 
-  private Collection<Deque<Node>> updateFlowGraphs(
+  private Map<FlowGraph, Deque<Node>> updateFlowGraphs(
       ProgramFieldSet fieldsWithLiveDefaultValue, ExecutorService executorService)
       throws ExecutionException {
-    return ThreadUtils.processItemsWithResultsThatMatches(
-        flowGraphs,
-        flowGraph -> {
-          Deque<Node> worklist = new ArrayDeque<>();
-          flowGraph.forEachFieldNode(
-              node -> {
-                ProgramField field = node.getField();
-                if (fieldsWithLiveDefaultValue.contains(field)) {
-                  node.addDefaultValue(
-                      appView,
-                      () -> {
-                        if (node.isUnknown()) {
-                          node.clearPredecessors();
-                        }
-                        node.addToWorkList(worklist);
-                      });
-                }
-              });
-          return worklist;
-        },
-        worklist -> !worklist.isEmpty(),
-        appView.options().getThreadingModule(),
-        executorService);
+    Collection<Pair<FlowGraph, Deque<Node>>> worklists =
+        ThreadUtils.processItemsWithResultsThatMatches(
+            flowGraphs,
+            flowGraph -> {
+              Deque<Node> worklist = new ArrayDeque<>();
+              flowGraph.forEachFieldNode(
+                  node -> {
+                    ProgramField field = node.getField();
+                    if (fieldsWithLiveDefaultValue.contains(field)) {
+                      node.addDefaultValue(
+                          appView,
+                          () -> {
+                            if (node.isUnknown()) {
+                              node.clearPredecessors();
+                            }
+                            node.addToWorkList(worklist);
+                          });
+                    }
+                  });
+              return new Pair<>(flowGraph, worklist);
+            },
+            pair -> !pair.getSecond().isEmpty(),
+            appView.options().getThreadingModule(),
+            executorService);
+    return MapUtils.newIdentityHashMap(
+        builder -> worklists.forEach(pair -> builder.put(pair.getFirst(), pair.getSecond())));
   }
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InFlowPropagator.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InFlowPropagator.java
index 27fcb12..5b1ccf2 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InFlowPropagator.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/propagation/InFlowPropagator.java
@@ -9,6 +9,7 @@
 
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
@@ -21,6 +22,7 @@
 import com.android.tools.r8.ir.analysis.value.AbstractValueFactory;
 import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.AbstractFunction;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.BaseInFlow;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteArrayTypeValueState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteClassTypeValueState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMethodState;
@@ -29,6 +31,7 @@
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteValueState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.FieldStateCollection;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.FieldValue;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.FlowGraphStateProvider;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.InFlow;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodParameter;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
@@ -42,9 +45,9 @@
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.TraversalContinuation;
-import com.android.tools.r8.utils.collections.ProgramFieldMap;
 import com.google.common.collect.Sets;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
+import it.unimi.dsi.fastutil.ints.Int2ReferenceMaps;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
 import java.util.ArrayDeque;
 import java.util.Collection;
@@ -59,6 +62,7 @@
 import java.util.function.BiConsumer;
 import java.util.function.BiPredicate;
 import java.util.function.Consumer;
+import java.util.function.Supplier;
 
 public class InFlowPropagator {
 
@@ -90,7 +94,7 @@
     // perform this analysis after having computed the initial fixpoint(s). The hypothesis is that
     // many fields will have reached the unknown state after the initial fixpoint, meaning there is
     // fewer fields to analyze.
-    Collection<Deque<Node>> worklists =
+    Map<FlowGraph, Deque<Node>> worklists =
         includeDefaultValuesInFieldStates(flowGraphs, executorService);
 
     // Since the inclusion of default values changes the flow graphs, we need to repeat the
@@ -111,7 +115,7 @@
     return ListUtils.map(stronglyConnectedComponents, FlowGraph::new);
   }
 
-  private Collection<Deque<Node>> includeDefaultValuesInFieldStates(
+  private Map<FlowGraph, Deque<Node>> includeDefaultValuesInFieldStates(
       List<FlowGraph> flowGraphs, ExecutorService executorService) throws ExecutionException {
     DefaultFieldValueJoiner joiner = new DefaultFieldValueJoiner(appView, flowGraphs);
     return joiner.joinDefaultFieldValuesForFieldsWithReadBeforeWrite(executorService);
@@ -123,9 +127,10 @@
         flowGraphs, this::process, appView.options().getThreadingModule(), executorService);
   }
 
-  private void processWorklists(Collection<Deque<Node>> worklists, ExecutorService executorService)
+  private void processWorklists(
+      Map<FlowGraph, Deque<Node>> worklists, ExecutorService executorService)
       throws ExecutionException {
-    ThreadUtils.processItems(
+    ThreadUtils.processMap(
         worklists, this::process, appView.options().getThreadingModule(), executorService);
   }
 
@@ -133,10 +138,10 @@
     // Build a worklist containing all the nodes.
     Deque<Node> worklist = new ArrayDeque<>();
     flowGraph.forEachNode(worklist::add);
-    process(worklist);
+    process(flowGraph, worklist);
   }
 
-  private void process(Deque<Node> worklist) {
+  private void process(FlowGraph flowGraph, Deque<Node> worklist) {
     // Repeatedly propagate argument information through edges in the flow graph until there are no
     // more changes.
     // TODO(b/190154391): Consider a path p1 -> p2 -> p3 in the graph. If we process p2 first, then
@@ -146,11 +151,11 @@
     while (!worklist.isEmpty()) {
       Node node = worklist.removeLast();
       node.unsetInWorklist();
-      propagate(node, worklist);
+      propagate(flowGraph, node, worklist);
     }
   }
 
-  private void propagate(Node node, Deque<Node> worklist) {
+  private void propagate(FlowGraph flowGraph, Node node, Deque<Node> worklist) {
     if (node.isBottom()) {
       return;
     }
@@ -164,18 +169,22 @@
       }
       node.clearDanglingSuccessors();
     } else {
-      propagateNode(node, worklist);
+      propagateNode(flowGraph, node, worklist);
     }
   }
 
-  private void propagateNode(Node node, Deque<Node> worklist) {
+  private void propagateNode(FlowGraph flowGraph, Node node, Deque<Node> worklist) {
     ConcreteValueState state = node.getState().asConcrete();
     node.removeSuccessorIf(
         (successorNode, transferFunctions) -> {
           assert !successorNode.isUnknown();
           for (AbstractFunction transferFunction : transferFunctions) {
-            NonEmptyValueState transferState = transferFunction.apply(state);
-            if (transferState.isUnknown()) {
+            FlowGraphStateProvider flowGraphStateProvider =
+                FlowGraphStateProvider.create(flowGraph, transferFunction);
+            ValueState transferState = transferFunction.apply(flowGraphStateProvider, state);
+            if (transferState.isBottom()) {
+              // Nothing to propagate.
+            } else if (transferState.isUnknown()) {
               successorNode.setStateToUnknown();
               successorNode.addToWorkList(worklist);
             } else {
@@ -222,9 +231,9 @@
     }
   }
 
-  public class FlowGraph extends BidirectedGraph<Node> {
+  public class FlowGraph extends BidirectedGraph<Node> implements FlowGraphStateProvider {
 
-    private final ProgramFieldMap<FieldNode> fieldNodes = ProgramFieldMap.create();
+    private final Map<DexField, FieldNode> fieldNodes = new IdentityHashMap<>();
     private final Map<DexMethod, Int2ReferenceMap<ParameterNode>> parameterNodes =
         new IdentityHashMap<>();
 
@@ -236,7 +245,7 @@
       for (Node node : nodes) {
         if (node.isFieldNode()) {
           FieldNode fieldNode = node.asFieldNode();
-          fieldNodes.put(fieldNode.getField(), fieldNode);
+          fieldNodes.put(fieldNode.getField().getReference(), fieldNode);
         } else {
           ParameterNode parameterNode = node.asParameterNode();
           parameterNodes
@@ -365,13 +374,19 @@
     }
 
     private TraversalContinuation<?, ?> addInFlow(AbstractFunction inFlow, Node node) {
-      InFlow baseInFlow = inFlow.getBaseInFlow();
-      if (baseInFlow.isFieldValue()) {
-        return addInFlow(baseInFlow.asFieldValue(), node, inFlow);
-      } else {
-        assert baseInFlow.isMethodParameter();
-        return addInFlow(baseInFlow.asMethodParameter(), node, inFlow);
+      for (BaseInFlow baseInFlow : inFlow.getBaseInFlow()) {
+        TraversalContinuation<?, ?> traversalContinuation;
+        if (baseInFlow.isFieldValue()) {
+          traversalContinuation = addInFlow(baseInFlow.asFieldValue(), node, inFlow);
+        } else {
+          assert baseInFlow.isMethodParameter();
+          traversalContinuation = addInFlow(baseInFlow.asMethodParameter(), node, inFlow);
+        }
+        if (traversalContinuation.shouldBreak()) {
+          return traversalContinuation;
+        }
       }
+      return TraversalContinuation.doContinue();
     }
 
     private TraversalContinuation<?, ?> addInFlow(FieldValue inFlow, Node node) {
@@ -380,6 +395,8 @@
 
     private TraversalContinuation<?, ?> addInFlow(
         FieldValue inFlow, Node node, AbstractFunction transferFunction) {
+      assert !node.isUnknown();
+
       ProgramField field = asProgramFieldOrNull(appView.definitionFor(inFlow.getField()));
       if (field == null) {
         assert false;
@@ -441,7 +458,8 @@
     }
 
     private FieldNode getOrCreateFieldNode(ProgramField field, ValueState fieldState) {
-      return fieldNodes.computeIfAbsent(field, f -> new FieldNode(f, fieldState));
+      return fieldNodes.computeIfAbsent(
+          field.getReference(), ignoreKey(() -> new FieldNode(field, fieldState)));
     }
 
     private ParameterNode getOrCreateParameterNode(
@@ -481,6 +499,43 @@
       }
       return methodStates.get(method);
     }
+
+    @Override
+    public ValueState getState(DexField field) {
+      return fieldNodes.get(field).getState();
+    }
+
+    @Override
+    public ValueState getState(BaseInFlow inFlow, Supplier<ValueState> defaultStateProvider) {
+      if (inFlow.isFieldValue()) {
+        FieldValue fieldValue = inFlow.asFieldValue();
+        return getState(fieldValue.getField());
+      } else {
+        assert inFlow.isMethodParameter();
+        MethodParameter methodParameter = inFlow.asMethodParameter();
+        ParameterNode parameterNode =
+            parameterNodes
+                .getOrDefault(methodParameter.getMethod(), Int2ReferenceMaps.emptyMap())
+                .get(methodParameter.getIndex());
+        if (parameterNode != null) {
+          return parameterNode.getState();
+        }
+        assert verifyMissingParameterStateIsBottom(methodParameter);
+        return defaultStateProvider.get();
+      }
+    }
+
+    private boolean verifyMissingParameterStateIsBottom(MethodParameter methodParameter) {
+      ProgramMethod enclosingMethod = getEnclosingMethod(methodParameter);
+      if (enclosingMethod == null) {
+        assert converter
+            .getInliner()
+            .verifyIsPrunedDueToSingleCallerInlining(methodParameter.getMethod());
+        return true;
+      }
+      MethodState enclosingMethodState = getMethodState(enclosingMethod);
+      return enclosingMethodState.isBottom();
+    }
   }
 
   public abstract static class Node {
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
index 1dc5061..6b63d9c 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
@@ -24,17 +24,18 @@
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
 import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorCodeScanner;
 import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorOptimizationInfoPopulator;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.BaseInFlow;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMonomorphicMethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcretePrimitiveTypeValueState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteValueState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.FieldStateCollection;
-import com.android.tools.r8.optimize.argumentpropagation.codescanner.InFlow;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodParameter;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodStateCollectionByReference;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ValueState;
 import com.android.tools.r8.optimize.argumentpropagation.propagation.InFlowPropagator;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.IterableUtils;
 import com.android.tools.r8.utils.LazyBox;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
@@ -162,11 +163,13 @@
     }
 
     // This is a call to a composable function from a restart function.
-    InFlow baseInFlow = transferFunction.getBaseInFlow();
-    assert baseInFlow.isFieldValue();
+    Iterable<BaseInFlow> baseInFlow = transferFunction.getBaseInFlow();
+    assert Iterables.size(baseInFlow) == 1;
+    BaseInFlow singleBaseInFlow = IterableUtils.first(baseInFlow);
+    assert singleBaseInFlow.isFieldValue();
 
     ProgramField field =
-        asProgramFieldOrNull(appView.definitionFor(baseInFlow.asFieldValue().getField()));
+        asProgramFieldOrNull(appView.definitionFor(singleBaseInFlow.asFieldValue().getField()));
     assert field != null;
 
     codeScanner
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/UpdateChangedFlagsAbstractFunction.java b/src/main/java/com/android/tools/r8/optimize/compose/UpdateChangedFlagsAbstractFunction.java
index c3dbf97..8bf3c0d 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/UpdateChangedFlagsAbstractFunction.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/UpdateChangedFlagsAbstractFunction.java
@@ -4,9 +4,12 @@
 package com.android.tools.r8.optimize.compose;
 
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.AbstractFunction;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.BaseInFlow;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteValueState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.FlowGraphStateProvider;
 import com.android.tools.r8.optimize.argumentpropagation.codescanner.InFlow;
-import com.android.tools.r8.optimize.argumentpropagation.codescanner.NonEmptyValueState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ValueState;
+import com.android.tools.r8.utils.IterableUtils;
 import java.util.Objects;
 
 public class UpdateChangedFlagsAbstractFunction implements AbstractFunction {
@@ -18,19 +21,28 @@
   }
 
   @Override
-  public NonEmptyValueState apply(ConcreteValueState state) {
+  public ValueState apply(FlowGraphStateProvider flowGraphStateProvider, ConcreteValueState state) {
     // TODO(b/302483644): Implement this abstract function to allow correct value propagation of
     //  updateChangedFlags(x | 1).
     return state;
   }
 
   @Override
-  public InFlow getBaseInFlow() {
+  public boolean containsBaseInFlow(BaseInFlow otherInFlow) {
+    if (inFlow.isAbstractFunction()) {
+      return inFlow.asAbstractFunction().containsBaseInFlow(otherInFlow);
+    }
+    assert inFlow.isBaseInFlow();
+    return inFlow.equals(otherInFlow);
+  }
+
+  @Override
+  public Iterable<BaseInFlow> getBaseInFlow() {
     if (inFlow.isAbstractFunction()) {
       return inFlow.asAbstractFunction().getBaseInFlow();
     }
-    assert inFlow.isFieldValue() || inFlow.isMethodParameter();
-    return inFlow;
+    assert inFlow.isBaseInFlow();
+    return IterableUtils.singleton(inFlow.asBaseInFlow());
   }
 
   @Override