Make allocation hierarchy available in AppInfoWithLiveness.

Bug: 139464956
Bug: 145344105
Change-Id: I35b0ba6ba0728f995221542cf758515695085af8
diff --git a/src/main/java/com/android/tools/r8/graph/AppInfoWithSubtyping.java b/src/main/java/com/android/tools/r8/graph/AppInfoWithSubtyping.java
index 58611d4..4edc070 100644
--- a/src/main/java/com/android/tools/r8/graph/AppInfoWithSubtyping.java
+++ b/src/main/java/com/android/tools/r8/graph/AppInfoWithSubtyping.java
@@ -357,7 +357,7 @@
     return true;
   }
 
-  public boolean hasAnyInstantiatedLambdas(DexProgramClass clazz) {
+  public boolean isInstantiatedInterface(DexProgramClass clazz) {
     assert checkIfObsolete();
     return true; // Don't know, there might be.
   }
diff --git a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollection.java b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollection.java
index 2ed505e..546bf90 100644
--- a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollection.java
+++ b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollection.java
@@ -20,6 +20,12 @@
 
   boolean isInstantiatedDirectly(DexProgramClass clazz);
 
+  boolean isInstantiatedDirectlyOrHasInstantiatedSubtype(DexProgramClass clazz);
+
+  boolean isInterfaceWithUnknownSubtypeHierarchy(DexProgramClass clazz);
+
+  boolean isImmediateInterfaceOfInstantiatedLambda(DexProgramClass clazz);
+
   ObjectAllocationInfoCollection rewrittenWithLens(
       DexDefinitionSupplier definitions, GraphLense lens);
 }
diff --git a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
index 3869d23..5999bbf 100644
--- a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
+++ b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
@@ -21,28 +21,101 @@
 import java.util.Set;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
-import java.util.function.Predicate;
 
