Towards tracking the allocation sites for instantiated classes

This CL replaces the set of instantiated types by a new ObjectAllocationInfoCollection.

In addition to maintaing the set of instantiated types, the new ObjectAllocationInfoCollection also supports getting the allocation sites for each instantiated type, if all allocation sites are known.

Bug: 147799448
Change-Id: I833145285436532328f0988b78a2877ab779cf17
diff --git a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollection.java b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollection.java
new file mode 100644
index 0000000..2ed505e
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollection.java
@@ -0,0 +1,25 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.graph;
+
+import java.util.Set;
+import java.util.function.BiConsumer;
+
+/**
+ * Provides immutable access to {@link ObjectAllocationInfoCollectionImpl}, which stores the set of
+ * instantiated classes along with their allocation sites.
+ */
+public interface ObjectAllocationInfoCollection {
+
+  void forEachClassWithKnownAllocationSites(
+      BiConsumer<DexProgramClass, Set<DexEncodedMethod>> consumer);
+
+  boolean isAllocationSitesKnown(DexProgramClass clazz);
+
+  boolean isInstantiatedDirectly(DexProgramClass clazz);
+
+  ObjectAllocationInfoCollection rewrittenWithLens(
+      DexDefinitionSupplier definitions, GraphLense lens);
+}
diff --git a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
new file mode 100644
index 0000000..fa13c74
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
@@ -0,0 +1,165 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.graph;
+
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+
+import com.android.tools.r8.shaking.GraphReporter;
+import com.android.tools.r8.shaking.InstantiationReason;
+import com.android.tools.r8.shaking.KeepReason;
+import com.android.tools.r8.utils.LensUtils;
+import com.google.common.collect.Sets;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.BiConsumer;
+
+/** Stores the set of instantiated classes along with their allocation sites. */
+public class ObjectAllocationInfoCollectionImpl implements ObjectAllocationInfoCollection {
+
+  private final Map<DexProgramClass, Set<DexEncodedMethod>> classesWithAllocationSiteTracking;
+  private final Set<DexProgramClass> classesWithoutAllocationSiteTracking;
+
+  private ObjectAllocationInfoCollectionImpl(
+      Map<DexProgramClass, Set<DexEncodedMethod>> classesWithAllocationSiteTracking,
+      Set<DexProgramClass> classesWithoutAllocationSiteTracking) {
+    this.classesWithAllocationSiteTracking = classesWithAllocationSiteTracking;
+    this.classesWithoutAllocationSiteTracking = classesWithoutAllocationSiteTracking;
+  }
+
+  public static Builder builder(boolean trackAllocationSites, GraphReporter reporter) {
+    return new Builder(trackAllocationSites, reporter);
+  }
+
+  @Override
+  public void forEachClassWithKnownAllocationSites(
+      BiConsumer<DexProgramClass, Set<DexEncodedMethod>> consumer) {
+    classesWithAllocationSiteTracking.forEach(consumer);
+  }
+
+  @Override
+  public boolean isAllocationSitesKnown(DexProgramClass clazz) {
+    return classesWithAllocationSiteTracking.containsKey(clazz);
+  }
+
+  @Override
+  public boolean isInstantiatedDirectly(DexProgramClass clazz) {
+    if (classesWithAllocationSiteTracking.containsKey(clazz)) {
+      assert !classesWithAllocationSiteTracking.get(clazz).isEmpty();
+      return true;
+    }
+    return classesWithoutAllocationSiteTracking.contains(clazz);
+  }
+
+  @Override
+  public ObjectAllocationInfoCollectionImpl rewrittenWithLens(
+      DexDefinitionSupplier definitions, GraphLense lens) {
+    return builder(true, null).rewrittenWithLens(this, definitions, lens).build();
+  }
+
+  public static class Builder {
+
+    private final boolean trackAllocationSites;
+
+    private final Map<DexProgramClass, Set<DexEncodedMethod>> classesWithAllocationSiteTracking =
+        new IdentityHashMap<>();
+    private final Set<DexProgramClass> classesWithoutAllocationSiteTracking =
+        Sets.newIdentityHashSet();
+
+    private GraphReporter reporter;
+
+    private Builder(boolean trackAllocationSites, GraphReporter reporter) {
+      this.trackAllocationSites = trackAllocationSites;
+      this.reporter = reporter;
+    }
+
+    private boolean shouldTrackAllocationSitesForClass(
+        DexProgramClass clazz, InstantiationReason instantiationReason) {
+      if (!trackAllocationSites) {
+        return false;
+      }
+      if (instantiationReason != InstantiationReason.NEW_INSTANCE_INSTRUCTION) {
+        // There is an allocation site which is not a new-instance instruction.
+        return false;
+      }
+      if (classesWithoutAllocationSiteTracking.contains(clazz)) {
+        // We already gave up on tracking the allocation sites for `clazz` previously.
+        return false;
+      }
+      // We currently only use allocation site information for instance field value propagation.
+      return !clazz.instanceFields().isEmpty();
+    }
+
+    public boolean isInstantiatedDirectly(DexProgramClass clazz) {
+      if (classesWithAllocationSiteTracking.containsKey(clazz)) {
+        assert !classesWithAllocationSiteTracking.get(clazz).isEmpty();
+        return true;
+      }
+      return classesWithoutAllocationSiteTracking.contains(clazz);
+    }
+
+    /**
+     * Records that {@param clazz} is instantiated in {@param context}.
+     *
+     * @return true if {@param clazz} was not instantiated before.
+     */
+    public boolean recordDirectAllocationSite(
+        DexProgramClass clazz,
+        DexEncodedMethod context,
+        InstantiationReason instantiationReason,
+        KeepReason keepReason) {
+      assert !clazz.isInterface();
+      if (reporter != null) {
+        reporter.registerClass(clazz, keepReason);
+      }
+      if (shouldTrackAllocationSitesForClass(clazz, instantiationReason)) {
+        assert context != null;
+        Set<DexEncodedMethod> allocationSitesForClass =
+            classesWithAllocationSiteTracking.computeIfAbsent(
+                clazz, ignore -> Sets.newIdentityHashSet());
+        allocationSitesForClass.add(context);
+        return allocationSitesForClass.size() == 1;
+      }
+      if (classesWithoutAllocationSiteTracking.add(clazz)) {
+        Set<DexEncodedMethod> allocationSitesForClass =
+            classesWithAllocationSiteTracking.remove(clazz);
+        return allocationSitesForClass == null;
+      }
+      return false;
+    }
+
+    Builder rewrittenWithLens(
+        ObjectAllocationInfoCollectionImpl objectAllocationInfos,
+        DexDefinitionSupplier definitions,
+        GraphLense lens) {
+      objectAllocationInfos.classesWithAllocationSiteTracking.forEach(
+          (clazz, allocationSitesForClass) -> {
+            DexProgramClass rewrittenClass =
+                asProgramClassOrNull(definitions.definitionFor(lens.lookupType(clazz.type)));
+            assert rewrittenClass != null;
+            assert !classesWithAllocationSiteTracking.containsKey(rewrittenClass);
+            classesWithAllocationSiteTracking.put(
+                rewrittenClass,
+                LensUtils.rewrittenWithRenamedSignature(
+                    allocationSitesForClass, definitions, lens));
+          });
+      objectAllocationInfos.classesWithoutAllocationSiteTracking.forEach(
+          clazz -> {
+            DexProgramClass rewrittenClass =
+                asProgramClassOrNull(definitions.definitionFor(lens.lookupType(clazz.type)));
+            assert rewrittenClass != null;
+            assert !classesWithAllocationSiteTracking.containsKey(rewrittenClass);
+            assert !classesWithoutAllocationSiteTracking.contains(rewrittenClass);
+            classesWithoutAllocationSiteTracking.add(rewrittenClass);
+          });
+      return this;
+    }
+
+    public ObjectAllocationInfoCollectionImpl build() {
+      return new ObjectAllocationInfoCollectionImpl(
+          classesWithAllocationSiteTracking, classesWithoutAllocationSiteTracking);
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index 5cc45e2..69ed3b0 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -26,6 +26,8 @@
 import com.android.tools.r8.graph.FieldAccessInfoImpl;
 import com.android.tools.r8.graph.GraphLense;
 import com.android.tools.r8.graph.GraphLense.NestedGraphLense;
+import com.android.tools.r8.graph.ObjectAllocationInfoCollection;
+import com.android.tools.r8.graph.ObjectAllocationInfoCollectionImpl;
 import com.android.tools.r8.graph.PresortedComparable;
 import com.android.tools.r8.graph.ResolutionResult;
 import com.android.tools.r8.ir.analysis.type.ClassTypeLatticeElement;
@@ -74,8 +76,6 @@
    * ServiceLoader.load() or ServiceLoader.loadInstalled().
    */
   public final Set<DexType> instantiatedAppServices;
-  /** Set of types that are actually instantiated. These cannot be abstract. */
-  final Set<DexType> instantiatedTypes;
   /** Cache for {@link #isInstantiatedDirectlyOrIndirectly(DexProgramClass)}. */
   private final IdentityHashMap<DexType, Boolean> indirectlyInstantiatedTypes =
       new IdentityHashMap<>();
@@ -109,6 +109,8 @@
    * each field. The latter is used, for example, during member rebinding.
    */
   private final FieldAccessInfoCollectionImpl fieldAccessInfoCollection;
+  /** Information about instantiated classes and their allocation sites. */
+  private final ObjectAllocationInfoCollectionImpl objectAllocationInfoCollection;
   /** Set of all methods referenced in virtual invokes, along with calling context. */
   public final SortedMap<DexMethod, Set<DexEncodedMethod>> virtualInvokes;
   /** Set of all methods referenced in interface invokes, along with calling context. */
@@ -193,7 +195,6 @@
       Set<DexType> missingTypes,
       Set<DexType> liveTypes,
       Set<DexType> instantiatedAppServices,
-      Set<DexType> instantiatedTypes,
       SortedSet<DexMethod> targetedMethods,
       Set<DexMethod> failedResolutionTargets,
       SortedSet<DexMethod> bootstrapMethods,
@@ -201,6 +202,7 @@
       SortedSet<DexMethod> virtualMethodsTargetedByInvokeDirect,
       SortedSet<DexMethod> liveMethods,
       FieldAccessInfoCollectionImpl fieldAccessInfoCollection,
+      ObjectAllocationInfoCollectionImpl objectAllocationInfoCollection,
       SortedMap<DexMethod, Set<DexEncodedMethod>> virtualInvokes,
       SortedMap<DexMethod, Set<DexEncodedMethod>> interfaceInvokes,
       SortedMap<DexMethod, Set<DexEncodedMethod>> superInvokes,
@@ -233,7 +235,6 @@
     this.missingTypes = missingTypes;
     this.liveTypes = liveTypes;
     this.instantiatedAppServices = instantiatedAppServices;
-    this.instantiatedTypes = instantiatedTypes;
     this.targetedMethods = targetedMethods;
     this.failedResolutionTargets = failedResolutionTargets;
     this.bootstrapMethods = bootstrapMethods;
@@ -241,6 +242,7 @@
     this.virtualMethodsTargetedByInvokeDirect = virtualMethodsTargetedByInvokeDirect;
     this.liveMethods = liveMethods;
     this.fieldAccessInfoCollection = fieldAccessInfoCollection;
+    this.objectAllocationInfoCollection = objectAllocationInfoCollection;
     this.pinnedItems = pinnedItems;
     this.mayHaveSideEffects = mayHaveSideEffects;
     this.noSideEffects = noSideEffects;
@@ -276,7 +278,6 @@
       Set<DexType> missingTypes,
       Set<DexType> liveTypes,
       Set<DexType> instantiatedAppServices,
-      Set<DexType> instantiatedTypes,
       SortedSet<DexMethod> targetedMethods,
       Set<DexMethod> failedResolutionTargets,
       SortedSet<DexMethod> bootstrapMethods,
@@ -284,6 +285,7 @@
       SortedSet<DexMethod> virtualMethodsTargetedByInvokeDirect,
       SortedSet<DexMethod> liveMethods,
       FieldAccessInfoCollectionImpl fieldAccessInfoCollection,
+      ObjectAllocationInfoCollectionImpl objectAllocationInfoCollection,
       SortedMap<DexMethod, Set<DexEncodedMethod>> virtualInvokes,
       SortedMap<DexMethod, Set<DexEncodedMethod>> interfaceInvokes,
       SortedMap<DexMethod, Set<DexEncodedMethod>> superInvokes,
@@ -316,7 +318,6 @@
     this.missingTypes = missingTypes;
     this.liveTypes = liveTypes;
     this.instantiatedAppServices = instantiatedAppServices;
-    this.instantiatedTypes = instantiatedTypes;
     this.targetedMethods = targetedMethods;
     this.failedResolutionTargets = failedResolutionTargets;
     this.bootstrapMethods = bootstrapMethods;
@@ -324,6 +325,7 @@
     this.virtualMethodsTargetedByInvokeDirect = virtualMethodsTargetedByInvokeDirect;
     this.liveMethods = liveMethods;
     this.fieldAccessInfoCollection = fieldAccessInfoCollection;
+    this.objectAllocationInfoCollection = objectAllocationInfoCollection;
     this.pinnedItems = pinnedItems;
     this.mayHaveSideEffects = mayHaveSideEffects;
     this.noSideEffects = noSideEffects;
@@ -360,7 +362,6 @@
         previous.missingTypes,
         previous.liveTypes,
         previous.instantiatedAppServices,
-        previous.instantiatedTypes,
         previous.targetedMethods,
         previous.failedResolutionTargets,
         previous.bootstrapMethods,
@@ -368,6 +369,7 @@
         previous.virtualMethodsTargetedByInvokeDirect,
         previous.liveMethods,
         previous.fieldAccessInfoCollection,
+        previous.objectAllocationInfoCollection,
         previous.virtualInvokes,
         previous.interfaceInvokes,
         previous.superInvokes,
@@ -409,7 +411,6 @@
         previous.missingTypes,
         previous.liveTypes,
         previous.instantiatedAppServices,
-        previous.instantiatedTypes,
         previous.targetedMethods,
         previous.failedResolutionTargets,
         previous.bootstrapMethods,
@@ -417,6 +418,7 @@
         previous.virtualMethodsTargetedByInvokeDirect,
         previous.liveMethods,
         previous.fieldAccessInfoCollection,
+        previous.objectAllocationInfoCollection,
         previous.virtualInvokes,
         previous.interfaceInvokes,
         previous.superInvokes,
@@ -461,7 +463,6 @@
     this.missingTypes = previous.missingTypes;
     this.liveTypes = previous.liveTypes;
     this.instantiatedAppServices = previous.instantiatedAppServices;
-    this.instantiatedTypes = previous.instantiatedTypes;
     this.instantiatedLambdas = previous.instantiatedLambdas;
     this.targetedMethods = previous.targetedMethods;
     this.failedResolutionTargets = previous.failedResolutionTargets;
@@ -470,6 +471,7 @@
     this.virtualMethodsTargetedByInvokeDirect = previous.virtualMethodsTargetedByInvokeDirect;
     this.liveMethods = previous.liveMethods;
     this.fieldAccessInfoCollection = previous.fieldAccessInfoCollection;
+    this.objectAllocationInfoCollection = previous.objectAllocationInfoCollection;
     this.pinnedItems = previous.pinnedItems;
     this.mayHaveSideEffects = previous.mayHaveSideEffects;
     this.noSideEffects = previous.noSideEffects;
@@ -713,6 +715,11 @@
     return fieldAccessInfoCollection;
   }
 
+  /** This method provides immutable access to `objectAllocationInfoCollection`. */
+  public ObjectAllocationInfoCollection getObjectAllocationInfoCollection() {
+    return objectAllocationInfoCollection;
+  }
+
   private boolean assertNoItemRemoved(Collection<DexReference> items, Collection<DexType> types) {
     Set<DexType> typeSet = ImmutableSet.copyOf(types);
     for (DexReference item : items) {
@@ -734,7 +741,7 @@
     assert checkIfObsolete();
     DexType type = clazz.type;
     return type.isD8R8SynthesizedClassType()
-        || instantiatedTypes.contains(type)
+        || objectAllocationInfoCollection.isInstantiatedDirectly(clazz)
         || (clazz.isAnnotation() && liveTypes.contains(type));
   }
 
@@ -993,7 +1000,6 @@
         missingTypes,
         rewriteItems(liveTypes, lens::lookupType),
         rewriteItems(instantiatedAppServices, lens::lookupType),
-        rewriteItems(instantiatedTypes, lens::lookupType),
         lens.rewriteMethodsConservatively(targetedMethods),
         lens.rewriteMethodsConservatively(failedResolutionTargets),
         lens.rewriteMethodsConservatively(bootstrapMethods),
@@ -1001,6 +1007,7 @@
         lens.rewriteMethodsConservatively(virtualMethodsTargetedByInvokeDirect),
         lens.rewriteMethodsConservatively(liveMethods),
         fieldAccessInfoCollection.rewrittenWithLens(application, lens),
+        objectAllocationInfoCollection.rewrittenWithLens(application, lens),
         rewriteKeysConservativelyWhileMergingValues(
             virtualInvokes, lens::lookupMethodInAllContexts),
         rewriteKeysConservativelyWhileMergingValues(
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 849d05f..d0e971e 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -54,6 +54,7 @@
 import com.android.tools.r8.graph.FieldAccessInfoImpl;
 import com.android.tools.r8.graph.InnerClassAttribute;
 import com.android.tools.r8.graph.LookupResult;
+import com.android.tools.r8.graph.ObjectAllocationInfoCollectionImpl;
 import com.android.tools.r8.graph.PresortedComparable;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.ResolutionResult;
@@ -177,6 +178,7 @@
   private final Map<DexMethod, Set<DexEncodedMethod>> staticInvokes = new IdentityHashMap<>();
   private final FieldAccessInfoCollectionImpl fieldAccessInfoCollection =
       new FieldAccessInfoCollectionImpl();
+  private final ObjectAllocationInfoCollectionImpl.Builder objectAllocationInfoCollection;
   private final Set<DexCallSite> callSites = Sets.newIdentityHashSet();
 
   private final Set<DexReference> identifierNameStrings = Sets.newIdentityHashSet();
@@ -222,9 +224,6 @@
   private final Map<DexProgramClass, Set<DexProgramClass>> unusedInterfaceTypes =
       new IdentityHashMap<>();
 
-  /** Set of types that are actually instantiated. These cannot be abstract. */
-  private final SetWithReason<DexProgramClass> instantiatedTypes;
-
   /** Set of all types that are instantiated, directly or indirectly, thus may be abstract. */
   private final Set<DexProgramClass> directAndIndirectlyInstantiatedTypes =
       Sets.newIdentityHashSet();
@@ -355,7 +354,6 @@
 
     liveTypes = new SetWithReportedReason<>();
     initializedTypes = new SetWithReportedReason<>();
-    instantiatedTypes = new SetWithReason<>(graphReporter::registerClass);
     targetedMethods = new SetWithReason<>(graphReporter::registerMethod);
     // This set is only populated in edge cases due to multiple default interface methods.
     // The set is generally expected to be empty and in the unlikely chance it is not, it will
@@ -367,6 +365,10 @@
     instantiatedInterfaceTypes = Sets.newIdentityHashSet();
     lambdaRewriter = options.desugarState == DesugarState.ON ? new LambdaRewriter(appView) : null;
 
+    // TODO(b/147799448): Enable allocation site tracking during the initial round of tree shaking.
+    objectAllocationInfoCollection =
+        ObjectAllocationInfoCollectionImpl.builder(false, graphReporter);
+
     if (appView.rewritePrefix.isRewriting() && mode.isInitialTreeShaking()) {
       desugaredLibraryWrapperAnalysis = new DesugaredLibraryConversionWrapperAnalysis(appView);
       registerAnalysis(desugaredLibraryWrapperAnalysis);
@@ -469,7 +471,7 @@
       } else if (clazz.isInterface()) {
         workList.enqueueMarkInterfaceInstantiatedAction(clazz, witness);
       } else {
-        workList.enqueueMarkInstantiatedAction(clazz, null, witness);
+        workList.enqueueMarkInstantiatedAction(clazz, null, InstantiationReason.KEEP_RULE, witness);
         if (clazz.hasDefaultInitializer()) {
           DexEncodedMethod defaultInitializer = clazz.getDefaultInitializer();
           if (forceProguardCompatibility) {
@@ -779,7 +781,8 @@
         } else if (clazz.isInterface()) {
           markInterfaceAsInstantiated(clazz, graphReporter.registerInterface(clazz, reason));
         } else {
-          markInstantiated(clazz, null, reason);
+          workList.enqueueMarkInstantiatedAction(
+              clazz, null, InstantiationReason.REFERENCED_IN_METHOD_HANDLE, reason);
         }
       }
     }
@@ -966,21 +969,34 @@
       return false;
     }
 
-    return traceNewInstance(type, context, KeepReason.instantiatedIn(currentMethod));
+    return traceNewInstance(
+        type,
+        context,
+        InstantiationReason.NEW_INSTANCE_INSTRUCTION,
+        KeepReason.instantiatedIn(currentMethod));
   }
 
   boolean traceNewInstanceFromLambda(DexType type, ProgramMethod context) {
-    return traceNewInstance(type, context, KeepReason.invokedFromLambdaCreatedIn(context.method));
+    return traceNewInstance(
+        type,
+        context,
+        InstantiationReason.LAMBDA,
+        KeepReason.invokedFromLambdaCreatedIn(context.method));
   }
 
-  private boolean traceNewInstance(DexType type, ProgramMethod context, KeepReason keepReason) {
+  private boolean traceNewInstance(
+      DexType type,
+      ProgramMethod context,
+      InstantiationReason instantiationReason,
+      KeepReason keepReason) {
     DexEncodedMethod currentMethod = context.method;
     DexProgramClass clazz = getProgramClassOrNull(type);
     if (clazz != null) {
       if (clazz.isAnnotation() || clazz.isInterface()) {
         markTypeAsLive(clazz, graphReporter.registerClass(clazz, keepReason));
       } else {
-        markInstantiated(clazz, currentMethod, keepReason);
+        workList.enqueueMarkInstantiatedAction(
+            clazz, currentMethod, instantiationReason, keepReason);
       }
     }
     return true;
@@ -1597,7 +1613,10 @@
    */
   // Package protected due to entry point from worklist.
   void processNewlyInstantiatedClass(
-      DexProgramClass clazz, DexEncodedMethod context, KeepReason reason) {
+      DexProgramClass clazz,
+      DexEncodedMethod context,
+      InstantiationReason instantiationReason,
+      KeepReason keepReason) {
     assert !clazz.isAnnotation();
     assert !clazz.isInterface();
 
@@ -1607,7 +1626,8 @@
     analyses.forEach(
         analysis -> analysis.processNewlyInstantiatedClass(clazz.asProgramClass(), context));
 
-    if (!instantiatedTypes.add(clazz, reason)) {
+    if (!objectAllocationInfoCollection.recordDirectAllocationSite(
+        clazz, context, instantiationReason, keepReason)) {
       return;
     }
 
@@ -1617,7 +1637,7 @@
       Log.verbose(getClass(), "Class `%s` is instantiated, processing...", clazz);
     }
     // This class becomes live, so it and all its supertypes become live types.
-    markTypeAsLive(clazz, graphReporter.registerClass(clazz, reason));
+    markTypeAsLive(clazz, graphReporter.registerClass(clazz, keepReason));
     // Instantiation triggers class initialization.
     markDirectAndIndirectClassInitializersAsLive(clazz);
     // For all methods of the class, if we have seen a call, mark the method live.
@@ -1673,7 +1693,7 @@
       transitionReachableVirtualMethods(current, seen);
       Collections.addAll(interfaces, current.interfaces.values);
       current = getProgramClassOrNull(current.superType);
-    } while (current != null && !instantiatedTypes.contains(current));
+    } while (current != null && !objectAllocationInfoCollection.isInstantiatedDirectly(current));
 
     // The set now contains all virtual methods on the type and its supertype that are reachable.
     // In a second step, we now look at interfaces. We have to do this in this order due to JVM
@@ -1835,7 +1855,7 @@
         }
       }
       clazz = getProgramClassOrNull(clazz.superType);
-    } while (clazz != null && !instantiatedTypes.contains(clazz));
+    } while (clazz != null && !objectAllocationInfoCollection.isInstantiatedDirectly(clazz));
   }
 
   private void transitionDependentItemsForInstantiatedClass(DexProgramClass clazz) {
@@ -1850,7 +1870,7 @@
           clazz.superType != null
               ? asProgramClassOrNull(appView.definitionFor(clazz.superType))
               : null;
-    } while (clazz != null && !instantiatedTypes.contains(clazz));
+    } while (clazz != null && !objectAllocationInfoCollection.isInstantiatedDirectly(clazz));
   }
 
   private void transitionUnusedInterfaceToLive(DexProgramClass clazz) {
@@ -1928,14 +1948,6 @@
     analyses.forEach(analysis -> analysis.processNewlyLiveField(field));
   }
 
