Compute abstraction of instance field values when all allocation sites are seen

Change-Id: Ia8a0fcff713ccd1b3a0501c09671054e4d6cb691
Bug: 147652121
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/fieldaccess/FieldAccessAnalysis.java b/src/main/java/com/android/tools/r8/ir/analysis/fieldaccess/FieldAccessAnalysis.java
index b463c75..1e8bb52 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/fieldaccess/FieldAccessAnalysis.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/fieldaccess/FieldAccessAnalysis.java
@@ -4,11 +4,15 @@
 
 package com.android.tools.r8.ir.analysis.fieldaccess;
 
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedField;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.ir.code.FieldInstruction;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.optimize.ClassInitializerDefaultsOptimization.ClassInitializerDefaultsResult;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
@@ -56,20 +60,33 @@
 
   public void recordFieldAccesses(
       IRCode code, OptimizationFeedback feedback, MethodProcessor methodProcessor) {
-    if (!code.metadata().mayHaveFieldInstruction() || !methodProcessor.isPrimary()) {
+    if (!methodProcessor.isPrimary()) {
       return;
     }
 
-    Iterable<FieldInstruction> fieldInstructions =
-        code.instructions(Instruction::isFieldInstruction);
-    for (FieldInstruction fieldInstruction : fieldInstructions) {
-      DexEncodedField encodedField = appView.appInfo().resolveField(fieldInstruction.getField());
-      if (encodedField != null && encodedField.isProgramField(appView)) {
-        if (fieldAssignmentTracker != null) {
-          fieldAssignmentTracker.recordFieldAccess(fieldInstruction, encodedField, code.method);
+    if (!code.metadata().mayHaveFieldInstruction() && !code.metadata().mayHaveNewInstance()) {
+      return;
+    }
+
+    for (Instruction instruction : code.instructions()) {
+      if (instruction.isFieldInstruction()) {
+        FieldInstruction fieldInstruction = instruction.asFieldInstruction();
+        DexEncodedField encodedField = appView.appInfo().resolveField(fieldInstruction.getField());
+        if (encodedField != null && encodedField.isProgramField(appView)) {
+          if (fieldAssignmentTracker != null) {
+            fieldAssignmentTracker.recordFieldAccess(fieldInstruction, encodedField, code.method);
+          }
+          if (fieldBitAccessAnalysis != null) {
+            fieldBitAccessAnalysis.recordFieldAccess(fieldInstruction, encodedField, feedback);
+          }
         }
-        if (fieldBitAccessAnalysis != null) {
-          fieldBitAccessAnalysis.recordFieldAccess(fieldInstruction, encodedField, feedback);
+      } else if (instruction.isNewInstance()) {
+        NewInstance newInstance = instruction.asNewInstance();
+        DexProgramClass clazz = asProgramClassOrNull(appView.definitionFor(newInstance.clazz));
+        if (clazz != null) {
+          if (fieldAssignmentTracker != null) {
+            fieldAssignmentTracker.recordAllocationSite(newInstance, clazz, code.method);
+          }
         }
       }
     }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/fieldaccess/FieldAssignmentTracker.java b/src/main/java/com/android/tools/r8/ir/analysis/fieldaccess/FieldAssignmentTracker.java
index 0b0b542..6d91163 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/fieldaccess/FieldAssignmentTracker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/fieldaccess/FieldAssignmentTracker.java
@@ -7,13 +7,22 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.FieldAccessInfoCollection;
+import com.android.tools.r8.graph.ObjectAllocationInfoCollection;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.ir.analysis.value.BottomValue;
+import com.android.tools.r8.ir.analysis.value.UnknownValue;
 import com.android.tools.r8.ir.code.FieldInstruction;
+import com.android.tools.r8.ir.code.InvokeDirect;
+import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.optimize.ClassInitializerDefaultsOptimization.ClassInitializerDefaultsResult;
 import com.android.tools.r8.ir.optimize.info.FieldOptimizationInfo;
-import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
+import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackDelayed;
+import com.android.tools.r8.ir.optimize.info.field.InstanceFieldArgumentInitializationInfo;
+import com.android.tools.r8.ir.optimize.info.field.InstanceFieldInitializationInfo;
+import com.android.tools.r8.ir.optimize.info.field.InstanceFieldInitializationInfoCollection;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.google.common.collect.Sets;
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
@@ -21,9 +30,11 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.IdentityHashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Consumer;
 
 public class FieldAssignmentTracker {
@@ -35,12 +46,52 @@
   // processed when a field no longer has any incoming edges.
   private final FieldAccessGraph fieldAccessGraph;
 
+  // An object allocation graph with edges from methods to the classes they instantiate. Edges are
+  // removed from the graph as we process methods, such that we can conclude that all allocation
+  // sites have been seen when a class no longer has any incoming edges.
+  private final ObjectAllocationGraph objectAllocationGraph;
+
   // The set of fields that may store a non-zero value.
   private final Set<DexEncodedField> nonZeroFields = Sets.newConcurrentHashSet();
 
+  private final Map<DexProgramClass, Map<DexEncodedField, AbstractValue>>
+      abstractInstanceFieldValues = new ConcurrentHashMap<>();
+
   FieldAssignmentTracker(AppView<AppInfoWithLiveness> appView) {
     this.appView = appView;
     this.fieldAccessGraph = new FieldAccessGraph(appView);
+    this.objectAllocationGraph = new ObjectAllocationGraph(appView);
+    initializeAbstractInstanceFieldValues();
+  }
+
+  /**
+   * For each class with known allocation sites, adds a mapping from clazz -> instance field ->
+   * bottom.
+   *
+   * <p>If an entry (clazz, instance field) is missing in {@link #abstractInstanceFieldValues}, it
+   * is interpreted as if we known nothing about the value of the field.
+   */
+  private void initializeAbstractInstanceFieldValues() {
+    ObjectAllocationInfoCollection objectAllocationInfos =
+        appView.appInfo().getObjectAllocationInfoCollection();
+    objectAllocationInfos.forEachClassWithKnownAllocationSites(
+        (clazz, allocationSites) -> {
+          if (appView.appInfo().isInstantiatedIndirectly(clazz)) {
+            // TODO(b/147652121): Handle classes that are instantiated indirectly.
+            return;
+          }
+          List<DexEncodedField> instanceFields = clazz.instanceFields();
+          if (instanceFields.isEmpty()) {
+            // No instance fields to track.
+            return;
+          }
+          Map<DexEncodedField, AbstractValue> abstractInstanceFieldValuesForClass =
+              new IdentityHashMap<>();
+          for (DexEncodedField field : clazz.instanceFields()) {
+            abstractInstanceFieldValuesForClass.put(field, BottomValue.getInstance());
+          }
+          abstractInstanceFieldValues.put(clazz, abstractInstanceFieldValuesForClass);
+        });
   }
 
   private boolean isAlwaysZero(DexEncodedField field) {
@@ -72,16 +123,98 @@
     }
   }
 
-  private void recordAllFieldPutsProcessed(DexEncodedField field, OptimizationFeedback feedback) {
+  void recordAllocationSite(
+      NewInstance instruction, DexProgramClass clazz, DexEncodedMethod context) {
+    Map<DexEncodedField, AbstractValue> abstractInstanceFieldValuesForClass =
+        abstractInstanceFieldValues.get(clazz);
+    if (abstractInstanceFieldValuesForClass == null) {
+      // We are not tracking the value of any of clazz' instance fields.
+      return;
+    }
+
+    InvokeDirect invoke = instruction.getUniqueConstructorInvoke(appView.dexItemFactory());
+    if (invoke == null) {
+      // We just lost track.
+      abstractInstanceFieldValues.remove(clazz);
+      return;
+    }
+
+    DexEncodedMethod singleTarget = invoke.lookupSingleTarget(appView, context.method.holder);
+    if (singleTarget == null) {
+      // We just lost track.
+      abstractInstanceFieldValues.remove(clazz);
+      return;
+    }
+
+    InstanceFieldInitializationInfoCollection initializationInfoCollection =
+        singleTarget.getOptimizationInfo().getInstanceInitializerInfo().fieldInitializationInfos();
+
+    // Synchronize on the lattice element (abstractInstanceFieldValuesForClass) in case we process
+    // another allocation site of `clazz` concurrently.
+    synchronized (abstractInstanceFieldValuesForClass) {
+      Iterator<Map.Entry<DexEncodedField, AbstractValue>> iterator =
+          abstractInstanceFieldValuesForClass.entrySet().iterator();
+      while (iterator.hasNext()) {
+        Map.Entry<DexEncodedField, AbstractValue> entry = iterator.next();
+        DexEncodedField field = entry.getKey();
+        InstanceFieldInitializationInfo initializationInfo =
+            initializationInfoCollection.get(field);
+        if (initializationInfo.isArgumentInitializationInfo()) {
+          InstanceFieldArgumentInitializationInfo argumentInitializationInfo =
+              initializationInfo.asArgumentInitializationInfo();
+          Value argument = invoke.arguments().get(argumentInitializationInfo.getArgumentIndex());
+          AbstractValue abstractValue =
+              argument.getAbstractValue(appView, context.method.holder).join(entry.getValue());
+          assert !abstractValue.isBottom();
+          if (!abstractValue.isUnknown()) {
+            entry.setValue(abstractValue);
+            continue;
+          }
+        } else {
+          assert initializationInfo.isUnknown();
+        }
+
+        // We just lost track for this field.
+        iterator.remove();
+      }
+    }
+  }
+
+  private void recordAllFieldPutsProcessed(
+      DexEncodedField field, OptimizationFeedbackDelayed feedback) {
     if (isAlwaysZero(field)) {
       feedback.recordFieldHasAbstractValue(
           field, appView, appView.abstractValueFactory().createSingleNumberValue(0));
     }
   }
 
-  public void waveDone(Collection<DexEncodedMethod> wave, OptimizationFeedback feedback) {
+  private void recordAllAllocationsSitesProcessed(
+      DexProgramClass clazz, OptimizationFeedbackDelayed feedback) {
+    Map<DexEncodedField, AbstractValue> abstractInstanceFieldValuesForClass =
+        abstractInstanceFieldValues.get(clazz);
+    if (abstractInstanceFieldValuesForClass == null) {
+      return;
+    }
+
+    for (DexEncodedField field : clazz.instanceFields()) {
+      AbstractValue abstractValue =
+          abstractInstanceFieldValuesForClass.getOrDefault(field, UnknownValue.getInstance());
+      if (abstractValue.isBottom()) {
+        // TODO(b/149454532): Record that the type is not instantiated.
+        break;
+      }
+      if (abstractValue.isUnknown()) {
+        continue;
+      }
+      feedback.recordFieldHasAbstractValue(field, appView, abstractValue);
+    }
+  }
+
+  public void waveDone(Collection<DexEncodedMethod> wave, OptimizationFeedbackDelayed feedback) {
     for (DexEncodedMethod method : wave) {
       fieldAccessGraph.markProcessed(method, field -> recordAllFieldPutsProcessed(field, feedback));
+      objectAllocationGraph.markProcessed(
+          method, clazz -> recordAllAllocationsSitesProcessed(clazz, feedback));
     }
   }
 
@@ -139,4 +272,42 @@
       }
     }
   }
+
+  static class ObjectAllocationGraph {
+
+    // The classes instantiated by each method.
+    private final Map<DexEncodedMethod, List<DexProgramClass>> objectAllocations =
+        new IdentityHashMap<>();
+
+    // The number of allocation sites that have not yet been processed per class.
+    private final Reference2IntMap<DexProgramClass> pendingObjectAllocations =
+        new Reference2IntOpenHashMap<>();
+
+    ObjectAllocationGraph(AppView<AppInfoWithLiveness> appView) {
+      ObjectAllocationInfoCollection objectAllocationInfos =
+          appView.appInfo().getObjectAllocationInfoCollection();
+      objectAllocationInfos.forEachClassWithKnownAllocationSites(
+          (clazz, contexts) -> {
+            for (DexEncodedMethod context : contexts) {
+              objectAllocations.computeIfAbsent(context, ignore -> new ArrayList<>()).add(clazz);
+            }
+            pendingObjectAllocations.put(clazz, contexts.size());
+          });
+    }
+
+    void markProcessed(
+        DexEncodedMethod method, Consumer<DexProgramClass> allAllocationsSitesSeenConsumer) {
+      List<DexProgramClass> allocationSitesInMethod = objectAllocations.get(method);
+      if (allocationSitesInMethod != null) {
+        for (DexProgramClass type : allocationSitesInMethod) {
+          int numberOfPendingAllocationSites = pendingObjectAllocations.removeInt(type) - 1;
+          if (numberOfPendingAllocationSites > 0) {
+            pendingObjectAllocations.put(type, numberOfPendingAllocationSites);
+          } else {
+            allAllocationsSitesSeenConsumer.accept(type);
+          }
+        }
+      }
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCode.java b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
index 39c3d39..e620cb1 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCode.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
@@ -826,6 +826,9 @@
       } else if (instruction.isMul()) {
         assert metadata.mayHaveMul() && metadata.mayHaveArithmeticOrLogicalBinop()
             : "IR metadata should indicate that code has a mul";
+      } else if (instruction.isNewInstance()) {
+        assert metadata.mayHaveNewInstance()
+            : "IR metadata should indicate that code has a new-instance";
       } else if (instruction.isRem()) {
         assert metadata.mayHaveRem() && metadata.mayHaveArithmeticOrLogicalBinop()
             : "IR metadata should indicate that code has a rem";
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRMetadata.java b/src/main/java/com/android/tools/r8/ir/code/IRMetadata.java
index c4d55d8..8db09e9 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRMetadata.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRMetadata.java
@@ -206,6 +206,10 @@
     return get(Opcodes.MUL);
   }
 
+  public boolean mayHaveNewInstance() {
+    return get(Opcodes.NEW_INSTANCE);
+  }
+
   public boolean mayHaveOr() {
     return get(Opcodes.OR);
   }
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index c0788fc..ba98418 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -369,9 +369,8 @@
     instantiatedInterfaceTypes = Sets.newIdentityHashSet();
     lambdaRewriter = options.desugarState == DesugarState.ON ? new LambdaRewriter(appView) : null;
 
-    // TODO(b/147799448): Enable allocation site tracking during the initial round of tree shaking.
     objectAllocationInfoCollection =
-        ObjectAllocationInfoCollectionImpl.builder(false, graphReporter);
+        ObjectAllocationInfoCollectionImpl.builder(mode.isInitialTreeShaking(), graphReporter);
 
     if (appView.rewritePrefix.isRewriting() && mode.isInitialTreeShaking()) {
       desugaredLibraryWrapperAnalysis = new DesugaredLibraryConversionWrapperAnalysis(appView);
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/membervaluepropagation/fields/FieldInitializedByConstantArgumentTest.java b/src/test/java/com/android/tools/r8/ir/optimize/membervaluepropagation/fields/FieldInitializedByConstantArgumentTest.java
index beecf65..ca34efd 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/membervaluepropagation/fields/FieldInitializedByConstantArgumentTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/membervaluepropagation/fields/FieldInitializedByConstantArgumentTest.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.ir.optimize.membervaluepropagation.fields;
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.MatcherAssert.assertThat;
 
 import com.android.tools.r8.NeverClassInline;
@@ -50,8 +51,7 @@
     ClassSubject testClassSubject = inspector.clazz(TestClass.class);
     assertThat(testClassSubject, isPresent());
     assertThat(testClassSubject.uniqueMethodWithName("live"), isPresent());
-    // TODO(b/147652121): Should be absent.
-    assertThat(testClassSubject.uniqueMethodWithName("dead"), isPresent());
+    assertThat(testClassSubject.uniqueMethodWithName("dead"), not(isPresent()));
   }
 
   static class TestClass {