diff --git a/src/main/java/com/android/tools/r8/graph/ProgramField.java b/src/main/java/com/android/tools/r8/graph/ProgramField.java
index adea688..60d3b0e 100644
--- a/src/main/java/com/android/tools/r8/graph/ProgramField.java
+++ b/src/main/java/com/android/tools/r8/graph/ProgramField.java
@@ -17,6 +17,10 @@
     super(holder, field);
   }
 
+  public static ProgramField asProgramFieldOrNull(DexClassAndField field) {
+    return field != null ? field.asProgramField() : null;
+  }
+
   public void collectIndexedItems(AppView<?> appView, IndexedItemCollection indexedItems) {
     getReference().collectIndexedItems(appView, indexedItems);
     DexEncodedField definition = getDefinition();
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/ConcreteMutableFieldSet.java b/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/ConcreteMutableFieldSet.java
index 901c082..5c44ef7 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/ConcreteMutableFieldSet.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/ConcreteMutableFieldSet.java
@@ -40,7 +40,7 @@
     return this;
   }
 
-  Set<DexEncodedField> getFields() {
+  public Set<DexEncodedField> getFields() {
     if (InternalOptions.assertionsEnabled()) {
       return Collections.unmodifiableSet(fields);
     }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
index 71f3df9..7de5536 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
@@ -224,7 +224,7 @@
       identifierNameStringMarker.decoupleIdentifierNameStringsInFields(executorService);
     }
 
-    ComposableOptimizationPass.run(appView, this, timing);
+    ComposableOptimizationPass.run(appView, this, executorService, timing);
 
     // Assure that no more optimization feedback left after post processing.
     assert feedback.noUpdatesLeft();
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java
index 244dbb4..1c23e06 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java
@@ -15,6 +15,8 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 
 public class ComposableOptimizationPass {
 
@@ -28,7 +30,11 @@
   }
 
   public static void run(
-      AppView<AppInfoWithLiveness> appView, PrimaryR8IRConverter converter, Timing timing) {
+      AppView<AppInfoWithLiveness> appView,
+      PrimaryR8IRConverter converter,
+      ExecutorService executorService,
+      Timing timing)
+      throws ExecutionException {
     InternalOptions options = appView.options();
     if (!options.isOptimizing() || !options.isShrinking()) {
       return;
@@ -40,16 +46,17 @@
     }
     timing.time(
         "ComposableOptimizationPass",
-        () -> new ComposableOptimizationPass(appView, converter).processWaves());
+        () -> new ComposableOptimizationPass(appView, converter).processWaves(executorService));
   }
 
-  void processWaves() {
+  void processWaves(ExecutorService executorService) throws ExecutionException {
     ComposableCallGraph callGraph = ComposableCallGraph.builder(appView).build();
     ComposeMethodProcessor methodProcessor =
         new ComposeMethodProcessor(appView, callGraph, converter);
     Set<ComposableCallGraphNode> wave = createInitialWave(callGraph);
     while (!wave.isEmpty()) {
-      Set<ComposableCallGraphNode> optimizedComposableFunctions = methodProcessor.processWave(wave);
+      Set<ComposableCallGraphNode> optimizedComposableFunctions =
+          methodProcessor.processWave(wave, executorService);
       wave = createNextWave(methodProcessor, optimizedComposableFunctions);
     }
   }
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
index 9a57173..258b69b 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
@@ -32,6 +32,8 @@
 import com.google.common.collect.Sets;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 
 public class ComposeMethodProcessor extends MethodProcessor {
 
@@ -51,7 +53,9 @@
   }
 
   // TODO(b/302483644): Process wave concurrently.
-  public Set<ComposableCallGraphNode> processWave(Set<ComposableCallGraphNode> wave) {
+  public Set<ComposableCallGraphNode> processWave(
+      Set<ComposableCallGraphNode> wave, ExecutorService executorService)
+      throws ExecutionException {
     ProcessorContext processorContext = appView.createProcessorContext();
     wave.forEach(
         node -> {
@@ -64,11 +68,13 @@
               MethodConversionOptions.forLirPhase(appView));
         });
     processed.addAll(wave);
-    return optimizeComposableFunctionsCalledFromWave(wave);
+    return optimizeComposableFunctionsCalledFromWave(wave, executorService);
   }
 
+  @SuppressWarnings("UnusedVariable")
   private Set<ComposableCallGraphNode> optimizeComposableFunctionsCalledFromWave(
-      Set<ComposableCallGraphNode> wave) {
+      Set<ComposableCallGraphNode> wave, ExecutorService executorService)
+      throws ExecutionException {
     ArgumentPropagatorOptimizationInfoPopulator optimizationInfoPopulator =
         new ArgumentPropagatorOptimizationInfoPopulator(appView, null, null, null);
     Set<ComposableCallGraphNode> optimizedComposableFunctions = Sets.newIdentityHashSet();
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
index 840ecca..35023fc 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
@@ -196,7 +196,7 @@
 
   @Override
   public Collection<V> values() {
-    throw new Unimplemented();
+    return backing.values();
   }
 
   protected abstract Wrapper<K> wrap(K member);
diff --git a/src/main/java/com/android/tools/r8/utils/collections/ProgramFieldMap.java b/src/main/java/com/android/tools/r8/utils/collections/ProgramFieldMap.java
index a8f41e2..3b2dc23 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/ProgramFieldMap.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/ProgramFieldMap.java
@@ -9,6 +9,7 @@
 import com.google.common.collect.ImmutableMap;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Supplier;
 
 public class ProgramFieldMap<V> extends DexClassAndFieldMapBase<ProgramField, V> {
@@ -23,6 +24,10 @@
     return new ProgramFieldMap<>(HashMap::new);
   }
 
+  public static <V> ProgramFieldMap<V> createConcurrent() {
+    return new ProgramFieldMap<>(ConcurrentHashMap::new);
+  }
+
   @SuppressWarnings("unchecked")
   public static <V> ProgramFieldMap<V> empty() {
     return (ProgramFieldMap<V>) EMPTY;
diff --git a/src/main/java/com/android/tools/r8/utils/collections/ProgramFieldSet.java b/src/main/java/com/android/tools/r8/utils/collections/ProgramFieldSet.java
index 93d53f9..2952b30 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/ProgramFieldSet.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/ProgramFieldSet.java
@@ -14,6 +14,8 @@
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Predicate;
 import java.util.stream.Stream;
 
 public class ProgramFieldSet implements Iterable<ProgramField> {
@@ -30,6 +32,10 @@
     return new ProgramFieldSet(new IdentityHashMap<>());
   }
 
+  public static ProgramFieldSet createConcurrent() {
+    return new ProgramFieldSet(new ConcurrentHashMap<>());
+  }
+
   public static ProgramFieldSet empty() {
     return EMPTY;
   }
@@ -82,6 +88,10 @@
     return remove(field.getReference());
   }
 
+  public boolean removeIf(Predicate<? super ProgramField> predicate) {
+    return backing.values().removeIf(predicate);
+  }
+
   public int size() {
     return backing.size();
   }
