Move instantiation information to ObjectAllocationInfoCollectionImpl builder.

Bug: 150277553, 139464956
Change-Id: Ib463841bee720869842de445145aa0c9544ca209
diff --git a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
index 53510bd..3869d23 100644
--- a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
+++ b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
@@ -6,15 +6,22 @@
 
 import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
 
+import com.android.tools.r8.ir.desugar.LambdaDescriptor;
 import com.android.tools.r8.shaking.GraphReporter;
 import com.android.tools.r8.shaking.InstantiationReason;
 import com.android.tools.r8.shaking.KeepReason;
 import com.android.tools.r8.utils.LensUtils;
+import com.android.tools.r8.utils.WorkList;
 import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.IdentityHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
 
 /** Stores the set of instantiated classes along with their allocation sites. */
 public class ObjectAllocationInfoCollectionImpl implements ObjectAllocationInfoCollection {
@@ -68,11 +75,33 @@
 
     private final boolean trackAllocationSites;
 
+    /** Instantiated classes with the contexts of the instantiations. */
     private final Map<DexProgramClass, Set<DexEncodedMethod>> classesWithAllocationSiteTracking =
         new IdentityHashMap<>();
+
+    /** Instantiated classes without contexts. */
     private final Set<DexProgramClass> classesWithoutAllocationSiteTracking =
         Sets.newIdentityHashSet();
 
+    /** Set of types directly implemented by a lambda. */
+    private final Map<DexType, List<LambdaDescriptor>> instantiatedLambdas =
+        new IdentityHashMap<>();
+
+    /**
+     * Hierarchy for instantiated types mapping a type to the set of immediate subtypes for which
+     * some subtype is either instantiated or is implemented by an instantiated lambda.
+     */
+    private final Map<DexType, Set<DexClass>> instantiatedHierarchy = new IdentityHashMap<>();
+
+    /**
+     * Set of interface types for which there may be instantiations, such as lambda expressions or
+     * explicit keep rules.
+     */
+    private final Set<DexProgramClass> instantiatedInterfaceTypes = Sets.newIdentityHashSet();
+
+    /** Subset of the above that are marked instantiated by usages that are not lambdas. */
+    public final Set<DexProgramClass> unknownInstantiatedInterfaceTypes = Sets.newIdentityHashSet();
+
     private GraphReporter reporter;
 
     private Builder(boolean trackAllocationSites, GraphReporter reporter) {
@@ -97,7 +126,15 @@
       return !clazz.instanceFields().isEmpty();
     }
 
+    public boolean isInstantiatedDirectlyOrIsInstantiationLeaf(DexProgramClass clazz) {
+      if (clazz.isInterface()) {
+        return instantiatedInterfaceTypes.contains(clazz);
+      }
+      return isInstantiatedDirectly(clazz);
+    }
+
     public boolean isInstantiatedDirectly(DexProgramClass clazz) {
+      assert !clazz.isInterface();
       if (classesWithAllocationSiteTracking.containsKey(clazz)) {
         assert !classesWithAllocationSiteTracking.get(clazz).isEmpty();
         return true;
@@ -105,6 +142,30 @@
       return classesWithoutAllocationSiteTracking.contains(clazz);
     }
 
+    public boolean isInstantiatedDirectlyOrHasInstantiatedSubtype(DexProgramClass clazz) {
+      return isInstantiatedDirectlyOrIsInstantiationLeaf(clazz)
+          || instantiatedHierarchy.containsKey(clazz.type);
+    }
+
+    public void forEachInstantiatedSubType(
+        DexType type,
+        Consumer<DexProgramClass> onClass,
+        Consumer<LambdaDescriptor> onLambda,
+        AppInfo appInfo) {
+      internalForEachInstantiatedSubType(
+          type,
+          onClass,
+          onLambda,
+          instantiatedHierarchy,
+          instantiatedLambdas,
+          this::isInstantiatedDirectlyOrIsInstantiationLeaf,
+          appInfo);
+    }
+
+    public Set<DexClass> getImmediateSubtypesInInstantiatedHierarchy(DexType type) {
+      return instantiatedHierarchy.get(type);
+    }
+
     /**
      * Records that {@param clazz} is instantiated in {@param context}.
      *
@@ -114,11 +175,13 @@
         DexProgramClass clazz,
         DexEncodedMethod context,
         InstantiationReason instantiationReason,
-        KeepReason keepReason) {
+        KeepReason keepReason,
+        AppInfo appInfo) {
       assert !clazz.isInterface();
       if (reporter != null) {
         reporter.registerClass(clazz, keepReason);
       }
+      populateInstantiatedHierarchy(appInfo, clazz);
       if (shouldTrackAllocationSitesForClass(clazz, instantiationReason)) {
         assert context != null;
         Set<DexEncodedMethod> allocationSitesForClass =
@@ -135,6 +198,51 @@
       return false;
     }
 
+    public boolean recordInstantiatedInterface(DexProgramClass iface) {
+      assert iface.isInterface();
+      assert !iface.isAnnotation();
+      unknownInstantiatedInterfaceTypes.add(iface);
+      return instantiatedInterfaceTypes.add(iface);
+    }
+
+    public void recordInstantiatedLambdaInterface(
+        DexType iface, LambdaDescriptor lambda, AppInfo appInfo) {
+      instantiatedLambdas.computeIfAbsent(iface, key -> new ArrayList<>()).add(lambda);
+      populateInstantiatedHierarchy(appInfo, iface);
+    }
+
+    private void populateInstantiatedHierarchy(AppInfo appInfo, DexType type) {
+      DexClass clazz = appInfo.definitionFor(type);
+      if (clazz != null) {
+        populateInstantiatedHierarchy(appInfo, clazz);
+      }
+    }
+
+    private void populateInstantiatedHierarchy(AppInfo appInfo, DexClass clazz) {
+      if (clazz.superType != null) {
+        populateInstantiatedHierarchy(appInfo, clazz.superType, clazz);
+      }
+      for (DexType iface : clazz.interfaces.values) {
+        populateInstantiatedHierarchy(appInfo, iface, clazz);
+      }
+    }
+
+    private void populateInstantiatedHierarchy(AppInfo appInfo, DexType type, DexClass subtype) {
+      if (type == appInfo.dexItemFactory().objectType) {
+        return;
+      }
+      Set<DexClass> subtypes = instantiatedHierarchy.get(type);
+      if (subtypes != null) {
+        subtypes.add(subtype);
+        return;
+      }
+      // This is the first time an instantiation appears below 'type', recursively populate.
+      subtypes = Sets.newIdentityHashSet();
+      subtypes.add(subtype);
+      instantiatedHierarchy.put(type, subtypes);
+      populateInstantiatedHierarchy(appInfo, type);
+    }
+
     Builder rewrittenWithLens(
         ObjectAllocationInfoCollectionImpl objectAllocationInfos,
         DexDefinitionSupplier definitions,
@@ -173,4 +281,48 @@
           classesWithAllocationSiteTracking, classesWithoutAllocationSiteTracking);
     }
   }
+
+  private static void internalForEachInstantiatedSubType(
+      DexType type,
+      Consumer<DexProgramClass> subTypeConsumer,
+      Consumer<LambdaDescriptor> lambdaConsumer,
+      Map<DexType, Set<DexClass>> instantiatedHierarchy,
+      Map<DexType, List<LambdaDescriptor>> instantiatedLambdas,
+      Predicate<DexProgramClass> isInstantiatedDirectly,
+      AppInfo appInfo) {
+    WorkList<DexClass> worklist = WorkList.newIdentityWorkList();
+    if (type == appInfo.dexItemFactory().objectType) {
+      // All types are below java.lang.Object, but we don't maintain an entry for it.
+      instantiatedHierarchy.forEach(
+          (key, subtypes) -> {
+            DexClass clazz = appInfo.definitionFor(key);
+            if (clazz != null) {
+              worklist.addIfNotSeen(clazz);
+            }
+            worklist.addIfNotSeen(subtypes);
+          });
+    } else {
+      DexClass initialClass = appInfo.definitionFor(type);
+      if (initialClass == null) {
+        // If no definition for the type is found, populate the worklist with any
+        // instantiated subtypes and callback with any lambda instance.
+        worklist.addIfNotSeen(instantiatedHierarchy.getOrDefault(type, Collections.emptySet()));
+        instantiatedLambdas.getOrDefault(type, Collections.emptyList()).forEach(lambdaConsumer);
+      } else {
+        worklist.addIfNotSeen(initialClass);
+      }
+    }
+
+    while (worklist.hasNext()) {
+      DexClass clazz = worklist.next();
+      if (clazz.isProgramClass()) {
+        DexProgramClass programClass = clazz.asProgramClass();
+        if (isInstantiatedDirectly.test(programClass)) {
+          subTypeConsumer.accept(programClass);
+        }
+      }
+      worklist.addIfNotSeen(instantiatedHierarchy.getOrDefault(clazz.type, Collections.emptySet()));
+      instantiatedLambdas.getOrDefault(clazz.type, Collections.emptyList()).forEach(lambdaConsumer);
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index b551b8e..818237c 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -124,7 +124,6 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.function.BiConsumer;
-import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Predicate;
 import org.objectweb.asm.Opcodes;
@@ -263,23 +262,6 @@
    */
   private final Set<DexType> instantiatedAppServices = Sets.newIdentityHashSet();
 
-  /** Set of types directly implemented by a lambda. */
-  private final Map<DexType, List<LambdaDescriptor>> instantiatedLambdas = new IdentityHashMap<>();
-
-  /**
-   * Hierarchy for instantiated types mapping a type to the set of direct subtypes for which some
-   * subtype is instantiated or implemented by a lambda.
-   */
-  private final Map<DexType, Set<DexClass>> instantiatedHierarchy = new IdentityHashMap<>();
-
-  /**
-   * Set of interface types for which there may be instantiations, such as lambda expressions or
-   * explicit keep rules.
-   */
-  private final Set<DexProgramClass> instantiatedInterfaceTypes;
-  /** Subset of the above that are marked instantiated by usages that are not desugared lambdas. */
-  private final SetWithReason<DexProgramClass> unknownInstantiatedInterfaceTypes;
-
   /** A queue of items that need processing. Different items trigger different actions. */
   private final EnqueuerWorklist workList;
 
@@ -372,8 +354,6 @@
     failedResolutionTargets = SetUtils.newIdentityHashSet(2);
     liveMethods = new LiveMethodsSet(graphReporter::registerMethod);
     liveFields = new SetWithReason<>(graphReporter::registerField);
-    unknownInstantiatedInterfaceTypes = new SetWithReason<>(graphReporter::registerInterface);
-    instantiatedInterfaceTypes = Sets.newIdentityHashSet();
     lambdaRewriter = options.desugarState == DesugarState.ON ? new LambdaRewriter(appView) : null;
 
     objectAllocationInfoCollection =
@@ -1664,42 +1644,36 @@
       InstantiationReason instantiationReason,
       KeepReason keepReason) {
     assert !clazz.isInterface();
-    if (!objectAllocationInfoCollection.recordDirectAllocationSite(
-        clazz, context, instantiationReason, keepReason)) {
-      return false;
-    }
-    populateInstantiatedHierarchy(clazz);
-    return true;
+    return objectAllocationInfoCollection.recordDirectAllocationSite(
+        clazz, context, instantiationReason, keepReason, appInfo);
   }
 
   void markInterfaceAsInstantiated(DexProgramClass clazz, KeepReasonWitness witness) {
     assert !clazz.isAnnotation();
     assert clazz.isInterface();
-    if (!instantiatedInterfaceTypes.add(clazz)) {
+    if (!objectAllocationInfoCollection.recordInstantiatedInterface(clazz)) {
       return;
     }
-    unknownInstantiatedInterfaceTypes.add(clazz, witness);
     markTypeAsLive(clazz, witness);
-    populateInstantiatedHierarchy(clazz);
     transitionDependentItemsForInstantiatedInterface(clazz);
   }
 
   private void markLambdaAsInstantiated(LambdaDescriptor descriptor, DexEncodedMethod context) {
     // Each descriptor is unique, so there is no check for already marking the lambda.
     for (DexType iface : descriptor.interfaces) {
-      instantiatedLambdas.computeIfAbsent(iface, key -> new ArrayList<>()).add(descriptor);
-      DexClass clazz = definitionForLambdaInterface(iface, context);
-      if (clazz != null) {
-        if (lambdaRewriter == null && clazz.isProgramClass()) {
-          unknownInstantiatedInterfaceTypes.add(
-              clazz.asProgramClass(), KeepReason.instantiatedIn(context));
+      checkLambdaInterface(iface, context);
+      objectAllocationInfoCollection.recordInstantiatedLambdaInterface(iface, descriptor, appInfo);
+      // TODO(b/150277553): Lambdas should be accurately traces and thus not be added here.
+      if (lambdaRewriter == null) {
+        DexProgramClass clazz = getProgramClassOrNull(iface);
+        if (clazz != null) {
+          objectAllocationInfoCollection.recordInstantiatedInterface(clazz);
         }
-        populateInstantiatedHierarchy(clazz);
       }
     }
   }
 
-  private DexClass definitionForLambdaInterface(DexType itf, DexEncodedMethod context) {
+  private void checkLambdaInterface(DexType itf, DexEncodedMethod context) {
     DexClass clazz = definitionFor(itf);
     if (clazz == null) {
       StringDiagnostic message =
@@ -1717,35 +1691,6 @@
               appInfo.originFor(context.method.holder));
       options.reporter.warning(message);
     }
-    return clazz;
-  }
-
-  private void populateInstantiatedHierarchy(DexClass clazz) {
-    if (clazz.superType != null) {
-      populateInstantiatedHierarchy(clazz.superType, clazz);
-    }
-    for (DexType iface : clazz.interfaces.values) {
-      populateInstantiatedHierarchy(iface, clazz);
-    }
-  }
-
-  private void populateInstantiatedHierarchy(DexType type, DexClass subtype) {
-    if (type == appInfo.dexItemFactory().objectType) {
-      return;
-    }
-    Set<DexClass> subtypes = instantiatedHierarchy.get(type);
-    if (subtypes != null) {
-      subtypes.add(subtype);
-      return;
-    }
-    // This is the first time an instantiation appears below 'type', recursively populate.
-    subtypes = Sets.newIdentityHashSet();
-    subtypes.add(subtype);
-    instantiatedHierarchy.put(type, subtypes);
-    DexClass clazz = definitionFor(type);
-    if (clazz != null) {
-      populateInstantiatedHierarchy(clazz);
-    }
   }
 
   private void transitionMethodsForInstantiatedLambda(LambdaDescriptor lambda) {
@@ -2112,16 +2057,6 @@
         : info.isWritten();
   }
 
-  private boolean isInstantiatedDirectly(DexProgramClass clazz) {
-    return clazz.isInterface()
-        ? instantiatedInterfaceTypes.contains(clazz)
-        : objectAllocationInfoCollection.isInstantiatedDirectly(clazz);
-  }
-
-  private boolean isInstantiatedDirectlyOrHasInstantiatedSubtype(DexProgramClass clazz) {
-    return isInstantiatedDirectly(clazz) || instantiatedHierarchy.containsKey(clazz.type);
-  }
-
   public boolean isMethodLive(DexEncodedMethod method) {
     return liveMethods.contains(method);
   }
@@ -2158,7 +2093,7 @@
     if (encodedField.accessFlags.isStatic()) {
       markStaticFieldAsLive(encodedField, reason);
     } else {
-      if (isInstantiatedDirectlyOrHasInstantiatedSubtype(clazz)) {
+      if (objectAllocationInfoCollection.isInstantiatedDirectlyOrHasInstantiatedSubtype(clazz)) {
         markInstanceFieldAsLive(clazz, encodedField, reason);
       } else {
         // Add the field to the reachable set if the type later becomes instantiated.
@@ -2258,7 +2193,12 @@
 
     resolution
         .lookupVirtualDispatchTargets(
-            context, appInfo, this::forEachInstantiatedSubType, pinnedItems::contains)
+            context,
+            appInfo,
+            (type, subTypeConsumer, lambdaConsumer) ->
+                objectAllocationInfoCollection.forEachInstantiatedSubType(
+                    type, subTypeConsumer, lambdaConsumer, appInfo),
+            pinnedItems::contains)
         .forEach(
             target ->
                 markVirtualDispatchTargetAsLive(
@@ -2298,46 +2238,6 @@
     }
   }
 
-  private void forEachInstantiatedSubType(
-      DexType type,
-      Consumer<DexProgramClass> subTypeConsumer,
-      Consumer<LambdaDescriptor> lambdaConsumer) {
-    WorkList<DexClass> worklist = WorkList.newIdentityWorkList();
-    if (type == appInfo.dexItemFactory().objectType) {
-      // All types are below java.lang.Object, but we don't maintain an entry for it.
-      instantiatedHierarchy.forEach(
-          (key, subtypes) -> {
-            DexClass clazz = definitionFor(key);
-            if (clazz != null) {
-              worklist.addIfNotSeen(clazz);
-            }
-            worklist.addIfNotSeen(subtypes);
-          });
-    } else {
-      DexClass initialClass = definitionFor(type);
-      if (initialClass == null) {
-        // If no definition for the type is found, populate the worklist with any
-        // instantiated subtypes and callback with any lambda instance.
-        worklist.addIfNotSeen(instantiatedHierarchy.getOrDefault(type, Collections.emptySet()));
-        instantiatedLambdas.getOrDefault(type, Collections.emptyList()).forEach(lambdaConsumer);
-      } else {
-        worklist.addIfNotSeen(initialClass);
-      }
-    }
-
-    while (worklist.hasNext()) {
-      DexClass clazz = worklist.next();
-      if (clazz.isProgramClass()) {
-        DexProgramClass programClass = clazz.asProgramClass();
-        if (isInstantiatedDirectly(programClass)) {
-          subTypeConsumer.accept(programClass);
-        }
-      }
-      worklist.addIfNotSeen(instantiatedHierarchy.getOrDefault(clazz.type, Collections.emptySet()));
-      instantiatedLambdas.getOrDefault(clazz.type, Collections.emptyList()).forEach(lambdaConsumer);
-    }
-  }
-
   private void markFailedResolutionTargets(
       DexMethod symbolicMethod, FailedResolutionResult failedResolution, KeepReason reason) {
     failedResolutionTargets.add(symbolicMethod);
@@ -2581,8 +2481,10 @@
             Collections.emptySet(),
             Collections.emptyMap(),
             EnumValueInfoMapCollection.empty(),
+            // TODO(b/150277553): Remove this once object allocation contains the information.
             SetUtils.mapIdentityHashSet(
-                unknownInstantiatedInterfaceTypes.getItems(), DexProgramClass::getType),
+                objectAllocationInfoCollection.unknownInstantiatedInterfaceTypes,
+                DexProgramClass::getType),
             constClassReferences);
     appInfo.markObsolete();
     return appInfoWithLiveness;
@@ -2650,7 +2552,8 @@
           wrapper,
           null,
           InstantiationReason.SYNTHESIZED_CLASS,
-          graphReporter.fakeReportShouldNotBeUsed());
+          graphReporter.fakeReportShouldNotBeUsed(),
+          appInfo);
       // Mark all methods on the wrapper as live and targeted.
       for (DexEncodedMethod method : wrapper.methods()) {
         targetedMethods.add(method, graphReporter.fakeReportShouldNotBeUsed());
@@ -2689,7 +2592,8 @@
           programClass,
           null,
           InstantiationReason.SYNTHESIZED_CLASS,
-          graphReporter.fakeReportShouldNotBeUsed());
+          graphReporter.fakeReportShouldNotBeUsed(),
+          appInfo);
 
       // Register all of the field writes in the lambda constructors.
       // This is needed to ensure that the initializers can be optimized.
@@ -3062,7 +2966,8 @@
   }
 
   private Set<DexProgramClass> getImmediateSubtypesInInstantiatedHierarchy(DexProgramClass clazz) {
-    Set<DexClass> subtypes = instantiatedHierarchy.get(clazz.type);
+    Set<DexClass> subtypes =
+        objectAllocationInfoCollection.getImmediateSubtypesInInstantiatedHierarchy(clazz.type);
     if (subtypes == null) {
       return Collections.emptySet();
     }