Parallelize cf application writer

Fixes: b/393534012
Change-Id: Ibc399222cf6089bf7bcdd67d2901c3d8e63cd1df
diff --git a/src/main/java/com/android/tools/r8/jar/CfApplicationClassWriter.java b/src/main/java/com/android/tools/r8/jar/CfApplicationClassWriter.java
index bb33502..0cb18fd 100644
--- a/src/main/java/com/android/tools/r8/jar/CfApplicationClassWriter.java
+++ b/src/main/java/com/android/tools/r8/jar/CfApplicationClassWriter.java
@@ -5,8 +5,6 @@
 
 import static com.android.tools.r8.utils.InternalOptions.ASM_VERSION;
 
-import com.android.tools.r8.ByteDataView;
-import com.android.tools.r8.ClassFileConsumer;
 import com.android.tools.r8.SourceFileEnvironment;
 import com.android.tools.r8.cf.CfVersion;
 import com.android.tools.r8.errors.CodeSizeOverflowDiagnostic;
@@ -45,7 +43,6 @@
 import com.android.tools.r8.synthesis.SyntheticNaming;
 import com.android.tools.r8.utils.AsmUtils;
 import com.android.tools.r8.utils.ComparatorUtils;
-import com.android.tools.r8.utils.ExceptionUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.structural.Ordered;
 import com.google.common.collect.ImmutableMap;
@@ -103,14 +100,13 @@
     return !appView.appInfo().hasDefinitionForWithoutExistenceAssert(type);
   }
 
