Layout class defs in startup profile order

Bug: b/324268048
Change-Id: I961e86f4581fe6c1f63ab8a994c0cd3393f541ea
diff --git a/src/main/java/com/android/tools/r8/dex/VirtualFile.java b/src/main/java/com/android/tools/r8/dex/VirtualFile.java
index 9f83a90..dc35e67 100644
--- a/src/main/java/com/android/tools/r8/dex/VirtualFile.java
+++ b/src/main/java/com/android/tools/r8/dex/VirtualFile.java
@@ -247,6 +247,8 @@
             indexedItems.callSites,
             indexedItems.methodHandles,
             lazyDexStringsCount,
+            startupProfile,
+            this,
             timing);
   }
 
diff --git a/src/main/java/com/android/tools/r8/graph/ObjectToOffsetMapping.java b/src/main/java/com/android/tools/r8/graph/ObjectToOffsetMapping.java
index b52eaf3..9d190ed 100644
--- a/src/main/java/com/android/tools/r8/graph/ObjectToOffsetMapping.java
+++ b/src/main/java/com/android/tools/r8/graph/ObjectToOffsetMapping.java
@@ -3,12 +3,16 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.graph;
 
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+
 import com.android.tools.r8.dex.Constants;
+import com.android.tools.r8.dex.VirtualFile;
 import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.graph.lens.InitClassLens;
 import com.android.tools.r8.ir.conversion.LensCodeRewriterUtils;
 import com.android.tools.r8.naming.NamingLens;
+import com.android.tools.r8.profile.startup.profile.StartupProfile;
 import com.android.tools.r8.utils.Box;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.structural.CompareToVisitor;
@@ -24,8 +28,10 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.function.Consumer;
 
 public class ObjectToOffsetMapping {
@@ -60,7 +66,7 @@
       AppView<?> appView,
       ObjectToOffsetMapping sharedMapping,
       LensCodeRewriterUtils lensCodeRewriter,
-      Collection<DexProgramClass> classes,
+      Set<DexProgramClass> classes,
       Map<DexProto, DexString> protos,
       Collection<DexType> types,
       Collection<DexMethod> methods,
@@ -69,6 +75,8 @@
       Collection<DexCallSite> callSites,
       Collection<DexMethodHandle> methodHandles,
       int lazyDexStringsCount,
+      StartupProfile startupProfile,
+      VirtualFile virtualFile,
       Timing timing) {
     assert appView != null;
     assert classes != null;
@@ -104,7 +112,10 @@
         new CompareToVisitorWithTypeTable(namingLens, this.strings::getInt, this.types::getInt);
     timing.end();
     timing.begin("Sort classes");
-    this.classes = sortClasses(classes, visitor);
+    this.classes =
+        appView.testing().enableLegacyClassDefOrdering
+            ? sortClassesLegacy(classes, visitor)
+            : sortClasses(classes, startupProfile, virtualFile, visitor);
     timing.end();
     timing.begin("Sort protos");
     this.protos = createSortedMap(protos.keySet(), compare(visitor), this::failOnOverflow);
@@ -251,8 +262,8 @@
     }
   }
 
