Run dex distribution refinement on startup dex files

Bug: b/473427453
Change-Id: I6d1b09d9a46e7e737d17463d6c8b7ed99d09b430
diff --git a/src/main/java/com/android/tools/r8/dex/distribution/DexDistributionRefinement.java b/src/main/java/com/android/tools/r8/dex/distribution/DexDistributionRefinement.java
index 26e8f1a..8dc85cb 100644
--- a/src/main/java/com/android/tools/r8/dex/distribution/DexDistributionRefinement.java
+++ b/src/main/java/com/android/tools/r8/dex/distribution/DexDistributionRefinement.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.dex.distribution;
 
 import static com.google.common.base.Predicates.alwaysTrue;
+import static com.google.common.base.Predicates.not;
 
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.dex.VirtualFile;
@@ -24,6 +25,7 @@
 import com.android.tools.r8.utils.SetUtils;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.timing.Timing;
+import com.google.common.base.Predicate;
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
 import java.util.Comparator;
 import java.util.IdentityHashMap;
@@ -43,7 +45,6 @@
   private final VirtualFileCycler cycler;
   private final boolean enableContainerDex;
   private final LinkedHashSet<VirtualFile> files;
-  private final int numPasses;
   private final LensCodeRewriterUtils rewriter;
 
   // Must be concurrent since we collect concurrently.
@@ -56,15 +57,11 @@
       fileToClassesWithDeterministicOrder = new IdentityHashMap<>();
 
   private DexDistributionRefinement(
-      AppView<?> appView,
-      VirtualFileCycler cycler,
-      List<VirtualFile> filesSubjectToRefinement,
-      int numPasses) {
+      AppView<?> appView, VirtualFileCycler cycler, List<VirtualFile> filesSubjectToRefinement) {
     this.appView = appView;
     this.cycler = cycler;
     this.enableContainerDex = appView.options().enableContainerDex();
     this.files = new LinkedHashSet<>(filesSubjectToRefinement);
-    this.numPasses = numPasses;
     this.rewriter = new LensCodeRewriterUtils(appView, true);
     initialize();
   }
@@ -74,14 +71,25 @@
       throws ExecutionException {
     int numPasses = appView.testing().classToDexDistributionRefinementPasses;
     if (numPasses > 0) {
-      List<VirtualFile> filesSubjectToRefinement =
-          ListUtils.filter(cycler.getFilesForDistribution(), f -> !f.isEmpty() && !f.isStartup());
-      if (filesSubjectToRefinement.size() > 1) {
-        timing.begin("Dex distribution refinement");
-        new DexDistributionRefinement(appView, cycler, filesSubjectToRefinement, numPasses)
-            .internalRun(executorService, timing);
-        timing.end();
-      }
+      runOnPartition(appView, cycler, VirtualFile::isStartup, executorService, timing);
+      runOnPartition(appView, cycler, not(VirtualFile::isStartup), executorService, timing);
+    }
+  }
+
+  private static void runOnPartition(
+      AppView<?> appView,
+      VirtualFileCycler cycler,
+      Predicate<VirtualFile> predicate,
+      ExecutorService executorService,
+      Timing timing)
+      throws ExecutionException {
+    List<VirtualFile> filesSubjectToRefinement =
+        ListUtils.filter(cycler.getFilesForDistribution(), f -> !f.isEmpty() && predicate.test(f));
+    if (filesSubjectToRefinement.size() > 1) {
+      timing.begin("Dex distribution refinement");
+      new DexDistributionRefinement(appView, cycler, filesSubjectToRefinement)
+          .internalRun(executorService, timing);
+      timing.end();
     }
   }
 
@@ -98,7 +106,7 @@
       throws ExecutionException {
     // Run refinement.
     boolean hasEmptyFiles = false;
-    for (int i = 0; i < numPasses; i++) {
+    for (int i = 0; i < appView.testing().classToDexDistributionRefinementPasses; i++) {
       boolean changed = false;
       timing.begin("Pass " + i);
       Iterator<VirtualFile> iterator = files.iterator();