-  void writeClassCatchingErrors(
-      ClassFileConsumer consumer,
+  Result writeClassCatchingErrors(
       LensCodeRewriterUtils rewriter,
       Optional<String> markerString,
       SourceFileEnvironment sourceFileEnvironment) {
     assert SyntheticNaming.verifyNotInternalSynthetic(clazz.getType());
     try {
-      writeClass(consumer, rewriter, markerString, sourceFileEnvironment);
+      return writeClass(rewriter, markerString, sourceFileEnvironment);
     } catch (ClassTooLargeException e) {
       throw options.reporter.fatalError(
           new ConstantPoolOverflowDiagnostic(
@@ -129,8 +125,7 @@
     }
   }
 
-  private void writeClass(
-      ClassFileConsumer consumer,
+  private Result writeClass(
       LensCodeRewriterUtils rewriter,
       Optional<String> markerString,
       SourceFileEnvironment sourceFileEnvironment) {
@@ -234,8 +229,7 @@
       // so don't assert that verifyCf() returns true.
       verifyCf(result);
     }
-    ExceptionUtils.withConsumeResourceHandler(
-        options.reporter, handler -> consumer.accept(ByteDataView.of(result), desc, handler));
+    return new Result(desc, result);
   }
 
   private String getSourceDebugExtension(DexAnnotationSet annotations) {
@@ -598,4 +592,23 @@
     PrintWriter pw = new PrintWriter(System.out);
     CheckClassAdapter.verify(reader, false, pw);
   }
+
+  public static class Result {
+
+    private final String descriptor;
+    private final byte[] classFileData;
+
+    Result(String descriptor, byte[] classFileData) {
+      this.descriptor = descriptor;
+      this.classFileData = classFileData;
+    }
+
+    public String getDescriptor() {
+      return descriptor;
+    }
+
+    public byte[] getClassFileData() {
+      return classFileData;
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java b/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
index 70b8e56..6898d97 100644
--- a/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
+++ b/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
@@ -5,6 +5,7 @@
 
 import static com.android.tools.r8.utils.positions.LineNumberOptimizer.runAndWriteMap;
 
+import com.android.tools.r8.ByteDataView;
 import com.android.tools.r8.ClassFileConsumer;
 import com.android.tools.r8.SourceFileEnvironment;
 import com.android.tools.r8.debuginfo.DebugRepresentation;
@@ -15,7 +16,9 @@
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.ir.conversion.LensCodeRewriterUtils;
 import com.android.tools.r8.naming.ProguardMapSupplier.ProguardMapId;
+import com.android.tools.r8.threading.TaskCollection;
 import com.android.tools.r8.utils.AndroidApp;
+import com.android.tools.r8.utils.ExceptionUtils;
 import com.android.tools.r8.utils.InternalGlobalSyntheticsProgramConsumer.InternalGlobalSyntheticsCfConsumer;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.OriginalSourceFiles;
@@ -23,6 +26,7 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Optional;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
 
@@ -40,13 +44,15 @@
     this.marker = Optional.ofNullable(marker);
   }
 
-  public void write(ClassFileConsumer consumer, ExecutorService executorService) {
+  public void write(ClassFileConsumer consumer, ExecutorService executorService)
+      throws ExecutionException {
     assert !options.hasMappingFileSupport();
     write(consumer, executorService, null);
   }
 
   public void write(
-      ClassFileConsumer consumer, ExecutorService executorService, AndroidApp inputApp) {
+      ClassFileConsumer consumer, ExecutorService executorService, AndroidApp inputApp)
+      throws ExecutionException {
     application.timing.begin("CfApplicationWriter.write");
     try {
       writeApplication(inputApp, consumer, executorService);
@@ -67,7 +73,8 @@
   }
 
   private void writeApplication(
-      AndroidApp inputApp, ClassFileConsumer consumer, ExecutorService executorService) {
+      AndroidApp inputApp, ClassFileConsumer consumer, ExecutorService executorService)
+      throws ExecutionException {
     ProguardMapId proguardMapId = null;
     if (options.hasMappingFileSupport()) {
       assert marker.isPresent();
@@ -104,20 +111,46 @@
         }
       }
     }
-    for (DexProgramClass clazz : classes) {
-      new CfApplicationClassWriter(appView, clazz)
-          .writeClassCatchingErrors(consumer, rewriter, markerString, sourceFileEnvironment);
-    }
+    supplyConsumer(
+        consumer, classes, markerString, rewriter, sourceFileEnvironment, executorService);
     if (!globalSyntheticClasses.isEmpty()) {
       InternalGlobalSyntheticsCfConsumer globalsConsumer =
           new InternalGlobalSyntheticsCfConsumer(options.getGlobalSyntheticsConsumer(), appView);
-      for (DexProgramClass clazz : globalSyntheticClasses) {
-        new CfApplicationClassWriter(appView, clazz)
-            .writeClassCatchingErrors(
-                globalsConsumer, rewriter, markerString, sourceFileEnvironment);
-      }
+      supplyConsumer(
+          globalsConsumer,
+          globalSyntheticClasses,
+          markerString,
+          rewriter,
+          sourceFileEnvironment,
+          executorService);
       globalsConsumer.finished(appView);
     }
     ApplicationWriter.supplyAdditionalConsumers(appView, executorService, Collections.emptyList());
   }
+
+  private void supplyConsumer(
+      ClassFileConsumer consumer,
+      Collection<DexProgramClass> classes,
+      Optional<String> markerString,
+      LensCodeRewriterUtils rewriter,
+      SourceFileEnvironment sourceFileEnvironment,
+      ExecutorService executorService)
+      throws ExecutionException {
+    TaskCollection<CfApplicationClassWriter.Result> taskCollection =
+        new TaskCollection<>(options.getThreadingModule(), executorService, classes.size());
+    taskCollection.stream(
+        classes,
+        clazz ->
+            new CfApplicationClassWriter(appView, clazz)
+                .writeClassCatchingErrors(rewriter, markerString, sourceFileEnvironment),
+        result -> supplyConsumer(consumer, result));
+  }
+
+  private void supplyConsumer(ClassFileConsumer consumer, CfApplicationClassWriter.Result result) {
+    ExceptionUtils.withConsumeResourceHandler(
+        options.reporter,
+        handler ->
+            consumer.accept(
+                ByteDataView.of(result.getClassFileData()), result.getDescriptor(), handler));
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/threading/TaskCollection.java b/src/main/java/com/android/tools/r8/threading/TaskCollection.java
index 2316ec9..3d058b5 100644
--- a/src/main/java/com/android/tools/r8/threading/TaskCollection.java
+++ b/src/main/java/com/android/tools/r8/threading/TaskCollection.java
@@ -9,12 +9,14 @@
 import com.android.tools.r8.utils.UncheckedExecutionException;
 import com.google.common.util.concurrent.Futures;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.function.Consumer;
+import java.util.function.Function;
 import java.util.function.Predicate;
 
 public class TaskCollection<T> {
@@ -43,6 +45,24 @@
     this(options.getThreadingModule(), executorService, initialCapacity);
   }
 
+  public <S> void stream(Collection<S> items, Function<S, T> fn, Consumer<T> consumer)
+      throws ExecutionException {
+    if (threadingModule.isSingleThreaded()) {
+      for (S item : items) {
+        try {
+          consumer.accept(fn.apply(item));
+        } catch (Exception e) {
+          throw new ExecutionException(e);
+        }
+      }
+    } else {
+      for (S item : items) {
+        submit(() -> fn.apply(item));
+      }
+      forEach(consumer);
+    }
+  }
+
   /**
    * Submit a task for execution.
    *
@@ -74,6 +94,11 @@
     futures.clear();
   }
 
+  public void forEach(Consumer<T> consumer) throws ExecutionException {
+    threadingModule.forEach(futures, consumer);
+    futures.clear();
+  }
+
   // Final helper methods for the collection.
 
   /** Number of current tasks in the collection. */
diff --git a/src/main/java/com/android/tools/r8/threading/ThreadingModule.java b/src/main/java/com/android/tools/r8/threading/ThreadingModule.java
index 2ec6382..cdf937b 100644
--- a/src/main/java/com/android/tools/r8/threading/ThreadingModule.java
+++ b/src/main/java/com/android/tools/r8/threading/ThreadingModule.java
@@ -17,6 +17,7 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.function.Consumer;
 
 /**
  * Threading module interface to enable non-blocking usage of R8.
@@ -35,10 +36,14 @@
 
   ExecutorService createThreadedExecutorService(int threadCount);
 
+  boolean isSingleThreaded();
+
   <T> Future<T> submit(Callable<T> task, ExecutorService executorService) throws ExecutionException;
 
   <T> void awaitFutures(List<Future<T>> futures) throws ExecutionException;
 
+  <T> void forEach(List<Future<T>> futures, Consumer<T> consumer) throws ExecutionException;
+
   class Loader {
 
     @UsedByReflection(
diff --git a/src/main/java/com/android/tools/r8/threading/providers/blocking/ThreadingModuleBlocking.java b/src/main/java/com/android/tools/r8/threading/providers/blocking/ThreadingModuleBlocking.java
index 6445817..18424ba 100644
--- a/src/main/java/com/android/tools/r8/threading/providers/blocking/ThreadingModuleBlocking.java
+++ b/src/main/java/com/android/tools/r8/threading/providers/blocking/ThreadingModuleBlocking.java
@@ -12,6 +12,7 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.function.Consumer;
 
 public class ThreadingModuleBlocking implements ThreadingModule {
 
@@ -28,16 +29,26 @@
   }
 
   @Override
+  public boolean isSingleThreaded() {
+    return false;
+  }
+
+  @Override
   public <T> Future<T> submit(Callable<T> task, ExecutorService executorService) {
     return executorService.submit(task);
   }
 
   @Override
   public <T> void awaitFutures(List<Future<T>> futures) throws ExecutionException {
-    Iterator<? extends Future<?>> it = futures.iterator();
+    forEach(futures, ignore -> {});
+  }
+
+  @Override
+  public <T> void forEach(List<Future<T>> futures, Consumer<T> consumer) throws ExecutionException {
+    Iterator<Future<T>> it = futures.iterator();
     try {
       while (it.hasNext()) {
-        it.next().get();
+        consumer.accept(it.next().get());
       }
     } catch (InterruptedException e) {
       throw new RuntimeException("Interrupted while waiting for future.", e);
diff --git a/src/main/java/com/android/tools/r8/threading/providers/singlethreaded/ThreadingModuleSingleThreaded.java b/src/main/java/com/android/tools/r8/threading/providers/singlethreaded/ThreadingModuleSingleThreaded.java
index 2afc6db..a351810 100644
--- a/src/main/java/com/android/tools/r8/threading/providers/singlethreaded/ThreadingModuleSingleThreaded.java
+++ b/src/main/java/com/android/tools/r8/threading/providers/singlethreaded/ThreadingModuleSingleThreaded.java
@@ -12,6 +12,7 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.function.Consumer;
 
 public class ThreadingModuleSingleThreaded implements ThreadingModule {
 
@@ -28,6 +29,11 @@
   }
 
   @Override
+  public boolean isSingleThreaded() {
+    return true;
+  }
+
+  @Override
   public <T> Future<T> submit(Callable<T> task, ExecutorService executorService)
       throws ExecutionException {
     try {
@@ -43,6 +49,18 @@
     assert allDone(futures);
   }
 
+  @Override
+  public <T> void forEach(List<Future<T>> futures, Consumer<T> consumer) throws ExecutionException {
+    assert allDone(futures);
+    for (Future<T> future : futures) {
+      try {
+        consumer.accept(future.get());
+      } catch (InterruptedException e) {
+        throw new RuntimeException("Interrupted while waiting for future.", e);
+      }
+    }
+  }
+
   private <T> boolean allDone(List<Future<T>> futures) {
     for (Future<?> future : futures) {
       if (!future.isDone()) {