Class by number of startup methods multi dex distribution strategy

Bug: b/298617875
Change-Id: I321aaad191190380cdb08e7d34c1d4aec5b87c82
diff --git a/src/main/java/com/android/tools/r8/dex/VirtualFile.java b/src/main/java/com/android/tools/r8/dex/VirtualFile.java
index 75c4f9d..2d87bd2 100644
--- a/src/main/java/com/android/tools/r8/dex/VirtualFile.java
+++ b/src/main/java/com/android/tools/r8/dex/VirtualFile.java
@@ -1467,7 +1467,8 @@
         virtualFile.abortTransaction();
 
         // If the above failed, then apply the selected multi startup dex distribution strategy.
-        MultiStartupDexDistributor distributor = MultiStartupDexDistributor.get(options);
+        MultiStartupDexDistributor distributor =
+            MultiStartupDexDistributor.get(options, startupProfile);
         distributor.distribute(classPartioning.getStartupClasses(), this, virtualFile, cycler);
 
         options.reporter.warning(
diff --git a/src/main/java/com/android/tools/r8/profile/startup/distribution/MultiStartupDexDistributor.java b/src/main/java/com/android/tools/r8/profile/startup/distribution/MultiStartupDexDistributor.java
index 6a3132c..a775547 100644
--- a/src/main/java/com/android/tools/r8/profile/startup/distribution/MultiStartupDexDistributor.java
+++ b/src/main/java/com/android/tools/r8/profile/startup/distribution/MultiStartupDexDistributor.java
@@ -8,40 +8,99 @@
 import com.android.tools.r8.dex.VirtualFile.PackageSplitPopulator;
 import com.android.tools.r8.dex.VirtualFile.VirtualFileCycler;
 import com.android.tools.r8.errors.Unimplemented;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.profile.startup.profile.StartupProfile;
 import com.android.tools.r8.utils.InternalOptions;
+import it.unimi.dsi.fastutil.objects.Reference2IntMap;
+import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
+import java.util.ArrayList;
+import java.util.Comparator;
 import java.util.List;
+import java.util.function.ToIntFunction;
 
 public abstract class MultiStartupDexDistributor {
 
+  StartupProfile startupProfile;
+
+  MultiStartupDexDistributor(StartupProfile startupProfile) {
+    this.startupProfile = startupProfile;
+  }
+
   public abstract void distribute(
       List<DexProgramClass> classes,
       PackageSplitPopulator packageSplitPopulator,
       VirtualFile virtualFile,
       VirtualFileCycler virtualFileCycler);
 
+  void distributeInOrder(
+      List<DexProgramClass> classes, VirtualFile virtualFile, VirtualFileCycler virtualFileCycler) {
+    // Add the startup classes one by one.
+    for (DexProgramClass startupClass : classes) {
+      virtualFile.addClass(startupClass);
+      if (hasSpaceForTransaction(virtualFile)) {
+        virtualFile.commitTransaction();
+      } else {
+        virtualFile.abortTransaction();
+        virtualFile = virtualFileCycler.addFile();
+        virtualFile.addClass(startupClass);
+        assert hasSpaceForTransaction(virtualFile);
+        virtualFile.commitTransaction();
+      }
+    }
+  }
+
   boolean hasSpaceForTransaction(VirtualFile virtualFile) {
     return !virtualFile.isFull();
   }
 
-  public static MultiStartupDexDistributor getDefault() {
-    return new ClassByNameMultiStartupDexDistributor();
+  Reference2IntMap<DexProgramClass> computeClassMetrics(
+      List<DexProgramClass> classes, ToIntFunction<DexProgramClass> fn) {
+    Reference2IntMap<DexProgramClass> result = new Reference2IntOpenHashMap<>();
+    result.defaultReturnValue(0);
+    for (DexProgramClass clazz : classes) {
+      result.put(clazz, fn.applyAsInt(clazz));
+    }
+    return result;
   }
 
-  public static MultiStartupDexDistributor get(InternalOptions options) {
+  public static MultiStartupDexDistributor getDefault(StartupProfile startupProfile) {
+    return new ClassByNameMultiStartupDexDistributor(startupProfile);
+  }
+
+  public static MultiStartupDexDistributor get(
+      InternalOptions options, StartupProfile startupProfile) {
     String strategyName = options.getStartupOptions().getMultiStartupDexDistributionStrategyName();
     if (strategyName == null) {
-      return getDefault();
+      return getDefault(startupProfile);
     }
     switch (strategyName) {
       case "classByName":
-        return getDefault();
+        return getDefault(startupProfile);
       case "classByNumberOfStartupMethods":
-        throw new Unimplemented();
+        return new ClassByLowestMetricMultiStartupDexDistributor(startupProfile) {
+
+          @Override
+          int getMetric(DexEncodedMethod method) {
+            return startupProfile.containsMethodRule(method.getReference()) ? -1 : 0;
+          }
+        };
       case "classByNumberOfStartupMethodsMinusNumberOfNonStartupMethods":
-        throw new Unimplemented();
+        return new ClassByLowestMetricMultiStartupDexDistributor(startupProfile) {
+
+          @Override
+          boolean forceSpillClassesWithNoStartupMethods() {
+            return true;
+          }
+
+          @Override
+          int getMetric(DexEncodedMethod method) {
+            return startupProfile.containsMethodRule(method.getReference()) ? -1 : 1;
+          }
+        };
       case "packageByName":
-        return new PackageByNameMultiStartupDexDistributor();
+        return new PackageByNameMultiStartupDexDistributor(startupProfile);
       case "packageByNumberOfStartupMethods":
         throw new Unimplemented();
       default:
@@ -50,8 +109,54 @@
     }
   }
 
+  private abstract static class ClassByLowestMetricMultiStartupDexDistributor
+      extends MultiStartupDexDistributor {
+
+    ClassByLowestMetricMultiStartupDexDistributor(StartupProfile startupProfile) {
+      super(startupProfile);
+    }
+
+    @Override
+    public void distribute(
+        List<DexProgramClass> classes,
+        PackageSplitPopulator packageSplitPopulator,
+        VirtualFile virtualFile,
+        VirtualFileCycler virtualFileCycler) {
+      Reference2IntMap<DexProgramClass> classMetrics =
+          computeClassMetrics(classes, this::getMetric);
+      List<DexProgramClass> distribution = new ArrayList<>(classes);
+      distribution.sort(
+          Comparator.<DexProgramClass>comparingInt(classMetrics::getInt)
+              .thenComparing(DexClass::getType));
+      distributeInOrder(distribution, virtualFile, virtualFileCycler);
+    }
+
+    int getMetric(DexProgramClass clazz) {
+      int metric = 0;
+      boolean seenStartupMethod = false;
+      for (DexEncodedMethod method : clazz.methods()) {
+        metric += getMetric(method);
+        seenStartupMethod |= startupProfile.containsMethodRule(method.getReference());
+      }
+      if (forceSpillClassesWithNoStartupMethods() && !seenStartupMethod) {
+        metric = Integer.MAX_VALUE;
+      }
+      return metric;
+    }
+
+    boolean forceSpillClassesWithNoStartupMethods() {
+      return false;
+    }
+
+    abstract int getMetric(DexEncodedMethod method);
+  }
+
   private static class ClassByNameMultiStartupDexDistributor extends MultiStartupDexDistributor {
 
+    ClassByNameMultiStartupDexDistributor(StartupProfile startupProfile) {
+      super(startupProfile);
+    }
+
     @Override
     public void distribute(
         List<DexProgramClass> classes,
@@ -59,23 +164,16 @@
         VirtualFile virtualFile,
         VirtualFileCycler virtualFileCycler) {
       // Add the (already sorted) startup classes one by one.
-      for (DexProgramClass startupClass : classes) {
-        virtualFile.addClass(startupClass);
-        if (hasSpaceForTransaction(virtualFile)) {
-          virtualFile.commitTransaction();
-        } else {
-          virtualFile.abortTransaction();
-          virtualFile = virtualFileCycler.addFile();
-          virtualFile.addClass(startupClass);
-          assert hasSpaceForTransaction(virtualFile);
-          virtualFile.commitTransaction();
-        }
-      }
+      distributeInOrder(classes, virtualFile, virtualFileCycler);
     }
   }
 
   private static class PackageByNameMultiStartupDexDistributor extends MultiStartupDexDistributor {
 
+    PackageByNameMultiStartupDexDistributor(StartupProfile startupProfile) {
+      super(startupProfile);
+    }
+
     @Override
     public void distribute(
         List<DexProgramClass> classes,