Update handling of data resource in partial shrinking

Fixes: b/387278926
Change-Id: I8f4fe19242c0b0cb9e5c4bb191d9776f6dee9182
diff --git a/src/main/java/com/android/tools/r8/R8Partial.java b/src/main/java/com/android/tools/r8/R8Partial.java
index 2fc67d4..dda3e6e 100644
--- a/src/main/java/com/android/tools/r8/R8Partial.java
+++ b/src/main/java/com/android/tools/r8/R8Partial.java
@@ -6,6 +6,8 @@
 import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
 import static java.nio.charset.StandardCharsets.UTF_8;
 
+import com.android.tools.r8.DexIndexedConsumer.ArchiveConsumer;
+import com.android.tools.r8.DexIndexedConsumer.ForwardingConsumer;
 import com.android.tools.r8.StringConsumer.FileConsumer;
 import com.android.tools.r8.dex.ApplicationReader;
 import com.android.tools.r8.dump.CompilerDump;
@@ -24,7 +26,6 @@
 import com.android.tools.r8.utils.AndroidAppConsumers;
 import com.android.tools.r8.utils.DumpInputFlags;
 import com.android.tools.r8.utils.ExceptionUtils;
-import com.android.tools.r8.utils.FileUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
@@ -34,7 +35,9 @@
 import java.io.IOException;
 import java.nio.file.Files;
 import java.nio.file.Path;
+import java.util.ArrayList;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
@@ -74,7 +77,12 @@
     Timing timing = Timing.create("R8 partial " + Version.LABEL, options);
 
     ProgramConsumer originalProgramConsumer = options.programConsumer;
+    DataResourceConsumer originalDataResourceConsumer = options.dataResourceConsumer;
     MapConsumer originalMapConsumer = options.mapConsumer;
+    if (!(originalProgramConsumer instanceof DexIndexedConsumer)) {
+      throw options.reporter.fatalError(
+          "Partial shrinking does not support generating class files");
+    }
 
     Path tmp = options.partialCompilationConfiguration.getTempDir();
     Path dumpFile = options.partialCompilationConfiguration.getDumpFile();
@@ -111,7 +119,7 @@
         || dump.getBuildProperties().hasArtProfileProviders()
         || dump.getBuildProperties().hasStartupProfileProviders()) {
       throw options.reporter.fatalError(
-          "Split compilation does not support legacy multi-dex, baseline or startup profiles");
+          "Partial shrinking does not support legacy multi-dex, baseline or startup profiles");
     }
 
     DexApplication dapp = applicationReader.read().toDirect();
@@ -210,8 +218,38 @@
     TraceReferencesCommand tr = TraceReferencesBridge.makeCommand(trBuilder);
     TraceReferencesBridge.runInternal(tr);
 
+    class R8DataResources implements DataResourceConsumer {
+      final List<DataResource> dataResources = new ArrayList<>();
+
+      @Override
+      public void accept(DataDirectoryResource directory, DiagnosticsHandler diagnosticsHandler) {
+        dataResources.add(directory);
+      }
+
+      @Override
+      public void accept(DataEntryResource file, DiagnosticsHandler diagnosticsHandler) {
+        dataResources.add(file);
+      }
+
+      @Override
+      public void finished(DiagnosticsHandler handler) {}
+
+      public void feed(DataResourceConsumer consumer, DiagnosticsHandler handler) {
+        dataResources.forEach(
+            dataResource -> {
+              if (dataResource instanceof DataDirectoryResource) {
+                consumer.accept((DataDirectoryResource) dataResource, handler);
+              } else {
+                assert dataResource instanceof DataEntryResource;
+                consumer.accept((DataEntryResource) dataResource, handler);
+              }
+            });
+      }
+    }
+
     // Compile R8 input with R8 using the keep rules from trace references.
     Path r8Output = tmp.resolve("r8-output.zip");
+    R8DataResources r8DataResources = new R8DataResources();
     R8Command.Builder r8Builder =
         R8Command.builder()
             .setMinApiLevel(dump.getBuildProperties().getMinApi())
@@ -222,7 +260,13 @@
             .addProguardConfigurationFiles(dump.getProguardConfigFile(), traceReferencesRules)
             .enableLegacyFullModeForKeepRules(true)
             .setMode(dump.getBuildProperties().getCompilationMode())
-            .setOutput(r8Output, OutputMode.DexIndexed);
+            .setProgramConsumer(
+                new ForwardingConsumer(new ArchiveConsumer(r8Output)) {
+                  @Override
+                  public DataResourceConsumer getDataResourceConsumer() {
+                    return r8DataResources;
+                  }
+                });
     if (dump.hasDesugaredLibrary()) {
       r8Builder.addDesugaredLibraryConfiguration(
           Files.readString(dump.getDesugaredLibraryFile(), UTF_8));
@@ -245,20 +289,6 @@
       r8OutputAppConsumer.accept(r8OutputAppSink.build());
     }
 
-    // Emit resources and merged DEX to the output consumer.
-    // TODO(b/309743298): Consider passing the DataResourceConsumer to the R8 invocation above.
-    DataResourceConsumer dataResourceConsumer = originalProgramConsumer.getDataResourceConsumer();
-    if (dataResourceConsumer != null) {
-      ZipUtils.iterWithZipFile(
-          r8Output,
-          (zip, entry) -> {
-            if (entry.getName().endsWith(FileUtils.DEX_EXTENSION)) {
-              return;
-            }
-            dataResourceConsumer.accept(
-                DataEntryResource.fromZip(zip, entry), new DiagnosticsHandler() {});
-          });
-    }
     // TODO(b/309743298): Handle jumbo string rewriting with PCs in mapping file.
     D8Command.Builder mergerBuilder =
         D8Command.builder()
