Write feature splits to container DEX when enabled

Bug: b/249922554
Change-Id: I6baa615c801e288e25b2e9dc758f9428661fb302
diff --git a/src/main/java/com/android/tools/r8/dex/ApplicationWriterExperimental.java b/src/main/java/com/android/tools/r8/dex/ApplicationWriterExperimental.java
index 146f858..84c1f5f 100644
--- a/src/main/java/com/android/tools/r8/dex/ApplicationWriterExperimental.java
+++ b/src/main/java/com/android/tools/r8/dex/ApplicationWriterExperimental.java
@@ -4,11 +4,13 @@
 package com.android.tools.r8.dex;
 
 import static com.android.tools.r8.utils.DexVersion.Layout.CONTAINER_DEX;
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
 
 import com.android.tools.r8.ByteBufferProvider;
 import com.android.tools.r8.ByteDataView;
 import com.android.tools.r8.DexFilePerClassFileConsumer;
 import com.android.tools.r8.DexIndexedConsumer;
+import com.android.tools.r8.FeatureSplit;
 import com.android.tools.r8.ProgramConsumer;
 import com.android.tools.r8.debuginfo.DebugRepresentation;
 import com.android.tools.r8.dex.FileWriter.ByteBufferResult;
@@ -26,7 +28,9 @@
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 
@@ -101,16 +105,18 @@
       Timing timing) {
     TimingMerger merger = timing.beginMerger("Write files", executorService);
     Collection<Timing> timings;
-    // Only write to container for OutputMode DexIndexed and only for base feature.
-    List<VirtualFile> virtualFilesForContainer = new ArrayList<>();
+    Map<FeatureSplit, List<VirtualFile>> virtualFilesForContainers = new HashMap<>();
     List<VirtualFile> virtualFilesOutsideContainer = new ArrayList<>();
     virtualFiles.forEach(
         virtualFile -> {
-          if (virtualFile.getPrimaryClassDescriptor() != null
-              || virtualFile.getFeatureSplit() != null) {
+          // Only use container format for OutputMode DexIndexed. Create one container per feature.
+          if (virtualFile.getPrimaryClassDescriptor() != null) {
             virtualFilesOutsideContainer.add(virtualFile);
           } else {
-            virtualFilesForContainer.add(virtualFile);
+            FeatureSplit split = virtualFile.getFeatureSplitOrBase();
+            virtualFilesForContainers
+                .computeIfAbsent(split, ignoreKey(ArrayList::new))
+                .add(virtualFile);
           }
         });
     virtualFiles = null;
@@ -124,84 +130,99 @@
       timings.add(fileTiming);
     }
     // Write container virtual files.
-    if (!virtualFilesForContainer.isEmpty()) {
-      ProgramConsumer consumer = options.getDexIndexedConsumer();
-      ByteBufferProvider byteBufferProvider = options.getDexIndexedConsumer();
-      DexOutputBuffer dexOutputBuffer = new DexOutputBuffer(byteBufferProvider);
-      byte[] tempForAssertions = null;
+    virtualFilesForContainers.forEach(
+        (split, virtualFilesForContainer) ->
+            writeContainer(forcedStrings, timings, virtualFilesForContainer));
 
-      int offset = 0;
-      List<DexContainerSection> sections = new ArrayList<>();
-
-      for (int i = 0; i < virtualFilesForContainer.size(); i++) {
-        VirtualFile virtualFile = virtualFilesForContainer.get(i);
-        Timing fileTiming = Timing.create("VirtualFile " + virtualFile.getId(), options);
-        if (virtualFile.isEmpty()) {
-          continue;
-        }
-        DexContainerSection section =
-            writeVirtualFileSection(
-                virtualFile,
-                fileTiming,
-                forcedStrings,
-                offset,
-                dexOutputBuffer,
-                i == virtualFilesForContainer.size() - 1);
-
-        if (InternalOptions.assertionsEnabled()) {
-          // Check that writing did not modify already written sections.
-          assert offset == 0 || tempForAssertions != null;
-          byte[] outputSoFar = dexOutputBuffer.asArray();
-          for (int j = 0; j < offset; j++) {
-            assert tempForAssertions[j] == outputSoFar[j];
-          }
-          // Copy written sections including the one just written
-          tempForAssertions = new byte[section.getLayout().getEndOfFile()];
-          for (int j = 0; j < section.getLayout().getEndOfFile(); j++) {
-            tempForAssertions[j] = outputSoFar[j];
-          }
-        }
-
-        offset = section.getLayout().getEndOfFile();
-        assert BitUtils.isAligned(4, offset);
-        sections.add(section);
-        fileTiming.end();
-        timings.add(fileTiming);
-      }
-
-      if (globalsSyntheticsConsumer != null) {
-        globalsSyntheticsConsumer.finished(appView);
-      } else if (options.hasGlobalSyntheticsConsumer()) {
-        // Make sure to also call finished even if no global output was generated.
-        options.getGlobalSyntheticsConsumer().finished(appView.reporter());
-      }
-
-      if (sections.isEmpty()) {
-        merger.add(timings);
-        merger.end();
-        return;
-      }
-
-      updateStringIdsSizeAndOffset(dexOutputBuffer, sections);
-
-      ByteBufferResult result =
-          new ByteBufferResult(
-              dexOutputBuffer.stealByteBuffer(),
-              sections.get(sections.size() - 1).getLayout().getEndOfFile());
-      ByteDataView data =
-          new ByteDataView(result.buffer.array(), result.buffer.arrayOffset(), result.length);
-      // TODO(b/249922554): Add timing of passing to consumer.
-      if (consumer instanceof DexFilePerClassFileConsumer) {
-        assert false;
-      } else {
-        ((DexIndexedConsumer) consumer)
-            .accept(0, data, Sets.newIdentityHashSet(), options.reporter);
-      }
-    }
     merger.add(timings);
     merger.end();
   }
 