-  private DexProgramClass[] sortClasses(
-      Collection<DexProgramClass> classes, CompareToVisitor visitor) {
+  private DexProgramClass[] sortClassesLegacy(
+      Set<DexProgramClass> classes, CompareToVisitor visitor) {
     // Collect classes in subtyping order, based on a sorted list of classes to start with.
     ProgramClassDepthsMemoized classDepths = new ProgramClassDepthsMemoized(appView.appInfo());
     DexProgramClass[] sortedClasses = classes.toArray(DexProgramClass.EMPTY_ARRAY);
@@ -266,6 +277,74 @@
     return sortedClasses;
   }
 
+  private DexProgramClass[] sortClasses(
+      Set<DexProgramClass> classes,
+      StartupProfile startupProfile,
+      VirtualFile virtualFile,
+      CompareToVisitor visitor) {
+    assert startupProfile.isEmpty() || virtualFile.getId() == 0;
+    // First sort the classes using the startup profile.
+    LinkedHashSet<DexProgramClass> sortedClasses = new LinkedHashSet<>(classes.size());
+    addClassesFromStartupProfile(classes, sortedClasses, startupProfile);
+    addRemainingClassesInSortedOrder(classes, sortedClasses, visitor);
+    assert classes.size() == sortedClasses.size();
+    return sortedClasses.toArray(DexProgramClass.EMPTY_ARRAY);
+  }
+
+  private void addClassesFromStartupProfile(
+      Set<DexProgramClass> classes,
+      LinkedHashSet<DexProgramClass> sortedClasses,
+      StartupProfile startupProfile) {
+    startupProfile.forEachRule(
+        rule -> {
+          DexType type = rule.getReference().getContextType().toBaseType(dexItemFactory());
+          if (type.isPrimitiveType()) {
+            assert false;
+            return;
+          }
+          assert type.isClassType();
+          DexProgramClass clazz = asProgramClassOrNull(appView.definitionFor(type));
+          if (clazz == null || !classes.contains(clazz)) {
+            return;
+          }
+          addClassAfterParentClasses(classes, sortedClasses, clazz);
+        });
+  }
+
+  private void addRemainingClassesInSortedOrder(
+      Set<DexProgramClass> classes,
+      LinkedHashSet<DexProgramClass> sortedClasses,
+      CompareToVisitor visitor) {
+    List<DexProgramClass> remainingClasses = new ArrayList<>(classes.size() - sortedClasses.size());
+    for (DexProgramClass clazz : classes) {
+      if (!sortedClasses.contains(clazz)) {
+        remainingClasses.add(clazz);
+      }
+    }
+    remainingClasses.sort((x, y) -> visitor.visitDexType(x.getType(), y.getType()));
+    remainingClasses.forEach(clazz -> addClassAfterParentClasses(classes, sortedClasses, clazz));
+  }
+
+  private void addClassAfterParentClasses(
+      Set<DexProgramClass> classes,
+      LinkedHashSet<DexProgramClass> sortedClasses,
+      DexProgramClass clazz) {
+    if (sortedClasses.contains(clazz)) {
+      return;
+    }
+    // Add the superclass and all implemented interfaces first, as this is required by the dex
+    // format.
+    clazz.forEachImmediateSuperClassMatching(
+        appView,
+        (supertype, superclass) ->
+            superclass != null
+                && superclass.isProgramClass()
+                && classes.contains(superclass.asProgramClass()),
+        (supertype, superclass) ->
+            addClassAfterParentClasses(classes, sortedClasses, superclass.asProgramClass()));
+    sortedClasses.add(clazz);
+  }
+
   private static <T> Collection<T> keysOrEmpty(Reference2IntLinkedOpenHashMap<T> map) {
     // The key-set is deterministic (linked) and inserted in sorted order.
     return map == null ? Collections.emptyList() : map.keySet();
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index ab04762..1ec83c8 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -2354,6 +2354,8 @@
     public boolean enableCheckCastAndInstanceOfRemoval = true;
     public boolean enableDeadSwitchCaseElimination = true;
     public boolean enableInvokeSuperToInvokeVirtualRewriting = true;
+    public boolean enableLegacyClassDefOrdering =
+        System.getProperty("com.android.tools.r8.enableLegacyClassDefOrdering") != null;
     public boolean enableMultiANewArrayDesugaringForClassFiles = false;
     public boolean enableStrictFrameVerification = false;
     public boolean enableSyntheticSharing = true;
diff --git a/src/test/java/com/android/tools/r8/dex/DebugByteCodeWriterTest.java b/src/test/java/com/android/tools/r8/dex/DebugByteCodeWriterTest.java
index 6d9774b..38365d8 100644
--- a/src/test/java/com/android/tools/r8/dex/DebugByteCodeWriterTest.java
+++ b/src/test/java/com/android/tools/r8/dex/DebugByteCodeWriterTest.java
@@ -18,6 +18,7 @@
 import com.android.tools.r8.graph.ObjectToOffsetMapping;
 import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.ir.conversion.LensCodeRewriterUtils;
+import com.android.tools.r8.profile.startup.profile.StartupProfile;
 import com.android.tools.r8.synthesis.SyntheticItems.GlobalSyntheticsStrategy;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.Reporter;
@@ -53,7 +54,7 @@
         appView,
         null,
         new LensCodeRewriterUtils(appView),
-        Collections.emptyList(),
+        Collections.emptySet(),
         Collections.emptyMap(),
         Collections.emptyList(),
         Collections.emptyList(),
@@ -62,6 +63,8 @@
         Collections.emptyList(),
         Collections.emptyList(),
         0,
+        StartupProfile.empty(),
+        null,
         Timing.empty());
   }