[Retrace] Add finished callbacks to mapping supplier

Bug: b/274735214
Change-Id: Id528e2552c7f95837645704986a4eb4584514c8d
diff --git a/src/main/java/com/android/tools/r8/retrace/FinishedPartitionMappingCallback.java b/src/main/java/com/android/tools/r8/retrace/FinishedPartitionMappingCallback.java
new file mode 100644
index 0000000..6c4c906
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/retrace/FinishedPartitionMappingCallback.java
@@ -0,0 +1,24 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.retrace;
+
+import com.android.tools.r8.DiagnosticsHandler;
+import com.android.tools.r8.Keep;
+
+/***
+ * Interface for registering a callback when a retracing operation is finished.
+ */
+@FunctionalInterface
+@Keep
+public interface FinishedPartitionMappingCallback {
+
+  FinishedPartitionMappingCallback EMPTY_INSTANCE = diagnosticsHandler -> {};
+
+  static FinishedPartitionMappingCallback empty() {
+    return EMPTY_INSTANCE;
+  }
+
+  void finished(DiagnosticsHandler handler);
+}
diff --git a/src/main/java/com/android/tools/r8/retrace/MappingSupplierBase.java b/src/main/java/com/android/tools/r8/retrace/MappingSupplierBase.java
index accecec..53ed349 100644
--- a/src/main/java/com/android/tools/r8/retrace/MappingSupplierBase.java
+++ b/src/main/java/com/android/tools/r8/retrace/MappingSupplierBase.java
@@ -5,13 +5,14 @@
 package com.android.tools.r8.retrace;
 
 import com.android.tools.r8.DiagnosticsHandler;
+import com.android.tools.r8.Finishable;
 import com.android.tools.r8.Keep;
 import com.android.tools.r8.references.ClassReference;
 import com.android.tools.r8.references.FieldReference;
 import com.android.tools.r8.references.MethodReference;
 
 @Keep