@@ -273,6 +303,12 @@
     AndroidApp mergeApp = mergeCommand.getInputApp();
     InternalOptions mergeOptions = mergeCommand.getInternalOptions();
     D8.runInternal(mergeApp, mergeOptions, executor);
+    // Feed the data resource output by R8 to the output consumer. Keeping this at the end after the
+    // merge keeps the order of calls to the output consumer closer to full R8.
+    if (originalDataResourceConsumer != null) {
+      r8DataResources.feed(originalDataResourceConsumer, options.reporter);
+      originalDataResourceConsumer.finished(options.reporter);
+    }
     timing.end();
   }
 }
diff --git a/src/test/java/com/android/tools/r8/partial/PartialCompilationDataResourcesTest.java b/src/test/java/com/android/tools/r8/partial/PartialCompilationDataResourcesTest.java
new file mode 100644
index 0000000..1cd7ceb
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/partial/PartialCompilationDataResourcesTest.java
@@ -0,0 +1,147 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.partial;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.ByteDataView;
+import com.android.tools.r8.DataDirectoryResource;
+import com.android.tools.r8.DataEntryResource;
+import com.android.tools.r8.DataEntryResource.ByteDataEntryResource;
+import com.android.tools.r8.DataResourceConsumer;
+import com.android.tools.r8.DexIndexedConsumer;
+import com.android.tools.r8.DiagnosticsHandler;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.ToolHelper.DexVm;
+import com.android.tools.r8.origin.Origin;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Set;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class PartialCompilationDataResourcesTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  // Test with min API level 24.
+  public static TestParametersCollection data() {
+    return getTestParameters()
+        .withDexRuntime(DexVm.Version.V7_0_0)
+        .withApiLevel(AndroidApiLevel.N)
+        .build();
+  }
+
+  static class DataResources implements DataResourceConsumer {
+    int dataEntryResources = 0;
+    int dataDirectoryResources = 0;
+    public int finished = 0;
+
+    @Override
+    public void accept(DataDirectoryResource directory, DiagnosticsHandler diagnosticsHandler) {
+      dataDirectoryResources++;
+    }
+
+    @Override
+    public void accept(DataEntryResource file, DiagnosticsHandler diagnosticsHandler) {
+      dataEntryResources++;
+    }
+
+    @Override
+    public void finished(DiagnosticsHandler handler) {
+      finished++;
+    }
+  }
+
+  static class ProgramConsumer implements DexIndexedConsumer {
+    public int dexFiles = 0;
+    public int finished = 0;
+    public DataResources dataResources = new DataResources();
+
+    @Override
+    public DataResourceConsumer getDataResourceConsumer() {
+      return dataResources;
+    }
+
+    @Override
+    public void finished(DiagnosticsHandler handler) {
+      finished++;
+    }
+
+    @Override
+    public void accept(
+        int fileIndex, ByteDataView data, Set<String> descriptors, DiagnosticsHandler handler) {
+      dexFiles++;
+    }
+  }
+
+  @Test
+  public void testProgramConsumer() throws Exception {
+    ProgramConsumer programConsumer = new ProgramConsumer();
+    testForR8Partial(parameters.getBackend())
+        .setMinApi(parameters)
+        .addProgramClasses(A.class, B.class, Main.class)
+        .addDataResources(
+            DataEntryResource.fromBytes(new byte[] {0}, "1", Origin.unknown()),
+            DataEntryResource.fromBytes(new byte[] {1}, "2", Origin.unknown()),
+            DataEntryResource.fromBytes(new byte[] {2}, "3", Origin.unknown()),
+            DataDirectoryResource.fromName("A/", Origin.unknown()),
+            DataDirectoryResource.fromName("B/", Origin.unknown()))
+        .addKeepMainRule(Main.class)
+        .addKeepRules("-keepdirectories")
+        .setR8PartialConfiguration(builder -> builder.includeAll().excludeClasses(A.class))
+        .setProgramConsumer(programConsumer)
+        .compile();
+    assertEquals(1, programConsumer.dexFiles);
+    assertEquals(1, programConsumer.finished);
+    assertEquals(2, programConsumer.dataResources.dataDirectoryResources);
+    assertEquals(3, programConsumer.dataResources.dataEntryResources);
+    assertEquals(1, programConsumer.dataResources.finished);
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8Partial(parameters.getBackend())
+        .setMinApi(parameters)
+        .addProgramClasses(A.class, B.class, Main.class)
+        .addDataResources(
+            new ByteDataEntryResource(
+                "Hello, world!".getBytes(StandardCharsets.UTF_8),
+                "data_resource.txt",
+                Origin.unknown()))
+        .addKeepMainRule(Main.class)
+        .setR8PartialConfiguration(builder -> builder.includeAll().excludeClasses(A.class))
+        .compile()
+        .run(parameters.getRuntime(), Main.class, getClass().getTypeName())
+        .assertSuccessWithOutputLines("Hello, world!");
+  }
+
+  public static class A {}
+
+  public static class B {}
+
+  public static class Main {
+
+    public static void main(String[] args) throws Exception {
+      byte[] buffer = new byte[1024];
+      int offset = 0;
+      int read;
+      try (InputStream is = Main.class.getClassLoader().getResourceAsStream("data_resource.txt")) {
+        while ((read = is.read(buffer, offset, 1024 - offset)) != -1) {
+          offset += read;
+        }
+      }
+      System.out.println(new String(buffer, 0, offset, StandardCharsets.UTF_8));
+    }
+  }
+}