-  private void markInstantiated(
-      DexProgramClass clazz, DexEncodedMethod context, KeepReason reason) {
-    if (Log.ENABLED) {
-      Log.verbose(getClass(), "Register new instantiation of `%s`.", clazz);
-    }
-    workList.enqueueMarkInstantiatedAction(clazz, context, reason);
-  }
-
   private void markLambdaInstantiated(DexType itf, DexEncodedMethod method) {
     DexClass clazz = appView.definitionFor(itf);
     if (clazz == null) {
@@ -2183,7 +2195,8 @@
       return;
     }
 
-    if (instantiatedTypes.contains(clazz) || instantiatedInterfaceTypes.contains(clazz)) {
+    if (objectAllocationInfoCollection.isInstantiatedDirectly(clazz)
+        || instantiatedInterfaceTypes.contains(clazz)) {
       markVirtualMethodAsLive(
           clazz,
           encodedPossibleTarget,
@@ -2199,7 +2212,7 @@
         if (currentClass == null || currentClass.lookupVirtualMethod(possibleTarget) != null) {
           continue;
         }
-        if (instantiatedTypes.contains(currentClass)
+        if (objectAllocationInfoCollection.isInstantiatedDirectly(currentClass)
             || instantiatedInterfaceTypes.contains(currentClass)) {
           markVirtualMethodAsLive(
               clazz,
@@ -2448,7 +2461,6 @@
             missingTypes,
             SetUtils.mapIdentityHashSet(liveTypes.getItems(), DexProgramClass::getType),
             Collections.unmodifiableSet(instantiatedAppServices),
-            SetUtils.mapIdentityHashSet(instantiatedTypes.getItems(), DexProgramClass::getType),
             Enqueuer.toSortedDescriptorSet(targetedMethods.getItems()),
             Collections.unmodifiableSet(failedResolutionTargets),
             ImmutableSortedSet.copyOf(DexMethod::slowCompareTo, bootstrapMethods),
@@ -2458,6 +2470,7 @@
             toSortedDescriptorSet(liveMethods.getItems()),
             // Filter out library fields and pinned fields, because these are read by default.
             fieldAccessInfoCollection,
+            objectAllocationInfoCollection.build(),
             // TODO(b/132593519): Do we require these sets to be sorted for determinism?
             toImmutableSortedMap(virtualInvokes, PresortedComparable::slowCompare),
             toImmutableSortedMap(interfaceInvokes, PresortedComparable::slowCompare),
@@ -2550,7 +2563,11 @@
     for (DexProgramClass wrapper : wrappers) {
       appBuilder.addProgramClass(wrapper);
       liveTypes.add(wrapper, graphReporter.fakeReportShouldNotBeUsed());
-      instantiatedTypes.add(wrapper, graphReporter.fakeReportShouldNotBeUsed());
+      objectAllocationInfoCollection.recordDirectAllocationSite(
+          wrapper,
+          null,
+          InstantiationReason.SYNTHESIZED_CLASS,
+          graphReporter.fakeReportShouldNotBeUsed());
       // Mark all methods on the wrapper as live and targeted.
       for (DexEncodedMethod method : wrapper.methods()) {
         targetedMethods.add(method, graphReporter.fakeReportShouldNotBeUsed());
@@ -2585,7 +2602,11 @@
         appBuilder.addToMainDexList(Collections.singletonList(programClass.type));
       }
       liveTypes.add(programClass, graphReporter.fakeReportShouldNotBeUsed());
-      instantiatedTypes.add(programClass, graphReporter.fakeReportShouldNotBeUsed());
+      objectAllocationInfoCollection.recordDirectAllocationSite(
+          programClass,
+          null,
+          InstantiationReason.SYNTHESIZED_CLASS,
+          graphReporter.fakeReportShouldNotBeUsed());
 
       // Register all of the field writes in the lambda constructors.
       // This is needed to ensure that the initializers can be optimized.
@@ -2772,11 +2793,6 @@
         Set<DexEncodedMethod> reachableNotLive = Sets.difference(allLive, liveMethods.getItems());
         Log.debug(getClass(), "%s methods are reachable but not live", reachableNotLive.size());
         Log.info(getClass(), "Only reachable: %s", reachableNotLive);
-        Set<DexProgramClass> liveButNotInstantiated =
-            Sets.difference(liveTypes.getItems(), instantiatedTypes.getItems());
-        Log.debug(getClass(), "%s classes are live but not instantiated",
-            liveButNotInstantiated.size());
-        Log.info(getClass(), "Live but not instantiated: %s", liveButNotInstantiated);
         SetView<DexEncodedMethod> targetedButNotLive = Sets
             .difference(targetedMethods.getItems(), liveMethods.getItems());
         Log.debug(getClass(), "%s methods are targeted but not live", targetedButNotLive.size());
@@ -3013,7 +3029,7 @@
   }
 
   private void markClassAsInstantiatedWithReason(DexProgramClass clazz, KeepReason reason) {
-    workList.enqueueMarkInstantiatedAction(clazz, null, reason);
+    workList.enqueueMarkInstantiatedAction(clazz, null, InstantiationReason.REFLECTION, reason);
     if (clazz.hasDefaultInitializer()) {
       workList.enqueueMarkReachableDirectAction(clazz.getDefaultInitializer().method, reason);
     }
@@ -3023,18 +3039,16 @@
       DexProgramClass clazz, KeepReasonWitness witness) {
     if (clazz.isAnnotation()) {
       markTypeAsLive(clazz, witness);
-      return;
-    }
-    if (clazz.isInterface()) {
+    } else if (clazz.isInterface()) {
       markInterfaceAsInstantiated(clazz, witness);
-      return;
-    }
-    workList.enqueueMarkInstantiatedAction(clazz, null, witness);
-    if (clazz.hasDefaultInitializer()) {
-      DexEncodedMethod defaultInitializer = clazz.getDefaultInitializer();
-      workList.enqueueMarkReachableDirectAction(
-          defaultInitializer.method,
-          graphReporter.reportCompatKeepDefaultInitializer(clazz, defaultInitializer));
+    } else {
+      workList.enqueueMarkInstantiatedAction(clazz, null, InstantiationReason.KEEP_RULE, witness);
+      if (clazz.hasDefaultInitializer()) {
+        DexEncodedMethod defaultInitializer = clazz.getDefaultInitializer();
+        workList.enqueueMarkReachableDirectAction(
+            defaultInitializer.method,
+            graphReporter.reportCompatKeepDefaultInitializer(clazz, defaultInitializer));
+      }
     }
   }
 
@@ -3095,7 +3109,8 @@
       if (clazz.isAnnotation() || clazz.isInterface()) {
         markTypeAsLive(clazz.type, KeepReason.reflectiveUseIn(method));
       } else {
-        markInstantiated(clazz, null, KeepReason.reflectiveUseIn(method));
+        workList.enqueueMarkInstantiatedAction(
+            clazz, null, InstantiationReason.REFLECTION, KeepReason.reflectiveUseIn(method));
         if (clazz.hasDefaultInitializer()) {
           DexEncodedMethod initializer = clazz.getDefaultInitializer();
           KeepReason reason = KeepReason.reflectiveUseIn(method);
@@ -3122,7 +3137,8 @@
           !encodedField.accessFlags.isStatic()
               && dexItemFactory.atomicFieldUpdaterMethods.isFieldUpdater(invokedMethod);
       if (keepClass) {
-        markInstantiated(clazz, null, KeepReason.reflectiveUseIn(method));
+        workList.enqueueMarkInstantiatedAction(
+            clazz, null, InstantiationReason.REFLECTION, KeepReason.reflectiveUseIn(method));
       }
       if (pinnedItems.add(encodedField.field)) {
         markFieldAsKept(clazz, encodedField, KeepReason.reflectiveUseIn(method));
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerWorklist.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerWorklist.java
index 32f931f..45e61a4 100644
--- a/src/main/java/com/android/tools/r8/shaking/EnqueuerWorklist.java
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerWorklist.java
@@ -68,20 +68,26 @@
   }
 
   static class MarkInstantiatedAction extends EnqueuerAction {
+
     final DexProgramClass target;
     final DexEncodedMethod context;
-    final KeepReason reason;
+    final InstantiationReason instantiationReason;
+    final KeepReason keepReason;
 
     public MarkInstantiatedAction(
-        DexProgramClass target, DexEncodedMethod context, KeepReason reason) {
+        DexProgramClass target,
+        DexEncodedMethod context,
+        InstantiationReason instantiationReason,
+        KeepReason keepReason) {
       this.target = target;
       this.context = context;
-      this.reason = reason;
+      this.instantiationReason = instantiationReason;
+      this.keepReason = keepReason;
     }
 
     @Override
     public void run(Enqueuer enqueuer) {
-      enqueuer.processNewlyInstantiatedClass(target, context, reason);
+      enqueuer.processNewlyInstantiatedClass(target, context, instantiationReason, keepReason);
     }
   }
 
@@ -265,10 +271,13 @@
   // TODO(b/142378367): Context is the containing method that is cause of the instantiation.
   // Consider updating call sites with the context information to increase precision where possible.
   void enqueueMarkInstantiatedAction(
-      DexProgramClass clazz, DexEncodedMethod context, KeepReason reason) {
+      DexProgramClass clazz,
+      DexEncodedMethod context,
+      InstantiationReason instantiationReason,
+      KeepReason keepReason) {
     assert !clazz.isAnnotation();
     assert !clazz.isInterface();
-    queue.add(new MarkInstantiatedAction(clazz, context, reason));
+    queue.add(new MarkInstantiatedAction(clazz, context, instantiationReason, keepReason));
   }
 
   void enqueueMarkAnnotationInstantiatedAction(DexProgramClass clazz, KeepReasonWitness reason) {
diff --git a/src/main/java/com/android/tools/r8/shaking/InstantiationReason.java b/src/main/java/com/android/tools/r8/shaking/InstantiationReason.java
new file mode 100644
index 0000000..c6af09d
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/shaking/InstantiationReason.java
@@ -0,0 +1,14 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking;
+
+public enum InstantiationReason {
+  KEEP_RULE,
+  LAMBDA,
+  NEW_INSTANCE_INSTRUCTION,
+  REFERENCED_IN_METHOD_HANDLE,
+  REFLECTION,
+  SYNTHESIZED_CLASS
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/TreePruner.java b/src/main/java/com/android/tools/r8/shaking/TreePruner.java
index cea0720..3ded35c 100644
--- a/src/main/java/com/android/tools/r8/shaking/TreePruner.java
+++ b/src/main/java/com/android/tools/r8/shaking/TreePruner.java
@@ -85,7 +85,7 @@
       }
       if (appInfo.isLiveProgramClass(clazz)) {
         newClasses.add(clazz);
-        if (!appInfo.instantiatedTypes.contains(clazz.type)
+        if (!appInfo.getObjectAllocationInfoCollection().isInstantiatedDirectly(clazz)
             && !options.forceProguardCompatibility) {
           // The class is only needed as a type but never instantiated. Make it abstract to reflect
           // this.
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index 1706c8c..76db9fe 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -326,7 +326,7 @@
   // Returns true if [clazz] is a merge candidate. Note that the result of the checks in this
   // method do not change in response to any class merges.
   private boolean isMergeCandidate(DexProgramClass clazz, Set<DexType> pinnedTypes) {
-    if (appInfo.instantiatedTypes.contains(clazz.type)
+    if (appInfo.getObjectAllocationInfoCollection().isInstantiatedDirectly(clazz)
         || appInfo.instantiatedLambdas.contains(clazz.type)
         || appInfo.isPinned(clazz.type)
         || pinnedTypes.contains(clazz.type)
diff --git a/src/main/java/com/android/tools/r8/utils/LensUtils.java b/src/main/java/com/android/tools/r8/utils/LensUtils.java
new file mode 100644
index 0000000..325d6fa
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/LensUtils.java
@@ -0,0 +1,23 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import com.android.tools.r8.graph.DexDefinitionSupplier;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.GraphLense;
+import com.google.common.collect.Sets;
+import java.util.Set;
+
+public class LensUtils {
+
+  public static Set<DexEncodedMethod> rewrittenWithRenamedSignature(
+      Set<DexEncodedMethod> methods, DexDefinitionSupplier definitions, GraphLense lens) {
+    Set<DexEncodedMethod> result = Sets.newIdentityHashSet();
+    for (DexEncodedMethod method : methods) {
+      result.add(lens.mapDexEncodedMethod(method, definitions));
+    }
+    return result;
+  }
+}