-/** Stores the set of instantiated classes along with their allocation sites. */
-public class ObjectAllocationInfoCollectionImpl implements ObjectAllocationInfoCollection {
+/**
+ * Provides information about all possibly instantiated classes and lambdas, their allocation sites,
+ * if known, as well as the full subtyping hierarchy of types above them.
+ */
+public abstract class ObjectAllocationInfoCollectionImpl implements ObjectAllocationInfoCollection {
 
-  private final Map<DexProgramClass, Set<DexEncodedMethod>> classesWithAllocationSiteTracking;
-  private final Set<DexProgramClass> classesWithoutAllocationSiteTracking;
+  /** Instantiated classes with the contexts of the instantiations. */
+  final Map<DexProgramClass, Set<DexEncodedMethod>> classesWithAllocationSiteTracking =
+      new IdentityHashMap<>();
 
-  private ObjectAllocationInfoCollectionImpl(
-      Map<DexProgramClass, Set<DexEncodedMethod>> classesWithAllocationSiteTracking,
-      Set<DexProgramClass> classesWithoutAllocationSiteTracking) {
-    this.classesWithAllocationSiteTracking = classesWithAllocationSiteTracking;
-    this.classesWithoutAllocationSiteTracking = classesWithoutAllocationSiteTracking;
+  /** Instantiated classes without contexts. */
+  final Set<DexProgramClass> classesWithoutAllocationSiteTracking = Sets.newIdentityHashSet();
+
+  /**
+   * Set of interface types for which the subtype hierarchy is unknown from that type.
+   *
+   * <p>E.g., the type is kept thus there could be instantiations of subtypes.
+   *
+   * <p>TODO(b/145344105): Generalize this to typesWithUnknownSubtypeHierarchy.
+   */
+  final Set<DexProgramClass> interfacesWithUnknownSubtypeHierarchy = Sets.newIdentityHashSet();
+
+  /** Map of types directly implemented by lambdas to those lambdas. */
+  final Map<DexType, List<LambdaDescriptor>> instantiatedLambdas = new IdentityHashMap<>();
+
+  /**
+   * Hierarchy for instantiated types mapping a type to the set of immediate subtypes for which some
+   * subtype is either an instantiated class, kept interface or is implemented by an instantiated
+   * lambda.
+   */
+  Map<DexType, Set<DexClass>> instantiatedHierarchy = new IdentityHashMap<>();
+
+  private ObjectAllocationInfoCollectionImpl() {
+    // Only builder can allocate an instance.
   }
 
   public static Builder builder(boolean trackAllocationSites, GraphReporter reporter) {
     return new Builder(trackAllocationSites, reporter);
   }
 
-  public void markNoLongerInstantiated(DexProgramClass clazz) {
-    classesWithAllocationSiteTracking.remove(clazz);
-    classesWithoutAllocationSiteTracking.remove(clazz);
+  public abstract void mutate(Consumer<Builder> mutator, AppInfo appInfo);
+
+  /**
+   * True if a class type might be instantiated directly at the given type.
+   *
+   * <p>Should not be called on interface types.
+   *
+   * <p>TODO(b/145344105): Extend this to not be called on any abstract types.
+   */
+  @Override
+  public boolean isInstantiatedDirectly(DexProgramClass clazz) {
+    if (clazz.isInterface()) {
+      return false;
+    }
+    if (classesWithAllocationSiteTracking.containsKey(clazz)) {
+      assert !classesWithAllocationSiteTracking.get(clazz).isEmpty();
+      return true;
+    }
+    return classesWithoutAllocationSiteTracking.contains(clazz);
+  }
+
+  /** True if the type or subtype of it might be instantiated. */
+  @Override
+  public boolean isInstantiatedDirectlyOrHasInstantiatedSubtype(DexProgramClass clazz) {
+    return (!clazz.isInterface() && isInstantiatedDirectly(clazz))
+        || hasInstantiatedStrictSubtype(clazz);
+  }
+
+  /** True if there might exist an instantiated (strict) subtype of the given type. */
+  public boolean hasInstantiatedStrictSubtype(DexProgramClass clazz) {
+    if (instantiatedHierarchy.get(clazz.type) != null) {
+      return true;
+    }
+    if (!clazz.isInterface()) {
+      return false;
+    }
+    return interfacesWithUnknownSubtypeHierarchy.contains(clazz)
+        || isImmediateInterfaceOfInstantiatedLambda(clazz);
+  }
+
+  /** True if the type is an interface that has unknown instantiations, eg, by being kept. */
+  @Override
+  public boolean isInterfaceWithUnknownSubtypeHierarchy(DexProgramClass clazz) {
+    return clazz.isInterface() && interfacesWithUnknownSubtypeHierarchy.contains(clazz);
+  }
+
+  /** Returns true if the type is an immediate interface of an instantiated lambda. */
+  @Override
+  public boolean isImmediateInterfaceOfInstantiatedLambda(DexProgramClass iface) {
+    return iface.isInterface() && instantiatedLambdas.get(iface.type) != null;
+  }
+
+  public Set<DexClass> getImmediateSubtypesInInstantiatedHierarchy(DexType type) {
+    return instantiatedHierarchy.get(type);
   }
 
   @Override
@@ -57,238 +130,15 @@
   }
 
   @Override
-  public boolean isInstantiatedDirectly(DexProgramClass clazz) {
-    if (classesWithAllocationSiteTracking.containsKey(clazz)) {
-      assert !classesWithAllocationSiteTracking.get(clazz).isEmpty();
-      return true;
-    }
-    return classesWithoutAllocationSiteTracking.contains(clazz);
-  }
-
-  @Override
   public ObjectAllocationInfoCollectionImpl rewrittenWithLens(
       DexDefinitionSupplier definitions, GraphLense lens) {
-    return builder(true, null).rewrittenWithLens(this, definitions, lens).build();
+    return builder(true, null).rewrittenWithLens(this, definitions, lens).build(definitions);
   }
 
-  public static class Builder {
-
-    private final boolean trackAllocationSites;
-
-    /** Instantiated classes with the contexts of the instantiations. */
-    private final Map<DexProgramClass, Set<DexEncodedMethod>> classesWithAllocationSiteTracking =
-        new IdentityHashMap<>();
-
-    /** Instantiated classes without contexts. */
-    private final Set<DexProgramClass> classesWithoutAllocationSiteTracking =
-        Sets.newIdentityHashSet();
-
-    /** Set of types directly implemented by a lambda. */
-    private final Map<DexType, List<LambdaDescriptor>> instantiatedLambdas =
-        new IdentityHashMap<>();
-
-    /**
-     * Hierarchy for instantiated types mapping a type to the set of immediate subtypes for which
-     * some subtype is either instantiated or is implemented by an instantiated lambda.
-     */
-    private final Map<DexType, Set<DexClass>> instantiatedHierarchy = new IdentityHashMap<>();
-
-    /**
-     * Set of interface types for which there may be instantiations, such as lambda expressions or
-     * explicit keep rules.
-     */
-    private final Set<DexProgramClass> instantiatedInterfaceTypes = Sets.newIdentityHashSet();
-
-    /** Subset of the above that are marked instantiated by usages that are not lambdas. */
-    public final Set<DexProgramClass> unknownInstantiatedInterfaceTypes = Sets.newIdentityHashSet();
-
-    private GraphReporter reporter;
-
-    private Builder(boolean trackAllocationSites, GraphReporter reporter) {
-      this.trackAllocationSites = trackAllocationSites;
-      this.reporter = reporter;
-    }
-
-    private boolean shouldTrackAllocationSitesForClass(
-        DexProgramClass clazz, InstantiationReason instantiationReason) {
-      if (!trackAllocationSites) {
-        return false;
-      }
-      if (instantiationReason != InstantiationReason.NEW_INSTANCE_INSTRUCTION) {
-        // There is an allocation site which is not a new-instance instruction.
-        return false;
-      }
-      if (classesWithoutAllocationSiteTracking.contains(clazz)) {
-        // We already gave up on tracking the allocation sites for `clazz` previously.
-        return false;
-      }
-      // We currently only use allocation site information for instance field value propagation.
-      return !clazz.instanceFields().isEmpty();
-    }
-
-    public boolean isInstantiatedDirectlyOrIsInstantiationLeaf(DexProgramClass clazz) {
-      if (clazz.isInterface()) {
-        return instantiatedInterfaceTypes.contains(clazz);
-      }
-      return isInstantiatedDirectly(clazz);
-    }
-
-    public boolean isInstantiatedDirectly(DexProgramClass clazz) {
-      assert !clazz.isInterface();
-      if (classesWithAllocationSiteTracking.containsKey(clazz)) {
-        assert !classesWithAllocationSiteTracking.get(clazz).isEmpty();
-        return true;
-      }
-      return classesWithoutAllocationSiteTracking.contains(clazz);
-    }
-
-    public boolean isInstantiatedDirectlyOrHasInstantiatedSubtype(DexProgramClass clazz) {
-      return isInstantiatedDirectlyOrIsInstantiationLeaf(clazz)
-          || instantiatedHierarchy.containsKey(clazz.type);
-    }
-
-    public void forEachInstantiatedSubType(
-        DexType type,
-        Consumer<DexProgramClass> onClass,
-        Consumer<LambdaDescriptor> onLambda,
-        AppInfo appInfo) {
-      internalForEachInstantiatedSubType(
-          type,
-          onClass,
-          onLambda,
-          instantiatedHierarchy,
-          instantiatedLambdas,
-          this::isInstantiatedDirectlyOrIsInstantiationLeaf,
-          appInfo);
-    }
-
-    public Set<DexClass> getImmediateSubtypesInInstantiatedHierarchy(DexType type) {
-      return instantiatedHierarchy.get(type);
-    }
-
-    /**
-     * Records that {@param clazz} is instantiated in {@param context}.
-     *
-     * @return true if {@param clazz} was not instantiated before.
-     */
-    public boolean recordDirectAllocationSite(
-        DexProgramClass clazz,
-        DexEncodedMethod context,
-        InstantiationReason instantiationReason,
-        KeepReason keepReason,
-        AppInfo appInfo) {
-      assert !clazz.isInterface();
-      if (reporter != null) {
-        reporter.registerClass(clazz, keepReason);
-      }
-      populateInstantiatedHierarchy(appInfo, clazz);
-      if (shouldTrackAllocationSitesForClass(clazz, instantiationReason)) {
-        assert context != null;
-        Set<DexEncodedMethod> allocationSitesForClass =
-            classesWithAllocationSiteTracking.computeIfAbsent(
-                clazz, ignore -> Sets.newIdentityHashSet());
-        allocationSitesForClass.add(context);
-        return allocationSitesForClass.size() == 1;
-      }
-      if (classesWithoutAllocationSiteTracking.add(clazz)) {
-        Set<DexEncodedMethod> allocationSitesForClass =
-            classesWithAllocationSiteTracking.remove(clazz);
-        return allocationSitesForClass == null;
-      }
-      return false;
-    }
-
-    public boolean recordInstantiatedInterface(DexProgramClass iface) {
-      assert iface.isInterface();
-      assert !iface.isAnnotation();
-      unknownInstantiatedInterfaceTypes.add(iface);
-      return instantiatedInterfaceTypes.add(iface);
-    }
-
-    public void recordInstantiatedLambdaInterface(
-        DexType iface, LambdaDescriptor lambda, AppInfo appInfo) {
-      instantiatedLambdas.computeIfAbsent(iface, key -> new ArrayList<>()).add(lambda);
-      populateInstantiatedHierarchy(appInfo, iface);
-    }
-
-    private void populateInstantiatedHierarchy(AppInfo appInfo, DexType type) {
-      DexClass clazz = appInfo.definitionFor(type);
-      if (clazz != null) {
-        populateInstantiatedHierarchy(appInfo, clazz);
-      }
-    }
-
-    private void populateInstantiatedHierarchy(AppInfo appInfo, DexClass clazz) {
-      if (clazz.superType != null) {
-        populateInstantiatedHierarchy(appInfo, clazz.superType, clazz);
-      }
-      for (DexType iface : clazz.interfaces.values) {
-        populateInstantiatedHierarchy(appInfo, iface, clazz);
-      }
-    }
-
-    private void populateInstantiatedHierarchy(AppInfo appInfo, DexType type, DexClass subtype) {
-      if (type == appInfo.dexItemFactory().objectType) {
-        return;
-      }
-      Set<DexClass> subtypes = instantiatedHierarchy.get(type);
-      if (subtypes != null) {
-        subtypes.add(subtype);
-        return;
-      }
-      // This is the first time an instantiation appears below 'type', recursively populate.
-      subtypes = Sets.newIdentityHashSet();
-      subtypes.add(subtype);
-      instantiatedHierarchy.put(type, subtypes);
-      populateInstantiatedHierarchy(appInfo, type);
-    }
-
-    Builder rewrittenWithLens(
-        ObjectAllocationInfoCollectionImpl objectAllocationInfos,
-        DexDefinitionSupplier definitions,
-        GraphLense lens) {
-      objectAllocationInfos.classesWithAllocationSiteTracking.forEach(
-          (clazz, allocationSitesForClass) -> {
-            DexType type = lens.lookupType(clazz.type);
-            if (type.isPrimitiveType()) {
-              return;
-            }
-            DexProgramClass rewrittenClass = asProgramClassOrNull(definitions.definitionFor(type));
-            assert rewrittenClass != null;
-            assert !classesWithAllocationSiteTracking.containsKey(rewrittenClass);
-            classesWithAllocationSiteTracking.put(
-                rewrittenClass,
-                LensUtils.rewrittenWithRenamedSignature(
-                    allocationSitesForClass, definitions, lens));
-          });
-      objectAllocationInfos.classesWithoutAllocationSiteTracking.forEach(
-          clazz -> {
-            DexType type = lens.lookupType(clazz.type);
-            if (type.isPrimitiveType()) {
-              return;
-            }
-            DexProgramClass rewrittenClass = asProgramClassOrNull(definitions.definitionFor(type));
-            assert rewrittenClass != null;
-            assert !classesWithAllocationSiteTracking.containsKey(rewrittenClass);
-            assert !classesWithoutAllocationSiteTracking.contains(rewrittenClass);
-            classesWithoutAllocationSiteTracking.add(rewrittenClass);
-          });
-      return this;
-    }
-
-    public ObjectAllocationInfoCollectionImpl build() {
-      return new ObjectAllocationInfoCollectionImpl(
-          classesWithAllocationSiteTracking, classesWithoutAllocationSiteTracking);
-    }
-  }
-
-  private static void internalForEachInstantiatedSubType(
+  public void forEachInstantiatedSubType(
       DexType type,
-      Consumer<DexProgramClass> subTypeConsumer,
-      Consumer<LambdaDescriptor> lambdaConsumer,
-      Map<DexType, Set<DexClass>> instantiatedHierarchy,
-      Map<DexType, List<LambdaDescriptor>> instantiatedLambdas,
-      Predicate<DexProgramClass> isInstantiatedDirectly,
+      Consumer<DexProgramClass> onClass,
+      Consumer<LambdaDescriptor> onLambda,
       AppInfo appInfo) {
     WorkList<DexClass> worklist = WorkList.newIdentityWorkList();
     if (type == appInfo.dexItemFactory().objectType) {
@@ -307,7 +157,7 @@
         // If no definition for the type is found, populate the worklist with any
         // instantiated subtypes and callback with any lambda instance.
         worklist.addIfNotSeen(instantiatedHierarchy.getOrDefault(type, Collections.emptySet()));
-        instantiatedLambdas.getOrDefault(type, Collections.emptyList()).forEach(lambdaConsumer);
+        instantiatedLambdas.getOrDefault(type, Collections.emptyList()).forEach(onLambda);
       } else {
         worklist.addIfNotSeen(initialClass);
       }
@@ -317,12 +167,287 @@
       DexClass clazz = worklist.next();
       if (clazz.isProgramClass()) {
         DexProgramClass programClass = clazz.asProgramClass();
-        if (isInstantiatedDirectly.test(programClass)) {
-          subTypeConsumer.accept(programClass);
+        if (isInstantiatedDirectly(programClass)
+            || isInterfaceWithUnknownSubtypeHierarchy(programClass)) {
+          onClass.accept(programClass);
         }
       }
       worklist.addIfNotSeen(instantiatedHierarchy.getOrDefault(clazz.type, Collections.emptySet()));
-      instantiatedLambdas.getOrDefault(clazz.type, Collections.emptyList()).forEach(lambdaConsumer);
+      instantiatedLambdas.getOrDefault(clazz.type, Collections.emptyList()).forEach(onLambda);
+    }
+  }
+
+  public static class Builder extends ObjectAllocationInfoCollectionImpl {
+
+    private static class Data {
+
+      private final boolean trackAllocationSites;
+      private final GraphReporter reporter;
+
+      private Data(boolean trackAllocationSites, GraphReporter reporter) {
+        this.trackAllocationSites = trackAllocationSites;
+        this.reporter = reporter;
+      }
+    }
+
+    // Pointer to data valid during the duration of the builder.
+    private Data data;
+
+    private Builder(boolean trackAllocationSites, GraphReporter reporter) {
+      data = new Data(trackAllocationSites, reporter);
+    }
+
+    public ObjectAllocationInfoCollectionImpl build(DexDefinitionSupplier definitions) {
+      assert data != null;
+      if (instantiatedHierarchy == null) {
+        repopulateInstantiatedHierarchy(definitions);
+      }
+      assert validate(definitions);
+      data = null;
+      return this;
+    }
+
+    // Consider a mutation interface that has just the mutation methods.
+    @Override
+    public void mutate(Consumer<Builder> mutator, AppInfo appInfo) {
+      mutator.accept(this);
+      repopulateInstantiatedHierarchy(appInfo);
+    }
+
+    private boolean shouldTrackAllocationSitesForClass(
+        DexProgramClass clazz, InstantiationReason instantiationReason) {
+      if (!data.trackAllocationSites) {
+        return false;
+      }
+      if (instantiationReason != InstantiationReason.NEW_INSTANCE_INSTRUCTION) {
+        // There is an allocation site which is not a new-instance instruction.
+        return false;
+      }
+      if (classesWithoutAllocationSiteTracking.contains(clazz)) {
+        // We already gave up on tracking the allocation sites for `clazz` previously.
+        return false;
+      }
+      // We currently only use allocation site information for instance field value propagation.
+      return !clazz.instanceFields().isEmpty();
+    }
+
+    /**
+     * Records that {@param clazz} is instantiated in {@param context}.
+     *
+     * @return true if {@param clazz} was not instantiated before.
+     */
+    public boolean recordDirectAllocationSite(
+        DexProgramClass clazz,
+        DexEncodedMethod context,
+        InstantiationReason instantiationReason,
+        KeepReason keepReason,
+        AppInfo appInfo) {
+      assert !clazz.isInterface();
+      if (data.reporter != null) {
+        data.reporter.registerClass(clazz, keepReason);
+      }
+      populateInstantiatedHierarchy(appInfo, clazz);
+      if (shouldTrackAllocationSitesForClass(clazz, instantiationReason)) {
+        assert context != null;
+        Set<DexEncodedMethod> allocationSitesForClass =
+            classesWithAllocationSiteTracking.computeIfAbsent(
+                clazz, ignore -> Sets.newIdentityHashSet());
+        allocationSitesForClass.add(context);
+        return allocationSitesForClass.size() == 1;
+      }
+      if (classesWithoutAllocationSiteTracking.add(clazz)) {
+        Set<DexEncodedMethod> allocationSitesForClass =
+            classesWithAllocationSiteTracking.remove(clazz);
+        return allocationSitesForClass == null;
+      }
+      return false;
+    }
+
+    public boolean recordInstantiatedInterface(DexProgramClass iface, AppInfo appInfo) {
+      assert iface.isInterface();
+      assert !iface.isAnnotation();
+      if (interfacesWithUnknownSubtypeHierarchy.add(iface)) {
+        populateInstantiatedHierarchy(appInfo, iface);
+        return true;
+      }
+      return false;
+    }
+
+    public void recordInstantiatedLambdaInterface(
+        DexType iface, LambdaDescriptor lambda, AppInfo appInfo) {
+      instantiatedLambdas.computeIfAbsent(iface, key -> new ArrayList<>()).add(lambda);
+      populateInstantiatedHierarchy(appInfo, iface);
+    }
+
+    private void repopulateInstantiatedHierarchy(DexDefinitionSupplier definitions) {
+      instantiatedHierarchy = new IdentityHashMap<>();
+      classesWithAllocationSiteTracking
+          .keySet()
+          .forEach(clazz -> populateInstantiatedHierarchy(definitions, clazz));
+      classesWithoutAllocationSiteTracking.forEach(
+          clazz -> populateInstantiatedHierarchy(definitions, clazz));
+      interfacesWithUnknownSubtypeHierarchy.forEach(
+          clazz -> populateInstantiatedHierarchy(definitions, clazz));
+      instantiatedLambdas
+          .keySet()
+          .forEach(type -> populateInstantiatedHierarchy(definitions, type));
+    }
+
+    private void populateInstantiatedHierarchy(DexDefinitionSupplier definitions, DexType type) {
+      DexClass clazz = definitions.definitionFor(type);
+      if (clazz != null) {
+        populateInstantiatedHierarchy(definitions, clazz);
+      }
+    }
+
+    private void populateInstantiatedHierarchy(DexDefinitionSupplier definitions, DexClass clazz) {
+      if (clazz.superType != null) {
+        populateInstantiatedHierarchy(definitions, clazz.superType, clazz);
+      }
+      for (DexType iface : clazz.interfaces.values) {
+        populateInstantiatedHierarchy(definitions, iface, clazz);
+      }
+    }
+
+    private void populateInstantiatedHierarchy(
+        DexDefinitionSupplier definitions, DexType type, DexClass subtype) {
+      if (type == definitions.dexItemFactory().objectType) {
+        return;
+      }
+      Set<DexClass> subtypes = instantiatedHierarchy.get(type);
+      if (subtypes != null) {
+        subtypes.add(subtype);
+        return;
+      }
+      // This is the first time an instantiation appears below 'type', recursively populate.
+      subtypes = Sets.newIdentityHashSet();
+      subtypes.add(subtype);
+      instantiatedHierarchy.put(type, subtypes);
+      populateInstantiatedHierarchy(definitions, type);
+    }
+
+    public void markNoLongerInstantiated(DexProgramClass clazz) {
+      classesWithAllocationSiteTracking.remove(clazz);
+      classesWithoutAllocationSiteTracking.remove(clazz);
+      instantiatedHierarchy = null;
+    }
+
+    Builder rewrittenWithLens(
+        ObjectAllocationInfoCollectionImpl objectAllocationInfos,
+        DexDefinitionSupplier definitions,
+        GraphLense lens) {
+      instantiatedHierarchy = null;
+      objectAllocationInfos.classesWithAllocationSiteTracking.forEach(
+          (clazz, allocationSitesForClass) -> {
+            DexType type = lens.lookupType(clazz.type);
+            if (type.isPrimitiveType()) {
+              assert !objectAllocationInfos.hasInstantiatedStrictSubtype(clazz);
+              return;
+            }
+            DexProgramClass rewrittenClass = asProgramClassOrNull(definitions.definitionFor(type));
+            assert rewrittenClass != null;
+            assert !classesWithAllocationSiteTracking.containsKey(rewrittenClass);
+            classesWithAllocationSiteTracking.put(
+                rewrittenClass,
+                LensUtils.rewrittenWithRenamedSignature(
+                    allocationSitesForClass, definitions, lens));
+          });
+      objectAllocationInfos.classesWithoutAllocationSiteTracking.forEach(
+          clazz -> {
+            DexType type = lens.lookupType(clazz.type);
+            if (type.isPrimitiveType()) {
+              assert !objectAllocationInfos.hasInstantiatedStrictSubtype(clazz);
+              return;
+            }
+            DexProgramClass rewrittenClass = asProgramClassOrNull(definitions.definitionFor(type));
+            assert rewrittenClass != null;
+            assert !classesWithAllocationSiteTracking.containsKey(rewrittenClass);
+            assert !classesWithoutAllocationSiteTracking.contains(rewrittenClass);
+            classesWithoutAllocationSiteTracking.add(rewrittenClass);
+          });
+      for (DexProgramClass abstractType :
+          objectAllocationInfos.interfacesWithUnknownSubtypeHierarchy) {
+        DexType type = lens.lookupType(abstractType.type);
+        if (type.isPrimitiveType()) {
+          assert false;
+          continue;
+        }
+        DexProgramClass rewrittenClass = asProgramClassOrNull(definitions.definitionFor(type));
+        assert rewrittenClass != null;
+        assert !interfacesWithUnknownSubtypeHierarchy.contains(rewrittenClass);
+        interfacesWithUnknownSubtypeHierarchy.add(rewrittenClass);
+      }
+      objectAllocationInfos.instantiatedLambdas.forEach(
+          (iface, lambdas) -> {
+            DexType type = lens.lookupType(iface);
+            if (type.isPrimitiveType()) {
+              assert false;
+              return;
+            }
+            assert !instantiatedLambdas.containsKey(type);
+            // TODO(b/150277553): Rewrite lambda descriptor.
+            instantiatedLambdas.put(type, lambdas);
+          });
+      return this;
+    }
+
+    // Validation that all types are linked in the instantiated hierarchy map.
+    boolean validate(DexDefinitionSupplier definitions) {
+      classesWithAllocationSiteTracking.forEach(
+          (clazz, contexts) -> {
+            assert !clazz.isInterface();
+            assert !classesWithoutAllocationSiteTracking.contains(clazz);
+            assert verifyAllSuperTypesAreInHierarchy(definitions, clazz.allImmediateSupertypes());
+          });
+      classesWithoutAllocationSiteTracking.forEach(
+          clazz -> {
+            assert !clazz.isInterface();
+            assert !classesWithAllocationSiteTracking.containsKey(clazz);
+            assert verifyAllSuperTypesAreInHierarchy(definitions, clazz.allImmediateSupertypes());
+          });
+      instantiatedLambdas.forEach(
+          (iface, lambdas) -> {
+            assert !lambdas.isEmpty();
+            DexClass definition = definitions.definitionFor(iface);
+            if (definition != null) {
+              assert definition.isInterface();
+              assert verifyAllSuperTypesAreInHierarchy(
+                  definitions, definition.allImmediateSupertypes());
+            }
+          });
+      for (DexProgramClass iface : interfacesWithUnknownSubtypeHierarchy) {
+        verifyAllSuperTypesAreInHierarchy(definitions, iface.allImmediateSupertypes());
+      }
+      instantiatedHierarchy.forEach(
+          (type, subtypes) -> {
+            assert !subtypes.isEmpty();
+            for (DexClass subtype : subtypes) {
+              assert isImmediateSuperType(type, subtype);
+            }
+          });
+      return true;
+    }
+
+    private boolean verifyAllSuperTypesAreInHierarchy(
+        DexDefinitionSupplier definitions, Iterable<DexType> dexTypes) {
+      for (DexType supertype : dexTypes) {
+        assert typeIsInHierarchy(definitions, supertype);
+      }
+      return true;
+    }
+
+    private boolean typeIsInHierarchy(DexDefinitionSupplier definitions, DexType type) {
+      return type == definitions.dexItemFactory().objectType
+          || instantiatedHierarchy.containsKey(type);
+    }
+
+    private boolean isImmediateSuperType(DexType type, DexClass subtype) {
+      for (DexType supertype : subtype.allImmediateSupertypes()) {
+        if (type == supertype) {
+          return true;
+        }
+      }
+      return false;
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/ResolutionResult.java b/src/main/java/com/android/tools/r8/graph/ResolutionResult.java
index 3de6525..283777f 100644
--- a/src/main/java/com/android/tools/r8/graph/ResolutionResult.java
+++ b/src/main/java/com/android/tools/r8/graph/ResolutionResult.java
@@ -444,7 +444,7 @@
         Consumer<DexProgramClass> lambdaInstantiatedConsumer =
             subType -> {
               subTypeConsumer.accept(subType);
-              if (appInfo.hasAnyInstantiatedLambdas(subType)) {
+              if (appInfo.isInstantiatedInterface(subType)) {
                 hasInstantiatedLambdas.set(true);
               }
             };
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index ea314bf..d5672a9 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -35,7 +35,6 @@
 import com.android.tools.r8.graph.ObjectAllocationInfoCollection;
 import com.android.tools.r8.graph.ObjectAllocationInfoCollectionImpl;
 import com.android.tools.r8.graph.PresortedComparable;
-import com.android.tools.r8.graph.ResolutionResult;
 import com.android.tools.r8.graph.ResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.code.Invoke.Type;
@@ -46,7 +45,6 @@
 import com.android.tools.r8.utils.CollectionUtils;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.PredicateSet;
-import com.android.tools.r8.utils.SetUtils;
 import com.android.tools.r8.utils.Visibility;
 import com.android.tools.r8.utils.WorkList;
 import com.google.common.collect.ImmutableSet;
@@ -192,8 +190,6 @@
   /** A map from enum types to their value types and ordinals. */
   final EnumValueInfoMapCollection enumValueInfoMaps;
 
-  final Set<DexType> instantiatedLambdas;
-
   /* A cache to improve the lookup performance of lookupSingleVirtualTarget */
   private final SingleTargetLookupCache singleTargetLookupCache = new SingleTargetLookupCache();
 
@@ -238,7 +234,6 @@
       Set<DexType> prunedTypes,
       Map<DexField, Int2ReferenceMap<DexField>> switchMaps,
       EnumValueInfoMapCollection enumValueInfoMaps,
-      Set<DexType> instantiatedLambdas,
       Set<DexType> constClassReferences,
       Map<DexType, Visibility> initClassReferences) {
     super(application);
@@ -280,7 +275,6 @@
     this.prunedTypes = prunedTypes;
     this.switchMaps = switchMaps;
     this.enumValueInfoMaps = enumValueInfoMaps;
-    this.instantiatedLambdas = instantiatedLambdas;
     this.constClassReferences = constClassReferences;
     this.initClassReferences = initClassReferences;
   }
@@ -325,7 +319,6 @@
       Set<DexType> prunedTypes,
       Map<DexField, Int2ReferenceMap<DexField>> switchMaps,
       EnumValueInfoMapCollection enumValueInfoMaps,
-      Set<DexType> instantiatedLambdas,
       Set<DexType> constClassReferences,
       Map<DexType, Visibility> initClassReferences) {
     super(appInfoWithSubtyping);
@@ -367,7 +360,6 @@
     this.prunedTypes = prunedTypes;
     this.switchMaps = switchMaps;
     this.enumValueInfoMaps = enumValueInfoMaps;
-    this.instantiatedLambdas = instantiatedLambdas;
     this.constClassReferences = constClassReferences;
     this.initClassReferences = initClassReferences;
   }
@@ -413,7 +405,6 @@
         previous.prunedTypes,
         previous.switchMaps,
         previous.enumValueInfoMaps,
-        previous.instantiatedLambdas,
         previous.constClassReferences,
         previous.initClassReferences);
     copyMetadataFromPrevious(previous);
@@ -468,7 +459,6 @@
             : CollectionUtils.mergeSets(previous.prunedTypes, removedClasses),
         previous.switchMaps,
         previous.enumValueInfoMaps,
-        previous.instantiatedLambdas,
         previous.constClassReferences,
         previous.initClassReferences);
     copyMetadataFromPrevious(previous);
@@ -484,7 +474,6 @@
     this.missingTypes = previous.missingTypes;
     this.liveTypes = previous.liveTypes;
     this.instantiatedAppServices = previous.instantiatedAppServices;
-    this.instantiatedLambdas = previous.instantiatedLambdas;
     this.targetedMethods = previous.targetedMethods;
     this.failedResolutionTargets = previous.failedResolutionTargets;
     this.bootstrapMethods = previous.bootstrapMethods;
@@ -764,8 +753,9 @@
     return objectAllocationInfoCollection;
   }
 
-  ObjectAllocationInfoCollectionImpl getMutableObjectAllocationInfoCollection() {
-    return objectAllocationInfoCollection;
+  void mutateObjectAllocationInfoCollection(
+      Consumer<ObjectAllocationInfoCollectionImpl.Builder> mutator) {
+    objectAllocationInfoCollection.mutate(mutator, this);
   }
 
   void removeFromSingleTargetLookupCache(DexClass clazz) {
@@ -793,13 +783,15 @@
     assert checkIfObsolete();
     DexType type = clazz.type;
     return type.isD8R8SynthesizedClassType()
-        || objectAllocationInfoCollection.isInstantiatedDirectly(clazz)
+        || (!clazz.isInterface() && objectAllocationInfoCollection.isInstantiatedDirectly(clazz))
+        // TODO(b/145344105): Model annotations in the object allocation info.
         || (clazz.isAnnotation() && liveTypes.contains(type));
   }
 
+  // TODO(b/145344105): Model incomplete hierarchies in the object allocation info.
   public boolean isInstantiatedIndirectly(DexProgramClass clazz) {
     assert checkIfObsolete();
-    if (hasAnyInstantiatedLambdas(clazz)) {
+    if (isInstantiatedInterface(clazz)) {
       return true;
     }
     DexType type = clazz.type;
@@ -934,9 +926,9 @@
   }
 
   @Override
-  public boolean hasAnyInstantiatedLambdas(DexProgramClass clazz) {
+  public boolean isInstantiatedInterface(DexProgramClass clazz) {
     assert checkIfObsolete();
-    return instantiatedLambdas.contains(clazz.type);
+    return objectAllocationInfoCollection.isInterfaceWithUnknownSubtypeHierarchy(clazz);
   }
 
   @Override
@@ -969,53 +961,6 @@
     return false;
   }
 
-  private boolean canVirtualMethodBeImplementedInExtraSubclass(
-      DexProgramClass clazz, DexMethod method) {
-    // For functional interfaces that are instantiated by lambdas, we may not have synthesized all
-    // the lambda classes yet, and therefore the set of subtypes for the holder may still be
-    // incomplete.
-    if (hasAnyInstantiatedLambdas(clazz)) {
-      return true;
-    }
-    // If `clazz` is kept and `method` is a library method or a library method override, then it is
-    // possible to create a class that inherits from `clazz` and overrides the library method.
-    // Similarly, if `clazz` is kept and `method` is kept directly on `clazz` or indirectly on one
-    // of its supertypes, then it is possible to create a class that inherits from `clazz` and
-    // overrides the kept method.
-    if (isPinned(clazz.type)) {
-      ResolutionResult resolutionResult = resolveMethod(clazz, method);
-      if (resolutionResult.isSingleResolution()) {
-        DexEncodedMethod resolutionTarget = resolutionResult.getSingleTarget();
-        return !resolutionTarget.isProgramMethod(this)
-            || resolutionTarget.isLibraryMethodOverride().isPossiblyTrue()
-            || isVirtualMethodPinnedDirectlyOrInAncestor(clazz, method);
-      }
-    }
-    return false;
-  }
-
-  private boolean isVirtualMethodPinnedDirectlyOrInAncestor(
-      DexProgramClass currentClass, DexMethod method) {
-    // Look in all ancestor types, including `currentClass` itself.
-    Set<DexProgramClass> visited = SetUtils.newIdentityHashSet(currentClass);
-    Deque<DexProgramClass> worklist = new ArrayDeque<>(visited);
-    while (!worklist.isEmpty()) {
-      DexClass clazz = worklist.removeFirst();
-      assert visited.contains(clazz);
-      DexEncodedMethod methodInClass = clazz.lookupVirtualMethod(method);
-      if (methodInClass != null && isPinned(methodInClass.method)) {
-        return true;
-      }
-      for (DexType superType : clazz.allImmediateSupertypes()) {
-        DexProgramClass superClass = asProgramClassOrNull(definitionFor(superType));
-        if (superClass != null && visited.add(superClass)) {
-          worklist.addLast(superClass);
-        }
-      }
-    }
-    return false;
-  }
-
   public Set<DexReference> getPinnedItems() {
     assert checkIfObsolete();
     return pinnedItems;
@@ -1030,6 +975,10 @@
       Collection<DexType> removedClasses,
       Collection<DexReference> additionalPinnedItems) {
     assert checkIfObsolete();
+    if (!removedClasses.isEmpty()) {
+      // Rebuild the hierarchy.
+      objectAllocationInfoCollection.mutate(mutator -> {}, this);
+    }
     return new AppInfoWithLiveness(this, application, removedClasses, additionalPinnedItems);
   }
 
@@ -1109,7 +1058,6 @@
         prunedTypes,
         rewriteReferenceKeys(switchMaps, lens::lookupField),
         enumValueInfoMaps.rewrittenWithLens(lens),
-        rewriteItems(instantiatedLambdas, lens::lookupType),
         rewriteItems(constClassReferences, lens::lookupType),
         rewriteReferenceKeys(initClassReferences, lens::lookupType));
   }
@@ -1238,7 +1186,7 @@
     // TODO(b/148769279): Disable lookup single target on lambda's for now.
     if (resolvedHolder.isInterface()
         && resolvedHolder.isProgramClass()
-        && hasAnyInstantiatedLambdas(resolvedHolder.asProgramClass())) {
+        && isInstantiatedInterface(resolvedHolder.asProgramClass())) {
       singleTargetLookupCache.addToCache(refinedReceiverType, method, null);
       return null;
     }
@@ -1417,9 +1365,7 @@
   }
 
   private boolean isInstantiatedOrPinned(DexProgramClass clazz) {
-    return isInstantiatedDirectly(clazz)
-        || isPinned(clazz.type)
-        || hasAnyInstantiatedLambdas(clazz);
+    return isInstantiatedDirectly(clazz) || isPinned(clazz.type) || isInstantiatedInterface(clazz);
   }
 
   public boolean isPinnedNotProgramOrLibraryOverride(DexReference reference) {
@@ -1436,7 +1382,7 @@
       DexClass clazz = definitionFor(reference.asDexType());
       return clazz == null
           || clazz.isNotProgramClass()
-          || hasAnyInstantiatedLambdas(clazz.asProgramClass());
+          || isInstantiatedInterface(clazz.asProgramClass());
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLivenessModifier.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLivenessModifier.java
index 99f2d84..0f1b348 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLivenessModifier.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLivenessModifier.java
@@ -8,7 +8,6 @@
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.FieldAccessInfoCollectionImpl;
 import com.android.tools.r8.graph.FieldAccessInfoImpl;
-import com.android.tools.r8.graph.ObjectAllocationInfoCollectionImpl;
 import com.google.common.collect.Sets;
 import java.util.Set;
 
@@ -34,12 +33,13 @@
 
   public void modify(AppInfoWithLiveness appInfo) {
     // Instantiated classes.
-    ObjectAllocationInfoCollectionImpl objectAllocationInfoCollection =
-        appInfo.getMutableObjectAllocationInfoCollection();
-    noLongerInstantiatedClasses.forEach(
-        clazz -> {
-          objectAllocationInfoCollection.markNoLongerInstantiated(clazz);
-          appInfo.removeFromSingleTargetLookupCache(clazz);
+    appInfo.mutateObjectAllocationInfoCollection(
+        mutator -> {
+          noLongerInstantiatedClasses.forEach(
+              clazz -> {
+                mutator.markNoLongerInstantiated(clazz);
+                appInfo.removeFromSingleTargetLookupCache(clazz);
+              });
         });
     // Written fields.
     FieldAccessInfoCollectionImpl fieldAccessInfoCollection =
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 05cac5f..f636d05 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -749,8 +749,6 @@
     }
 
     DexEncodedMethod contextMethod = context.getMethod();
-    markLambdaAsInstantiated(descriptor, contextMethod);
-    transitionMethodsForInstantiatedLambda(descriptor);
     if (lambdaRewriter != null) {
       assert contextMethod.getCode().isCfCode() : "Unexpected input type with lambdas";
       CfCode code = contextMethod.getCode().asCfCode();
@@ -769,6 +767,8 @@
         desugaredLambdaImplementationMethods.add(descriptor.implHandle.asMethod());
       }
     } else {
+      markLambdaAsInstantiated(descriptor, contextMethod);
+      transitionMethodsForInstantiatedLambda(descriptor);
       callSites.add(callSite);
     }
 
@@ -1888,7 +1888,7 @@
   void markInterfaceAsInstantiated(DexProgramClass clazz, KeepReasonWitness witness) {
     assert !clazz.isAnnotation();
     assert clazz.isInterface();
-    if (!objectAllocationInfoCollection.recordInstantiatedInterface(clazz)) {
+    if (!objectAllocationInfoCollection.recordInstantiatedInterface(clazz, appInfo)) {
       return;
     }
     markTypeAsLive(clazz, witness);
@@ -1900,12 +1900,11 @@
     for (DexType iface : descriptor.interfaces) {
       checkLambdaInterface(iface, context);
       objectAllocationInfoCollection.recordInstantiatedLambdaInterface(iface, descriptor, appInfo);
-      // TODO(b/150277553): Lambdas should be accurately traces and thus not be added here.
-      if (lambdaRewriter == null) {
-        DexProgramClass clazz = getProgramClassOrNull(iface);
-        if (clazz != null) {
-          objectAllocationInfoCollection.recordInstantiatedInterface(clazz);
-        }
+      // TODO(b/150277553): Lambdas should be accurately traced and thus not be added here.
+      assert lambdaRewriter == null;
+      DexProgramClass clazz = getProgramClassOrNull(iface);
+      if (clazz != null) {
+        objectAllocationInfoCollection.recordInstantiatedInterface(clazz, appInfo);
       }
     }
   }
@@ -2833,7 +2832,7 @@
             toSortedDescriptorSet(liveMethods.getItems()),
             // Filter out library fields and pinned fields, because these are read by default.
             fieldAccessInfoCollection,
-            objectAllocationInfoCollection.build(),
+            objectAllocationInfoCollection.build(appInfo),
             // TODO(b/132593519): Do we require these sets to be sorted for determinism?
             toImmutableSortedMap(virtualInvokes, PresortedComparable::slowCompare),
             toImmutableSortedMap(interfaceInvokes, PresortedComparable::slowCompare),
@@ -2861,10 +2860,6 @@
             Collections.emptySet(),
             Collections.emptyMap(),
             EnumValueInfoMapCollection.empty(),
-            // TODO(b/150277553): Remove this once object allocation contains the information.
-            SetUtils.mapIdentityHashSet(
-                objectAllocationInfoCollection.unknownInstantiatedInterfaceTypes,
-                DexProgramClass::getType),
             constClassReferences,
             initClassReferences);
     appInfo.markObsolete();
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index 8ec0dac..2fb7435 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -29,6 +29,7 @@
 import com.android.tools.r8.graph.GraphLense.GraphLenseLookupResult;
 import com.android.tools.r8.graph.LookupResult.LookupResultSuccess;
 import com.android.tools.r8.graph.MethodAccessFlags;
+import com.android.tools.r8.graph.ObjectAllocationInfoCollection;
 import com.android.tools.r8.graph.ParameterAnnotationsList;
 import com.android.tools.r8.graph.ResolutionResult;
 import com.android.tools.r8.graph.RewrittenPrototypeDescription;
@@ -332,9 +333,9 @@
   private boolean isMergeCandidate(
       DexProgramClass sourceClass, DexProgramClass targetClass, Set<DexType> pinnedTypes) {
     assert targetClass != null;
-
-    if (appInfo.getObjectAllocationInfoCollection().isInstantiatedDirectly(sourceClass)
-        || appInfo.instantiatedLambdas.contains(sourceClass.type)
+    ObjectAllocationInfoCollection allocationInfo = appInfo.getObjectAllocationInfoCollection();
+    if (allocationInfo.isInstantiatedDirectly(sourceClass)
+        || allocationInfo.isInterfaceWithUnknownSubtypeHierarchy(sourceClass)
         || appInfo.isPinned(sourceClass.type)
         || pinnedTypes.contains(sourceClass.type)
         || appInfo.neverMerge.contains(sourceClass.type)) {