-public interface MappingSupplierBase<T extends MappingSupplierBase<T>> {
+public interface MappingSupplierBase<T extends MappingSupplierBase<T>> extends Finishable {
 
   /***
    * Register an allowed mapping lookup to allow for prefetching of resources.
diff --git a/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplier.java b/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplier.java
index ebcfaf9..d89ce17 100644
--- a/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplier.java
+++ b/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplier.java
@@ -22,10 +22,17 @@
       RegisterMappingPartitionCallback registerCallback,
       PrepareMappingPartitionsCallback prepareCallback,
       MappingPartitionFromKeySupplier partitionSupplier,
+      FinishedPartitionMappingCallback finishedCallback,
       boolean allowExperimental,
       byte[] metadata,
       MapVersion fallbackMapVersion) {
-    super(registerCallback, prepareCallback, allowExperimental, metadata, fallbackMapVersion);
+    super(
+        registerCallback,
+        prepareCallback,
+        finishedCallback,
+        allowExperimental,
+        metadata,
+        fallbackMapVersion);
     this.partitionSupplier = partitionSupplier;
   }
 
@@ -70,6 +77,15 @@
     return createRetracerFromPartitionSupplier(diagnosticsHandler, partitionSupplier);
   }
 
+  public MappingPartitionFromKeySupplier getMappingPartitionFromKeySupplier() {
+    return partitionSupplier;
+  }
+
+  @Override
+  public PartitionMappingSupplier getPartitionMappingSupplier() {
+    return this;
+  }
+
   @Override
   public PartitionMappingSupplier self() {
     return this;
@@ -124,6 +140,7 @@
           registerCallback,
           prepareCallback,
           partitionSupplier,
+          finishedCallback,
           allowExperimental,
           null,
           fallbackMapVersion);
@@ -165,6 +182,7 @@
           registerCallback,
           prepareCallback,
           partitionSupplier,
+          finishedCallback,
           allowExperimental,
           metadata,
           fallbackMapVersion);
diff --git a/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplierAsync.java b/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplierAsync.java
index 41acee5..9e0f71b 100644
--- a/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplierAsync.java
+++ b/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplierAsync.java
@@ -21,10 +21,17 @@
   private PartitionMappingSupplierAsync(
       RegisterMappingPartitionCallback registerCallback,
       PrepareMappingPartitionsCallback prepareCallback,
+      FinishedPartitionMappingCallback finishedCallback,
       boolean allowExperimental,
       byte[] metadata,
       MapVersion fallbackMapVersion) {
-    super(registerCallback, prepareCallback, allowExperimental, metadata, fallbackMapVersion);
+    super(
+        registerCallback,
+        prepareCallback,
+        finishedCallback,
+        allowExperimental,
+        metadata,
+        fallbackMapVersion);
   }
 
   /***
@@ -107,7 +114,12 @@
         throw new RuntimeException("Cannot build without providing metadata.");
       }
       return new PartitionMappingSupplierAsync(
-          registerCallback, prepareCallback, allowExperimental, metadata, fallbackMapVersion);
+          registerCallback,
+          prepareCallback,
+          finishedCallback,
+          allowExperimental,
+          metadata,
+          fallbackMapVersion);
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplierBuilderBase.java b/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplierBuilderBase.java
index f5b020b..1dc2e14 100644
--- a/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplierBuilderBase.java
+++ b/src/main/java/com/android/tools/r8/retrace/PartitionMappingSupplierBuilderBase.java
@@ -15,6 +15,8 @@
       RegisterMappingPartitionCallback.empty();
   protected PrepareMappingPartitionsCallback prepareCallback =
       PrepareMappingPartitionsCallback.empty();
+  protected FinishedPartitionMappingCallback finishedCallback =
+      FinishedPartitionMappingCallback.empty();
   protected final MapVersion fallbackMapVersion;
   protected boolean allowExperimental = false;
 
@@ -27,6 +29,11 @@
     return self();
   }
 
+  public T setFinishedPartitionMappingCallback(FinishedPartitionMappingCallback finishedCallback) {
+    this.finishedCallback = finishedCallback;
+    return self();
+  }
+
   public T setPrepareMappingPartitionsCallback(PrepareMappingPartitionsCallback prepareCallback) {
     this.prepareCallback = prepareCallback;
     return self();
diff --git a/src/main/java/com/android/tools/r8/retrace/PartitionedToProguardMappingConverter.java b/src/main/java/com/android/tools/r8/retrace/PartitionedToProguardMappingConverter.java
index 4a4d593..d55881cf 100644
--- a/src/main/java/com/android/tools/r8/retrace/PartitionedToProguardMappingConverter.java
+++ b/src/main/java/com/android/tools/r8/retrace/PartitionedToProguardMappingConverter.java
@@ -9,12 +9,11 @@
 import com.android.tools.r8.DiagnosticsHandler;
 import com.android.tools.r8.Finishable;
 import com.android.tools.r8.StringConsumer;
-import com.android.tools.r8.dex.CompatByteBuffer;
 import com.android.tools.r8.naming.ClassNameMapper;
 import com.android.tools.r8.naming.LineReader;
-import com.android.tools.r8.naming.MapVersion;
 import com.android.tools.r8.retrace.internal.MappingPartitionMetadataInternal;
 import com.android.tools.r8.retrace.internal.MetadataAdditionalInfo;
+import com.android.tools.r8.retrace.internal.PartitionMappingSupplierBase;
 import com.android.tools.r8.retrace.internal.ProguardMapReaderWithFiltering.ProguardMapReaderWithFilteringInputBuffer;
 import com.android.tools.r8.utils.ChainableStringConsumer;
 import java.io.ByteArrayInputStream;
@@ -23,30 +22,37 @@
 public class PartitionedToProguardMappingConverter {
 
   private final StringConsumer consumer;
-  private final MappingPartitionFromKeySupplier partitionSupplier;
-  private final byte[] metadata;
+  private final PartitionMappingSupplierBase<?> partitionMappingSupplier;
   private final DiagnosticsHandler diagnosticsHandler;
 
   private PartitionedToProguardMappingConverter(
       StringConsumer consumer,
-      MappingPartitionFromKeySupplier partitionSupplier,
-      byte[] metadata,
+      PartitionMappingSupplierBase<?> partitionMappingSupplier,
       DiagnosticsHandler diagnosticsHandler) {
     this.consumer = consumer;
-    this.partitionSupplier = partitionSupplier;
-    this.metadata = metadata;
+    this.partitionMappingSupplier = partitionMappingSupplier;
     this.diagnosticsHandler = diagnosticsHandler;
   }
 
-  public void run() throws RetracePartitionException {
+  private MappingPartitionMetadataInternal getMetadata() {
     MappingPartitionMetadataInternal metadataInternal =
-        MappingPartitionMetadataInternal.deserialize(
-            CompatByteBuffer.wrapOrNull(metadata),
-            MapVersion.MAP_VERSION_UNKNOWN,
-            diagnosticsHandler);
-    if (!metadataInternal.canGetPartitionKeys()) {
+        partitionMappingSupplier.getMetadata(diagnosticsHandler);
+    if (metadataInternal == null || !metadataInternal.canGetPartitionKeys()) {
       throw new RetracePartitionException("Cannot obtain all partition keys from metadata");
     }
+    return metadataInternal;
+  }
+
+  private void requestKeys(MappingPartitionMetadataInternal metadataInternal) {
+    for (String partitionKey : metadataInternal.getPartitionKeys()) {
+      partitionMappingSupplier.registerKeyUse(partitionKey);
+    }
+  }
+
+  private void run(
+      MappingPartitionMetadataInternal metadataInternal,
+      MappingPartitionFromKeySupplier partitionSupplier)
+      throws RetracePartitionException {
     ProguardMapWriter consumer = new ProguardMapWriter(this.consumer, diagnosticsHandler);
     if (metadataInternal.canGetAdditionalInfo()) {
       MetadataAdditionalInfo additionalInfo = metadataInternal.getAdditionalInfo();
@@ -72,6 +78,25 @@
       }
     }
     consumer.finished(diagnosticsHandler);
+    partitionMappingSupplier.finished(diagnosticsHandler);
+  }
+
+  public void run() throws RetracePartitionException {
+    MappingPartitionMetadataInternal metadata = getMetadata();
+    PartitionMappingSupplier syncSupplier = partitionMappingSupplier.getPartitionMappingSupplier();
+    if (syncSupplier == null) {
+      throw new RetracePartitionException(
+          "Running synchronously requires a synchronous partition mapping provider. Use runAsync()"
+              + " if you have an asynchronous provider.");
+    }
+    requestKeys(metadata);
+    run(metadata, syncSupplier.getMappingPartitionFromKeySupplier());
+  }
+
+  public RetraceAsyncAction runAsync() throws RetracePartitionException {
+    MappingPartitionMetadataInternal metadata = getMetadata();
+    requestKeys(metadata);
+    return supplier -> run(metadata, supplier);
   }
 
   private static class ProguardMapWriter implements ChainableStringConsumer, Finishable {
@@ -103,8 +128,7 @@
   public static class Builder {
 
     private StringConsumer consumer;
-    private MappingPartitionFromKeySupplier partitionSupplier;
-    private byte[] metadata;
+    private PartitionMappingSupplierBase<?> partitionSupplier;
     private DiagnosticsHandler diagnosticsHandler;
 
     public Builder setConsumer(StringConsumer consumer) {
@@ -112,16 +136,11 @@
       return this;
     }
 
-    public Builder setPartitionSupplier(MappingPartitionFromKeySupplier partitionSupplier) {
+    public Builder setPartitionMappingSupplier(PartitionMappingSupplierBase<?> partitionSupplier) {
       this.partitionSupplier = partitionSupplier;
       return this;
     }
 
-    public Builder setMetadata(byte[] metadata) {
-      this.metadata = metadata;
-      return this;
-    }
-
     public Builder setDiagnosticsHandler(DiagnosticsHandler diagnosticsHandler) {
       this.diagnosticsHandler = diagnosticsHandler;
       return this;
@@ -129,7 +148,7 @@
 
     public PartitionedToProguardMappingConverter build() {
       return new PartitionedToProguardMappingConverter(
-          consumer, partitionSupplier, metadata, diagnosticsHandler);
+          consumer, partitionSupplier, diagnosticsHandler);
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/retrace/Retrace.java b/src/main/java/com/android/tools/r8/retrace/Retrace.java
index 0150f83..312cb2b 100644
--- a/src/main/java/com/android/tools/r8/retrace/Retrace.java
+++ b/src/main/java/com/android/tools/r8/retrace/Retrace.java
@@ -297,6 +297,7 @@
                       RetraceUnknownMapVersionDiagnostic.create(mapVersionInfo.getValue()));
                 }
               });
+      mappingSupplier.finished(diagnosticsHandler);
     } catch (InvalidMappingFileException e) {
       command.getOptions().getDiagnosticsHandler().error(new ExceptionDiagnostic(e));
       throw e;
diff --git a/src/main/java/com/android/tools/r8/retrace/RetraceAsyncAction.java b/src/main/java/com/android/tools/r8/retrace/RetraceAsyncAction.java
new file mode 100644
index 0000000..a6088e8
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/retrace/RetraceAsyncAction.java
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.retrace;
+
+import com.android.tools.r8.Keep;
+
+@Keep
+public interface RetraceAsyncAction {
+
+  void execute(MappingPartitionFromKeySupplier supplier);
+}
diff --git a/src/main/java/com/android/tools/r8/retrace/StringRetrace.java b/src/main/java/com/android/tools/r8/retrace/StringRetrace.java
index 2c16ea1..df888dc 100644
--- a/src/main/java/com/android/tools/r8/retrace/StringRetrace.java
+++ b/src/main/java/com/android/tools/r8/retrace/StringRetrace.java
@@ -4,14 +4,13 @@
 
 package com.android.tools.r8.retrace;
 
-import static com.android.tools.r8.retrace.internal.RetraceUtils.firstNonWhiteSpaceCharacterFromIndex;
-
 import com.android.tools.r8.DiagnosticsHandler;
 import com.android.tools.r8.Keep;
 import com.android.tools.r8.retrace.internal.RetraceStackFrameResultWithContextImpl;
 import com.android.tools.r8.retrace.internal.StackTraceElementStringProxy;
 import com.android.tools.r8.utils.StringUtils;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -78,24 +77,9 @@
    */
   public RetraceStackFrameResultWithContext<String> retrace(
       List<String> stackTrace, RetraceStackTraceContext context) {
-    RetraceStackTraceResult<String> listRetraceStackTraceResult =
-        retraceStackTrace(stackTrace, context);
-    List<String> retracedStrings = new ArrayList<>();
-    listRetraceStackTraceResult.forEach(
-        newLines ->
-            newLines.forEachWithIndex(
-                (inlineFrames, ambiguousIndex) -> {
-                  for (int i = 0; i < inlineFrames.size(); i++) {
-                    String stackTraceLine = inlineFrames.get(i);
-                    if (i == 0 && ambiguousIndex > 0) {
-                      insertOrIntoStackTraceLine(stackTraceLine, retracedStrings);
-                    } else {
-                      retracedStrings.add(stackTraceLine);
-                    }
-                  }
-                }));
+    RetraceStackTraceResult<String> result = retraceStackTrace(stackTrace, context);
     return RetraceStackFrameResultWithContextImpl.create(
-        retracedStrings, listRetraceStackTraceResult.getContext());
+        joinAmbiguousLines(result.getResult()), result.getContext());
   }
 
   /**
@@ -108,38 +92,9 @@
    */
   public RetraceStackFrameResultWithContext<String> retraceParsed(
       List<StackTraceElementStringProxy> stackTrace, RetraceStackTraceContext context) {
-    RetraceStackTraceResult<String> listRetraceStackTraceResult =
-        retraceStackTraceParsed(stackTrace, context);
-    List<String> retracedStrings = new ArrayList<>();
-    listRetraceStackTraceResult.forEach(
-        newLines ->
-            newLines.forEachWithIndex(
-                (inlineFrames, ambiguousIndex) -> {
-                  for (int i = 0; i < inlineFrames.size(); i++) {
-                    String stackTraceLine = inlineFrames.get(i);
-                    if (i == 0 && ambiguousIndex > 0) {
-                      insertOrIntoStackTraceLine(stackTraceLine, retracedStrings);
-                    } else {
-                      retracedStrings.add(stackTraceLine);
-                    }
-                  }
-                }));
+    RetraceStackTraceResult<String> result = retraceStackTraceParsed(stackTrace, context);
     return RetraceStackFrameResultWithContextImpl.create(
-        retracedStrings, listRetraceStackTraceResult.getContext());
-  }
-
-  private void insertOrIntoStackTraceLine(String stackTraceLine, List<String> retracedStrings) {
-    // We are reporting an ambiguous frame. To support retracing tools that
-    // retrace line by line we have to emit <OR> at the point of the first 'at '
-    // if we can find it.
-    int indexToInsertOr = stackTraceLine.indexOf("at ");
-    if (indexToInsertOr < 0) {
-      indexToInsertOr = Math.max(StringUtils.firstNonWhitespaceCharacter(stackTraceLine), 0);
-    }
-    retracedStrings.add(
-        stackTraceLine.substring(0, indexToInsertOr)
-            + "<OR> "
-            + stackTraceLine.substring(indexToInsertOr));
+        joinAmbiguousLines(result.getResult()), result.getContext());
   }
 
   /**
@@ -153,10 +108,9 @@
       String stackTraceLine, RetraceStackTraceContext context) {
     RetraceStackFrameAmbiguousResultWithContext<String> listRetraceStackTraceResult =
         retraceFrame(stackTraceLine, context);
-    List<String> result = new ArrayList<>();
-    joinAmbiguousLines(listRetraceStackTraceResult.getAmbiguousResult(), result::add);
     return RetraceStackFrameResultWithContextImpl.create(
-        result, listRetraceStackTraceResult.getContext());
+        joinAmbiguousLines(Collections.singletonList(listRetraceStackTraceResult)),
+        listRetraceStackTraceResult.getContext());
   }
 
   /**
@@ -176,32 +130,42 @@
     }
   }
 
-  private void joinAmbiguousLines(
-      List<RetraceStackFrameResult<String>> retracedResult, Consumer<String> joinedConsumer) {
-    if (retracedResult.isEmpty()) {
-      // The result is empty, likely it maps to compiler synthesized items.
-      return;
-    }
-    Set<String> reportedFrames = new HashSet<>();
+  private List<String> joinAmbiguousLines(
+      List<RetraceStackFrameAmbiguousResult<String>> retracedResult) {
+    List<String> result = new ArrayList<>();
     retracedResult.forEach(
         potentialResults -> {
-          assert !potentialResults.isEmpty();
-          // Check if we already reported position.
-          if (reportedFrames.add(potentialResults.get(0))) {
-            boolean isAmbiguous = potentialResults != retracedResult.get(0);
-            potentialResults.forEach(
-                retracedString -> {
-                  if (isAmbiguous) {
-                    int firstCharIndex = firstNonWhiteSpaceCharacterFromIndex(retracedString, 0);
-                    joinedConsumer.accept(
-                        retracedString.substring(0, firstCharIndex)
-                            + "<OR> "
-                            + retracedString.substring(firstCharIndex));
-                  } else {
-                    joinedConsumer.accept(retracedString);
-                  }
-                });
-          }
+          Set<String> reportedFrames = new HashSet<>();
+          potentialResults.forEachWithIndex(
+              (inlineFrames, index) -> {
+                // Check if we already reported position.
+                String topFrame = inlineFrames.get(0);
+                if (reportedFrames.add(topFrame)) {
+                  inlineFrames.forEach(
+                      inlineFrame -> {
+                        boolean isAmbiguous = index > 0 && topFrame.equals(inlineFrame);
+                        if (isAmbiguous) {
+                          result.add(insertOrIntoStackTraceLine(inlineFrame));
+                        } else {
+                          result.add(inlineFrame);
+                        }
+                      });
+                }
+              });
         });
+    return result;
+  }
+
+  private String insertOrIntoStackTraceLine(String stackTraceLine) {
+    // We are reporting an ambiguous frame. To support retracing tools that
+    // retrace line by line we have to emit <OR> at the point of the first 'at '
+    // if we can find it.
+    int indexToInsertOr = stackTraceLine.indexOf("at ");
+    if (indexToInsertOr < 0) {
+      indexToInsertOr = Math.max(StringUtils.firstNonWhitespaceCharacter(stackTraceLine), 0);
+    }
+    return stackTraceLine.substring(0, indexToInsertOr)
+        + "<OR> "
+        + stackTraceLine.substring(indexToInsertOr);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/retrace/internal/PartitionMappingSupplierBase.java b/src/main/java/com/android/tools/r8/retrace/internal/PartitionMappingSupplierBase.java
index 5a9fc41..012fb94 100644
--- a/src/main/java/com/android/tools/r8/retrace/internal/PartitionMappingSupplierBase.java
+++ b/src/main/java/com/android/tools/r8/retrace/internal/PartitionMappingSupplierBase.java
@@ -7,6 +7,7 @@
 import static com.google.common.base.Predicates.alwaysTrue;
 
 import com.android.tools.r8.DiagnosticsHandler;
+import com.android.tools.r8.Finishable;
 import com.android.tools.r8.dex.CompatByteBuffer;
 import com.android.tools.r8.naming.ClassNameMapper;
 import com.android.tools.r8.naming.LineReader;
@@ -15,8 +16,10 @@
 import com.android.tools.r8.references.ClassReference;
 import com.android.tools.r8.references.FieldReference;
 import com.android.tools.r8.references.MethodReference;
+import com.android.tools.r8.retrace.FinishedPartitionMappingCallback;
 import com.android.tools.r8.retrace.InvalidMappingFileException;
 import com.android.tools.r8.retrace.MappingPartitionFromKeySupplier;
+import com.android.tools.r8.retrace.PartitionMappingSupplier;
 import com.android.tools.r8.retrace.PrepareMappingPartitionsCallback;
 import com.android.tools.r8.retrace.RegisterMappingPartitionCallback;
 import com.android.tools.r8.retrace.internal.ProguardMapReaderWithFiltering.ProguardMapReaderWithFilteringInputBuffer;
@@ -28,10 +31,12 @@
 import java.util.LinkedHashSet;
 import java.util.Set;
 
-public abstract class PartitionMappingSupplierBase<T extends PartitionMappingSupplierBase<T>> {
+public abstract class PartitionMappingSupplierBase<T extends PartitionMappingSupplierBase<T>>
+    implements Finishable {
 
   private final RegisterMappingPartitionCallback registerCallback;
   private final PrepareMappingPartitionsCallback prepareCallback;
+  private final FinishedPartitionMappingCallback finishedCallback;
   private final boolean allowExperimental;
   private final byte[] metadata;
   private final MapVersion fallbackMapVersion;
@@ -45,17 +50,19 @@
   protected PartitionMappingSupplierBase(
       RegisterMappingPartitionCallback registerCallback,
       PrepareMappingPartitionsCallback prepareCallback,
+      FinishedPartitionMappingCallback finishedCallback,
       boolean allowExperimental,
       byte[] metadata,
       MapVersion fallbackMapVersion) {
     this.registerCallback = registerCallback;
     this.prepareCallback = prepareCallback;
+    this.finishedCallback = finishedCallback;
     this.allowExperimental = allowExperimental;
     this.metadata = metadata;
     this.fallbackMapVersion = fallbackMapVersion;
   }
 
-  protected MappingPartitionMetadataInternal getMetadata(DiagnosticsHandler diagnosticsHandler) {
+  public MappingPartitionMetadataInternal getMetadata(DiagnosticsHandler diagnosticsHandler) {
     if (mappingPartitionMetadataCache != null) {
       return mappingPartitionMetadataCache;
     }
@@ -96,6 +103,10 @@
         getMetadata(diagnosticsHandler).getMapVersion().toMapVersionMappingInformation());
   }
 
+  public PartitionMappingSupplier getPartitionMappingSupplier() {
+    return null;
+  }
+
   protected RetracerImpl createRetracerFromPartitionSupplier(
       DiagnosticsHandler diagnosticsHandler, MappingPartitionFromKeySupplier partitionSupplier) {
     if (!pendingKeys.isEmpty()) {
diff --git a/src/test/java/com/android/tools/r8/retrace/partition/RetracePartitionAndJoinIdentityTest.java b/src/test/java/com/android/tools/r8/retrace/partition/RetracePartitionAndJoinIdentityTest.java
index 3836c08..f4a9316 100644
--- a/src/test/java/com/android/tools/r8/retrace/partition/RetracePartitionAndJoinIdentityTest.java
+++ b/src/test/java/com/android/tools/r8/retrace/partition/RetracePartitionAndJoinIdentityTest.java
@@ -12,6 +12,7 @@
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.retrace.MappingPartitionMetadata;
+import com.android.tools.r8.retrace.PartitionMappingSupplier;
 import com.android.tools.r8.retrace.PartitionedToProguardMappingConverter;
 import com.android.tools.r8.retrace.ProguardMapProducer;
 import com.android.tools.r8.retrace.internal.MappingPartitionKeyStrategy;
@@ -61,10 +62,13 @@
 
     StringBuilder builder = new StringBuilder();
     PartitionedToProguardMappingConverter.builder()
-        .setMetadata(metadataData.getBytes())
         .setDiagnosticsHandler(diagnosticsHandler)
+        .setPartitionMappingSupplier(
+            PartitionMappingSupplier.builder()
+                .setMetadata(metadataData.getBytes())
+                .setMappingPartitionFromKeySupplier(partitions::get)
+                .build())
         .setConsumer((string, handler) -> builder.append(string))
-        .setPartitionSupplier(partitions::get)
         .build()
         .run();
     List<String> joinedMapLines = StringUtils.splitLines(builder.toString());