+  private void writeContainer(
+      List<DexString> forcedStrings, Collection<Timing> timings, List<VirtualFile> virtualFiles) {
+    ProgramConsumer consumer;
+    ByteBufferProvider byteBufferProvider;
+    if (programConsumer != null) {
+      consumer = programConsumer;
+      byteBufferProvider = programConsumer;
+    } else if (virtualFiles.get(0).getFeatureSplit() != null) {
+      ProgramConsumer featureConsumer = virtualFiles.get(0).getFeatureSplit().getProgramConsumer();
+      assert featureConsumer instanceof DexIndexedConsumer;
+      consumer = featureConsumer;
+      byteBufferProvider = (DexIndexedConsumer) featureConsumer;
+    } else {
+      consumer = options.getDexIndexedConsumer();
+      byteBufferProvider = options.getDexIndexedConsumer();
+    }
+
+    DexOutputBuffer dexOutputBuffer = new DexOutputBuffer(byteBufferProvider);
+    byte[] tempForAssertions = new byte[] {};
+
+    int offset = 0;
+    List<DexContainerSection> sections = new ArrayList<>();
+
+    for (int i = 0; i < virtualFiles.size(); i++) {
+      VirtualFile virtualFile = virtualFiles.get(i);
+      Timing fileTiming = Timing.create("VirtualFile " + virtualFile.getId(), options);
+      if (virtualFile.isEmpty()) {
+        continue;
+      }
+      DexContainerSection section =
+          writeVirtualFileSection(
+              virtualFile,
+              fileTiming,
+              forcedStrings,
+              offset,
+              dexOutputBuffer,
+              i == virtualFiles.size() - 1);
+
+      if (InternalOptions.assertionsEnabled()) {
+        // Check that writing did not modify already written sections.
+        byte[] outputSoFar = dexOutputBuffer.asArray();
+        for (int j = 0; j < offset; j++) {
+          assert tempForAssertions[j] == outputSoFar[j];
+        }
+        // Copy written sections including the one just written
+        tempForAssertions = new byte[section.getLayout().getEndOfFile()];
+        for (int j = 0; j < section.getLayout().getEndOfFile(); j++) {
+          tempForAssertions[j] = outputSoFar[j];
+        }
+      }
+
+      offset = section.getLayout().getEndOfFile();
+      assert BitUtils.isAligned(4, offset);
+      sections.add(section);
+      fileTiming.end();
+      timings.add(fileTiming);
+    }
+
+    if (globalsSyntheticsConsumer != null) {
+      globalsSyntheticsConsumer.finished(appView);
+    } else if (options.hasGlobalSyntheticsConsumer()) {
+      // Make sure to also call finished even if no global output was generated.
+      options.getGlobalSyntheticsConsumer().finished(appView.reporter());
+    }
+
+    if (sections.isEmpty()) {
+      return;
+    }
+
+    updateStringIdsSizeAndOffset(dexOutputBuffer, sections);
+
+    ByteBufferResult result =
+        new ByteBufferResult(
+            dexOutputBuffer.stealByteBuffer(),
+            sections.get(sections.size() - 1).getLayout().getEndOfFile());
+    ByteDataView data =
+        new ByteDataView(result.buffer.array(), result.buffer.arrayOffset(), result.length);
+    // TODO(b/249922554): Add timing of passing to consumer.
+    if (consumer instanceof DexFilePerClassFileConsumer) {
+      assert false;
+    } else {
+      ((DexIndexedConsumer) consumer).accept(0, data, Sets.newIdentityHashSet(), options.reporter);
+    }
+  }
+
   private void updateStringIdsSizeAndOffset(
       DexOutputBuffer dexOutputBuffer, List<DexContainerSection> sections) {
     // The last section has the shared string_ids table. Now it is written the final size and
diff --git a/src/test/java/com/android/tools/r8/dex/container/DexContainerFormatFeatureSplitTest.java b/src/test/java/com/android/tools/r8/dex/container/DexContainerFormatFeatureSplitTest.java
index 00490db..c7cef0f 100644
--- a/src/test/java/com/android/tools/r8/dex/container/DexContainerFormatFeatureSplitTest.java
+++ b/src/test/java/com/android/tools/r8/dex/container/DexContainerFormatFeatureSplitTest.java
@@ -5,9 +5,7 @@
 
 import com.android.tools.r8.R8TestCompileResult;
 import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.BooleanUtils;
-import com.android.tools.r8.utils.DexVersion;
 import com.android.tools.r8.utils.InternalOptions;
 import java.nio.file.Path;
 import java.util.List;
@@ -69,19 +67,7 @@
     Path feature2Path = result.getFeature(1);
 
     validateDex(basePath, 1, InternalOptions.containerDexApiLevel().getDexVersion());
-    validateDex(
-        feature1Path,
-        2,
-        // For container DEX API levels non container output use the highest non container format.
-        useContainerDexApiLevel
-            ? DexVersion.getDexVersion(AndroidApiLevel.V)
-            : DexVersion.getDexVersion(AndroidApiLevel.L));
-    validateDex(
-        feature2Path,
-        2,
-        // For container DEX API levels non container output use the highest non container format.
-        useContainerDexApiLevel
-            ? DexVersion.getDexVersion(AndroidApiLevel.V)
-            : DexVersion.getDexVersion(AndroidApiLevel.L));
+    validateDex(feature1Path, 1, InternalOptions.containerDexApiLevel().getDexVersion());
+    validateDex(feature2Path, 1, InternalOptions.containerDexApiLevel().getDexVersion());
   }
 }