Clear write contexts in field access collection after tree shaking

Change-Id: I6297260aef4ba92ac1b7c0ab6c66fcf7a8a5b9ca
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index c9be529..4e7eb60 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -507,7 +507,6 @@
       if (options.getTestingOptions().enableMemberRebindingAnalysis) {
         new MemberRebindingAnalysis(appViewWithLiveness).run(executorService);
       }
-      appViewWithLiveness.appInfo().getMutableFieldAccessInfoCollection().flattenAccessContexts();
       appViewWithLiveness
           .appInfo()
           .getMutableFieldAccessInfoCollection()
diff --git a/src/main/java/com/android/tools/r8/graph/AbstractAccessContexts.java b/src/main/java/com/android/tools/r8/graph/AbstractAccessContexts.java
index c3a7531..883ad01 100644
--- a/src/main/java/com/android/tools/r8/graph/AbstractAccessContexts.java
+++ b/src/main/java/com/android/tools/r8/graph/AbstractAccessContexts.java
@@ -5,13 +5,9 @@
 package com.android.tools.r8.graph;
 
 import com.android.tools.r8.errors.Unreachable;
-import com.android.tools.r8.graph.lens.GraphLens;
-import com.android.tools.r8.utils.MapUtils;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import java.util.IdentityHashMap;
-import java.util.Iterator;
 import java.util.Map;
-import java.util.Map.Entry;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
@@ -40,8 +36,6 @@
  */
 public abstract class AbstractAccessContexts {
 
-  abstract void flattenAccessContexts(DexField field);
-
   abstract void forEachAccessContext(Consumer<ProgramMethod> consumer);
 
   /**
@@ -50,18 +44,10 @@
   abstract boolean isAccessedInMethodSatisfying(Predicate<ProgramMethod> predicate);
 
   /**
-   * Returns true if this field is only written by methods for which {@param predicate} returns
-   * true.
-   */
-  abstract boolean isAccessedOnlyInMethodSatisfying(Predicate<ProgramMethod> predicate);
-
-  /**
    * Returns true if this field is written by a method in the program other than {@param method}.
    */
   abstract boolean isAccessedOutside(DexEncodedMethod method);
 
-  abstract int getNumberOfAccessContexts();
-
   public final boolean hasAccesses() {
     return !isEmpty();
   }
@@ -84,11 +70,6 @@
     return false;
   }
 
-  abstract AbstractAccessContexts rewrittenWithLens(
-      DexDefinitionSupplier definitions, GraphLens lens);
-
-  abstract AbstractAccessContexts withoutPrunedItems(PrunedItems prunedItems);
-
   public static EmptyAccessContexts empty() {
     return EmptyAccessContexts.getInstance();
   }
@@ -110,11 +91,6 @@
     }
 
     @Override
-    void flattenAccessContexts(DexField field) {
-      // Intentionally empty.
-    }
-
-    @Override
     void forEachAccessContext(Consumer<ProgramMethod> consumer) {
       // Intentionally empty.
     }
@@ -125,21 +101,11 @@
     }
 
     @Override
-    boolean isAccessedOnlyInMethodSatisfying(Predicate<ProgramMethod> predicate) {
-      return true;
-    }
-
-    @Override
     boolean isAccessedOutside(DexEncodedMethod method) {
       return false;
     }
 
     @Override
-    int getNumberOfAccessContexts() {
-      return 0;
-    }
-
-    @Override
     public boolean isBottom() {
       return true;
     }
@@ -150,19 +116,9 @@
     }
 
     @Override
-    AbstractAccessContexts rewrittenWithLens(DexDefinitionSupplier definitions, GraphLens lens) {
-      return this;
-    }
-
-    @Override
     public AbstractAccessContexts join(AbstractAccessContexts contexts) {
       return contexts;
     }
-
-    @Override
-    AbstractAccessContexts withoutPrunedItems(PrunedItems prunedItems) {
-      return this;
-    }
   }
 
   public static class ConcreteAccessContexts extends AbstractAccessContexts {
@@ -209,33 +165,6 @@
       return accessesWithContexts;
     }
 
-    @Override
-    int getNumberOfAccessContexts() {
-      if (accessesWithContexts.size() == 1) {
-        return accessesWithContexts.values().iterator().next().size();
-      }
-      throw new Unreachable(
-          "Should only be querying the number of access contexts after flattening");
-    }
-
-    @Override
-    void flattenAccessContexts(DexField field) {
-      if (accessesWithContexts != null) {
-        ProgramMethodSet flattenedAccessContexts =
-            accessesWithContexts.computeIfAbsent(field, ignore -> ProgramMethodSet.create());
-        accessesWithContexts.forEach(
-            (access, contexts) -> {
-              if (access.isNotIdenticalTo(field)) {
-                flattenedAccessContexts.addAll(contexts);
-              }
-            });
-        accessesWithContexts.clear();
-        if (!flattenedAccessContexts.isEmpty()) {
-          accessesWithContexts.put(field, flattenedAccessContexts);
-        }
-      }
-    }
-
     /**
      * Returns true if this field is written by a method for which {@param predicate} returns true.
      */
@@ -252,22 +181,6 @@
     }
 
     /**
-     * Returns true if this field is only written by methods for which {@param predicate} returns
-     * true.
-     */
-    @Override
-    public boolean isAccessedOnlyInMethodSatisfying(Predicate<ProgramMethod> predicate) {
-      for (ProgramMethodSet encodedWriteContexts : accessesWithContexts.values()) {
-        for (ProgramMethod encodedWriteContext : encodedWriteContexts) {
-          if (!predicate.test(encodedWriteContext)) {
-            return false;
-          }
-        }
-      }
-      return true;
-    }
-
-    /**
      * Returns true if this field is written by a method in the program other than {@param method}.
      */
     @Override
@@ -304,54 +217,6 @@
     }
 
     @Override
-    ConcreteAccessContexts rewrittenWithLens(DexDefinitionSupplier definitions, GraphLens lens) {
-      Map<DexField, ProgramMethodSet> rewrittenAccessesWithContexts = null;
-      for (Entry<DexField, ProgramMethodSet> entry : accessesWithContexts.entrySet()) {
-        DexField field = entry.getKey();
-        DexField rewrittenField = lens.lookupField(field);
-
-        ProgramMethodSet contexts = entry.getValue();
-        ProgramMethodSet rewrittenContexts = contexts.rewrittenWithLens(definitions, lens);
-
-        if (rewrittenField.isIdenticalTo(field) && rewrittenContexts == contexts) {
-          if (rewrittenAccessesWithContexts == null) {
-            continue;
-          }
-        } else {
-          if (rewrittenAccessesWithContexts == null) {
-            rewrittenAccessesWithContexts = new IdentityHashMap<>(accessesWithContexts.size());
-            MapUtils.forEachUntilExclusive(
-                accessesWithContexts, rewrittenAccessesWithContexts::put, field);
-          }
-        }
-        merge(rewrittenAccessesWithContexts, rewrittenField, rewrittenContexts);
-      }
-      if (rewrittenAccessesWithContexts != null) {
-        rewrittenAccessesWithContexts =
-            MapUtils.trimCapacityOfIdentityHashMapIfSizeLessThan(
-                rewrittenAccessesWithContexts, accessesWithContexts.size());
-        return new ConcreteAccessContexts(rewrittenAccessesWithContexts);
-      } else {
-        return this;
-      }
-    }
-
-    private static void merge(
-        Map<DexField, ProgramMethodSet> accessesWithContexts,
-        DexField field,
-        ProgramMethodSet contexts) {
-      ProgramMethodSet existingContexts = accessesWithContexts.put(field, contexts);
-      if (existingContexts != null) {
-        if (existingContexts.size() <= contexts.size()) {
-          contexts.addAll(existingContexts);
-        } else {
-          accessesWithContexts.put(field, existingContexts);
-          existingContexts.addAll(contexts);
-        }
-      }
-    }
-
-    @Override
     public AbstractAccessContexts join(AbstractAccessContexts contexts) {
       if (contexts.isEmpty()) {
         return this;
@@ -372,30 +237,6 @@
       contexts.asConcrete().accessesWithContexts.forEach(addAllMethods);
       return new ConcreteAccessContexts(newAccessesWithContexts);
     }
-
-    @Override
-    AbstractAccessContexts withoutPrunedItems(PrunedItems prunedItems) {
-      for (ProgramMethodSet methodSet : accessesWithContexts.values()) {
-        Iterator<ProgramMethod> iterator = methodSet.iterator();
-        ProgramMethodSet newAccessContexts = null;
-        while (iterator.hasNext()) {
-          DexMethod methodReference = iterator.next().getReference();
-          if (prunedItems.isRemoved(methodReference)) {
-            iterator.remove();
-            if (prunedItems.isFullyInlined(methodReference)) {
-              if (newAccessContexts == null) {
-                newAccessContexts = ProgramMethodSet.create();
-              }
-              prunedItems.forEachFullyInlinedMethodCaller(methodReference, newAccessContexts::add);
-            }
-          }
-        }
-        if (newAccessContexts != null) {
-          methodSet.addAll(newAccessContexts);
-        }
-      }
-      return this;
-    }
   }
 
   public static class UnknownAccessContexts extends AbstractAccessContexts {
@@ -409,11 +250,6 @@
     }
 
     @Override
-    void flattenAccessContexts(DexField field) {
-      // Intentionally empty.
-    }
-
-    @Override
     void forEachAccessContext(Consumer<ProgramMethod> consumer) {
       throw new Unreachable("Should never be iterating the access contexts when they are unknown");
     }
@@ -424,22 +260,11 @@
     }
 
     @Override
-    boolean isAccessedOnlyInMethodSatisfying(Predicate<ProgramMethod> predicate) {
-      return false;
-    }
-
-    @Override
     boolean isAccessedOutside(DexEncodedMethod method) {
       return true;
     }
 
     @Override
-    int getNumberOfAccessContexts() {
-      throw new Unreachable(
-          "Should never be querying the number of access contexts when they are unknown");
-    }
-
-    @Override
     public boolean isEmpty() {
       return false;
     }
@@ -450,18 +275,8 @@
     }
 
     @Override
-    AbstractAccessContexts rewrittenWithLens(DexDefinitionSupplier definitions, GraphLens lens) {
-      return this;
-    }
-
-    @Override
     public AbstractAccessContexts join(AbstractAccessContexts contexts) {
       return this;
     }
-
-    @Override
-    AbstractAccessContexts withoutPrunedItems(PrunedItems prunedItems) {
-      return this;
-    }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/FieldAccessInfoCollectionImpl.java b/src/main/java/com/android/tools/r8/graph/FieldAccessInfoCollectionImpl.java
index 92ac34d..03ffe26 100644
--- a/src/main/java/com/android/tools/r8/graph/FieldAccessInfoCollectionImpl.java
+++ b/src/main/java/com/android/tools/r8/graph/FieldAccessInfoCollectionImpl.java
@@ -30,22 +30,17 @@
     this.infos = infos;
   }
 
-  @Override
-  public void destroyAccessContexts() {
-    infos.values().forEach(FieldAccessInfoImpl::destroyAccessContexts);
-  }
-
-  @Override
-  public void flattenAccessContexts() {
-    infos.values().forEach(FieldAccessInfoImpl::flattenAccessContexts);
-  }
-
   public FieldAccessInfoImpl computeIfAbsent(
       DexField field, Function<DexField, FieldAccessInfoImpl> fn) {
     return infos.computeIfAbsent(field, fn);
   }
 
   @Override
+  public void destroyUniqueWriteContexts() {
+    infos.values().forEach(FieldAccessInfoImpl::clearUniqueWriteContext);
+  }
+
+  @Override
   public FieldAccessInfoImpl get(DexField field) {
     return infos.get(field);
   }
@@ -80,13 +75,13 @@
 
   @Override
   public FieldAccessInfoCollectionImpl rewrittenWithLens(
-      DexDefinitionSupplier definitions, GraphLens lens, Timing timing) {
+      DexDefinitionSupplier definitions, GraphLens lens, GraphLens appliedLens, Timing timing) {
     timing.begin("Rewrite FieldAccessInfoCollectionImpl");
     Map<DexField, FieldAccessInfoImpl> newInfos =
         LensUtils.mutableRewriteMap(
             infos,
             IdentityHashMap::new,
-            (field, info) -> info.rewrittenWithLens(definitions, lens, timing),
+            (field, info) -> info.rewrittenWithLens(definitions, lens, appliedLens),
             (field, info, rewrittenInfo) -> rewrittenInfo.getField(),
             (field, info, rewrittenInfo) -> rewrittenInfo,
             (field, info, otherInfo) -> info.join(otherInfo));
diff --git a/src/main/java/com/android/tools/r8/graph/FieldAccessInfoImpl.java b/src/main/java/com/android/tools/r8/graph/FieldAccessInfoImpl.java
index c3a1272..06fdee6 100644
--- a/src/main/java/com/android/tools/r8/graph/FieldAccessInfoImpl.java
+++ b/src/main/java/com/android/tools/r8/graph/FieldAccessInfoImpl.java
@@ -7,10 +7,11 @@
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AbstractAccessContexts.ConcreteAccessContexts;
 import com.android.tools.r8.graph.lens.GraphLens;
+import com.android.tools.r8.shaking.Enqueuer;
 import com.android.tools.r8.shaking.Enqueuer.FieldAccessKind;
 import com.android.tools.r8.utils.BitUtils;
+import com.android.tools.r8.utils.ObjectUtils;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
-import com.android.tools.r8.utils.timing.Timing;
 import com.google.common.collect.Sets;
 import java.util.Map;
 import java.util.Set;
@@ -35,6 +36,9 @@
   public static final int FLAG_IS_READ_FROM_RECORD_INVOKE_DYNAMIC = 1 << 6;
   public static final int FLAG_IS_READ_FROM_FIND_LITE_EXTENSION_BY_NUMBER_METHOD = 1 << 7;
   public static final int FLAG_IS_READ_FROM_NON_FIND_LITE_EXTENSION_BY_NUMBER_METHOD = 1 << 8;
+  public static final int FLAG_IS_WRITTEN_DIRECTLY = 1 << 9;
+  public static final int FLAG_IS_WRITTEN_FROM_NON_CLASS_INITIALIZER = 1 << 10;
+  public static final int FLAG_IS_WRITTEN_FROM_NON_INSTANCE_INITIALIZER = 1 << 11;
 
   // A direct reference to the definition of the field.
   private final DexField field;
@@ -50,10 +54,18 @@
   // reference appears.
   private AbstractAccessContexts writesWithContexts;
 
+  // The unique write context of this field, or null.
+  private ProgramMethod uniqueWriteContext;
+
   public FieldAccessInfoImpl(DexField field) {
     this(field, 0, AbstractAccessContexts.empty(), AbstractAccessContexts.empty());
   }
 
+  public FieldAccessInfoImpl(DexField field, int flags, ProgramMethod uniqueWriteContext) {
+    this(field, flags, AbstractAccessContexts.unknown(), AbstractAccessContexts.unknown());
+    this.uniqueWriteContext = uniqueWriteContext;
+  }
+
   public FieldAccessInfoImpl(
       DexField field,
       int flags,
@@ -65,22 +77,13 @@
     this.writesWithContexts = writesWithContexts;
   }
 
-  void destroyAccessContexts() {
-    destroyReadAccessContexts();
-    writesWithContexts = AbstractAccessContexts.unknown();
-  }
-
-  public void destroyReadAccessContexts() {
+  public void destroyAccessContexts(Enqueuer.Mode mode) {
+    assert uniqueWriteContext == null;
+    if (mode.isInitialTreeShaking()) {
+      setUniqueWriteContextFromWritesWithContexts();
+    }
     readsWithContexts = AbstractAccessContexts.unknown();
-  }
-
-  void flattenAccessContexts() {
-    flattenAccessContexts(readsWithContexts);
-    flattenAccessContexts(writesWithContexts);
-  }
-
-  private void flattenAccessContexts(AbstractAccessContexts accessesWithContexts) {
-    accessesWithContexts.flattenAccessContexts(field);
+    writesWithContexts = AbstractAccessContexts.unknown();
   }
 
   @Override
@@ -88,31 +91,36 @@
     return field;
   }
 
-  public AbstractAccessContexts getReadsWithContexts() {
-    return readsWithContexts;
-  }
-
   public void setReadsWithContexts(AbstractAccessContexts readsWithContexts) {
     this.readsWithContexts = readsWithContexts;
   }
 
-  public AbstractAccessContexts getWritesWithContexts() {
-    return writesWithContexts;
-  }
-
   public void setWritesWithContexts(AbstractAccessContexts writesWithContexts) {
     this.writesWithContexts = writesWithContexts;
   }
 
   public ProgramMethod getUniqueWriteContext() {
-    if (hasKnownWriteContexts()
-        && writesWithContexts.getNumberOfAccessContexts() == 1
-        && !isWrittenIndirectly()) {
+    // We only set the `uniqueWriteContext` when we destroy `writesWithContexts`. Therefore,
+    // disallow uses of this method if the `writesWithContexts` has not been set to top.
+    assert writesWithContexts.isTop();
+    return uniqueWriteContext;
+  }
+
+  private void setUniqueWriteContextFromWritesWithContexts() {
+    if (writesWithContexts.isConcrete() && !isWrittenIndirectly()) {
       Map<DexField, ProgramMethodSet> accessesWithContexts =
           writesWithContexts.asConcrete().getAccessesWithContexts();
-      return accessesWithContexts.values().iterator().next().iterator().next();
+      if (accessesWithContexts.size() == 1) {
+        ProgramMethodSet contexts = accessesWithContexts.values().iterator().next();
+        if (contexts.size() == 1) {
+          uniqueWriteContext = contexts.getFirst();
+        }
+      }
     }
-    return null;
+  }
+
+  void clearUniqueWriteContext() {
+    uniqueWriteContext = null;
   }
 
   @Override
@@ -196,11 +204,27 @@
 
   @Override
   public boolean isEffectivelyFinal(ProgramField field) {
-    return isWrittenOnlyInMethodSatisfying(
-        method ->
-            method.getDefinition().isInitializer()
-                && method.getAccessFlags().isStatic() == field.getAccessFlags().isStatic()
-                && method.getHolder() == field.getHolder());
+    if (field.getAccessFlags().isStatic()) {
+      return BitUtils.isBitInMaskUnset(flags, FLAG_IS_WRITTEN_FROM_NON_CLASS_INITIALIZER);
+    } else {
+      return BitUtils.isBitInMaskUnset(flags, FLAG_IS_WRITTEN_FROM_NON_INSTANCE_INITIALIZER);
+    }
+  }
+
+  public void setWrittenFromNonClassInitializer() {
+    flags |= FLAG_IS_WRITTEN_FROM_NON_CLASS_INITIALIZER;
+  }
+
+  public void clearWrittenFromNonClassInitializer() {
+    flags &= ~FLAG_IS_WRITTEN_FROM_NON_CLASS_INITIALIZER;
+  }
+
+  public void setWrittenFromNonInstanceInitializer() {
+    flags |= FLAG_IS_WRITTEN_FROM_NON_INSTANCE_INITIALIZER;
+  }
+
+  public void clearWrittenFromNonInstanceInitializer() {
+    flags &= ~FLAG_IS_WRITTEN_FROM_NON_INSTANCE_INITIALIZER;
   }
 
   /** Returns true if this field is read by the program. */
@@ -209,7 +233,7 @@
     return isReadDirectly() || isReadIndirectly();
   }
 
-  private boolean isReadDirectly() {
+  public boolean isReadDirectly() {
     return BitUtils.isBitInMaskSet(flags, FLAG_IS_READ_DIRECTLY);
   }
 
@@ -281,8 +305,16 @@
     return isWrittenDirectly() || isWrittenIndirectly();
   }
 
-  private boolean isWrittenDirectly() {
-    return !writesWithContexts.isEmpty();
+  public boolean isWrittenDirectly() {
+    return BitUtils.isBitInMaskSet(flags, FLAG_IS_WRITTEN_DIRECTLY);
+  }
+
+  public void setWrittenDirectly() {
+    flags |= FLAG_IS_WRITTEN_DIRECTLY;
+  }
+
+  private void clearWrittenDirectly() {
+    flags &= ~FLAG_IS_WRITTEN_DIRECTLY;
   }
 
   @Override
@@ -302,14 +334,6 @@
   }
 
   /**
-   * Returns true if this field is only written by methods for which {@param predicate} returns
-   * true.
-   */
-  public boolean isWrittenOnlyInMethodSatisfying(Predicate<ProgramMethod> predicate) {
-    return writesWithContexts.isAccessedOnlyInMethodSatisfying(predicate) && !isWrittenIndirectly();
-  }
-
-  /**
    * Returns true if this field is written by a method in the program other than {@param method}.
    */
   public boolean isWrittenOutside(DexEncodedMethod method) {
@@ -328,6 +352,8 @@
   }
 
   public boolean recordWrite(DexField access, ProgramMethod context) {
+    setWrittenDirectly();
+    updateEffectivelyFinalFlags(context);
     if (writesWithContexts.isBottom()) {
       writesWithContexts = new ConcreteAccessContexts();
     }
@@ -337,6 +363,20 @@
     return false;
   }
 
+  private void updateEffectivelyFinalFlags(ProgramMethod context) {
+    if (context.getAccessFlags().isConstructor()
+        && context.getHolderType().isIdenticalTo(field.getHolderType())) {
+      if (context.getAccessFlags().isStatic()) {
+        setWrittenFromNonInstanceInitializer();
+      } else {
+        setWrittenFromNonClassInitializer();
+      }
+    } else {
+      setWrittenFromNonClassInitializer();
+      setWrittenFromNonInstanceInitializer();
+    }
+  }
+
   @Override
   public void clearReads() {
     assert !hasReflectiveAccess();
@@ -349,44 +389,49 @@
 
   @Override
   public void clearWrites() {
-    writesWithContexts = AbstractAccessContexts.empty();
+    assert !isWrittenIndirectly();
+    assert writesWithContexts.isTop();
+    clearWrittenDirectly();
+    clearWrittenFromNonClassInitializer();
+    clearWrittenFromNonInstanceInitializer();
+    clearUniqueWriteContext();
   }
 
   public FieldAccessInfoImpl rewrittenWithLens(
-      DexDefinitionSupplier definitions, GraphLens lens, Timing timing) {
+      DexDefinitionSupplier definitions, GraphLens lens, GraphLens appliedLens) {
     assert readsWithContexts.isTop();
-    timing.begin("Rewrite FieldAccessInfoImpl");
-    AbstractAccessContexts rewrittenWritesWithContexts =
-        writesWithContexts.rewrittenWithLens(definitions, lens);
-    FieldAccessInfoImpl rewritten;
-    if (lens.isIdentityLensForFields(GraphLens.getIdentityLens())) {
-      if (rewrittenWritesWithContexts == writesWithContexts) {
-        rewritten = this;
-      } else {
-        rewritten =
-            new FieldAccessInfoImpl(field, flags, readsWithContexts, rewrittenWritesWithContexts);
-      }
-    } else {
-      rewritten =
-          new FieldAccessInfoImpl(
-              lens.lookupField(field), flags, readsWithContexts, rewrittenWritesWithContexts);
+    assert writesWithContexts.isTop();
+    DexField rewrittenField = lens.lookupField(field, appliedLens);
+    ProgramMethod rewrittenUniqueWriteContext = null;
+    if (uniqueWriteContext != null) {
+      rewrittenUniqueWriteContext =
+          uniqueWriteContext.rewrittenWithLens(lens, appliedLens, definitions);
     }
-    timing.end();
-    return rewritten;
+    if (rewrittenField.isIdenticalTo(field)
+        && (ObjectUtils.identical(rewrittenUniqueWriteContext, uniqueWriteContext)
+            || (rewrittenUniqueWriteContext != null
+                && rewrittenUniqueWriteContext.isStructurallyEqualTo(uniqueWriteContext)))) {
+      return this;
+    }
+    return new FieldAccessInfoImpl(rewrittenField, flags, rewrittenUniqueWriteContext);
   }
 
   public FieldAccessInfoImpl join(FieldAccessInfoImpl impl) {
     assert readsWithContexts.isTop();
-    FieldAccessInfoImpl merged = new FieldAccessInfoImpl(field);
-    merged.flags = flags | impl.flags;
-    merged.readsWithContexts = AbstractAccessContexts.unknown();
-    merged.writesWithContexts = writesWithContexts.join(impl.writesWithContexts);
-    return merged;
+    assert writesWithContexts.isTop();
+    return new FieldAccessInfoImpl(
+        field,
+        flags | impl.flags,
+        AbstractAccessContexts.unknown(),
+        AbstractAccessContexts.unknown());
   }
 
   public FieldAccessInfoImpl withoutPrunedItems(PrunedItems prunedItems) {
     assert readsWithContexts.isTop();
-    writesWithContexts = writesWithContexts.withoutPrunedItems(prunedItems);
+    assert writesWithContexts.isTop();
+    if (uniqueWriteContext != null && prunedItems.isRemoved(uniqueWriteContext.getReference())) {
+      uniqueWriteContext = null;
+    }
     return this;
   }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/MutableFieldAccessInfoCollection.java b/src/main/java/com/android/tools/r8/graph/MutableFieldAccessInfoCollection.java
index 17a8c35..845c94e 100644
--- a/src/main/java/com/android/tools/r8/graph/MutableFieldAccessInfoCollection.java
+++ b/src/main/java/com/android/tools/r8/graph/MutableFieldAccessInfoCollection.java
@@ -11,17 +11,16 @@
         S extends MutableFieldAccessInfoCollection<S, T>, T extends MutableFieldAccessInfo>
     extends FieldAccessInfoCollection<T> {
 
-  void destroyAccessContexts();
+  void destroyUniqueWriteContexts();
 
   T extend(DexField field, FieldAccessInfoImpl info);
 
-  void flattenAccessContexts();
-
   void removeIf(BiPredicate<DexField, FieldAccessInfoImpl> predicate);
 
   void restrictToProgram(DexDefinitionSupplier definitions);
 
-  S rewrittenWithLens(DexDefinitionSupplier definitions, GraphLens lens, Timing timing);
+  S rewrittenWithLens(
+      DexDefinitionSupplier definitions, GraphLens lens, GraphLens appliedLens, Timing timing);
 
   S withoutPrunedItems(PrunedItems prunedItems);
 }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index 07ae42b..7b23ff6 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -265,10 +265,7 @@
         new FieldAccessInfoCollectionModifier.Builder();
     for (HorizontalMergeGroup group : groups) {
       if (group.hasClassIdField()) {
-        DexProgramClass target = group.getTarget();
-        target.forEachProgramInstanceInitializerMatching(
-            definition -> definition.getCode().isHorizontalClassMergerCode(),
-            method -> builder.recordFieldWrittenInContext(group.getClassIdField(), method));
+        builder.addField(group.getClassIdField());
       }
     }
     return builder.build();
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
index 687b0fc..d60dc53 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
@@ -117,7 +117,10 @@
     new IdentifierMinifier(appView).rewriteDexItemBasedConstStringInStaticFields(executorService);
 
     // The field access info collection is not maintained during IR processing.
-    appView.appInfoWithLiveness().getMutableFieldAccessInfoCollection().destroyAccessContexts();
+    appView
+        .appInfoWithLiveness()
+        .getMutableFieldAccessInfoCollection()
+        .destroyUniqueWriteContexts();
 
     // Assure that no more optimization feedback left after primary processing.
     assert feedback.noUpdatesLeft();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/SharedEnumUnboxingUtilityClass.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/SharedEnumUnboxingUtilityClass.java
index 15dde02..a97f770 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/SharedEnumUnboxingUtilityClass.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/SharedEnumUnboxingUtilityClass.java
@@ -319,8 +319,7 @@
               .disableAndroidApiLevelCheckIf(
                   !appView.options().apiModelingOptions().isApiCallerIdentificationEnabled())
               .build();
-      fieldAccessInfoCollectionModifierBuilder
-          .recordFieldWriteInUnknownContext(valuesField.getReference());
+      fieldAccessInfoCollectionModifierBuilder.addField(valuesField.getReference());
       return valuesField;
     }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationInfoRemover.java b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationInfoRemover.java
index db23102..1d67dbd 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationInfoRemover.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationInfoRemover.java
@@ -48,7 +48,7 @@
     optimizationInfo.unsetDynamicType();
   }
 
-  private static void processMethod(DexEncodedMethod method) {
+  public static void processMethod(DexEncodedMethod method) {
     MutableMethodOptimizationInfo optimizationInfo =
         method.getOptimizationInfo().asMutableMethodOptimizationInfo();
     if (optimizationInfo == null) {
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index 74460ea..927ff82 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -1007,7 +1007,7 @@
         lens.rewriteReferences(bootstrapMethods),
         lens.rewriteReferences(virtualMethodsTargetedByInvokeDirect),
         lens.rewriteReferences(liveMethods),
-        fieldAccessInfoCollection.rewrittenWithLens(definitionSupplier, lens, timing),
+        fieldAccessInfoCollection.rewrittenWithLens(definitionSupplier, lens, appliedLens, timing),
         objectAllocationInfoCollection.rewrittenWithLens(
             definitionSupplier, lens, appliedLens, timing),
         lens.rewriteCallSites(callSites, definitionSupplier, timing),
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index b26877d..f6440dd 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -4483,7 +4483,7 @@
           if (field != info.getField() || info == MISSING_FIELD_ACCESS_INFO) {
             return true;
           }
-          info.destroyReadAccessContexts();
+          info.destroyAccessContexts(mode);
           return false;
         });
     assert fieldAccessInfoCollection.verifyMappingIsOneToOne();
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerDeferredTracingImpl.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerDeferredTracingImpl.java
index 551d4c6..20313a2 100644
--- a/src/main/java/com/android/tools/r8/shaking/EnqueuerDeferredTracingImpl.java
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerDeferredTracingImpl.java
@@ -27,6 +27,7 @@
 import com.android.tools.r8.ir.conversion.passes.ThrowCatchOptimizer;
 import com.android.tools.r8.ir.optimize.AssumeInserter;
 import com.android.tools.r8.ir.optimize.CodeRewriter;
+import com.android.tools.r8.ir.optimize.info.OptimizationInfoRemover;
 import com.android.tools.r8.ir.optimize.membervaluepropagation.assume.AssumeInfo;
 import com.android.tools.r8.shaking.Enqueuer.FieldAccessKind;
 import com.android.tools.r8.shaking.Enqueuer.FieldAccessMetadata;
@@ -203,7 +204,7 @@
 
     // If the field is now both read and written, then we cannot optimize the field unless the field
     // type is an uninstantiated class type.
-    if (info.getReadsWithContexts().hasAccesses() && info.getWritesWithContexts().hasAccesses()) {
+    if (info.isReadDirectly() && info.isWrittenDirectly()) {
       if (!fieldType.isClassType()) {
         return false;
       }
@@ -262,9 +263,15 @@
     // Rewrite application.
     Map<DexProgramClass, ProgramMethodSet> initializedClassesWithContexts =
         new ConcurrentHashMap<>();
+    ProgramMethodSet instanceInitializers = ProgramMethodSet.createConcurrent();
     ThreadUtils.processItems(
         methodsToProcess,
-        (method, ignored) -> rewriteMethod(method, initializedClassesWithContexts, prunedFields),
+        (method, ignored) -> {
+          rewriteMethod(method, initializedClassesWithContexts, prunedFields);
+          if (method.getDefinition().isInstanceInitializer()) {
+            instanceInitializers.add(method);
+          }
+        },
         appView.options().getThreadingModule(),
         executorService,
         WorkLoad.HEAVY);
@@ -275,6 +282,11 @@
             contexts.forEach(context -> enqueuer.traceInitClass(clazz.getType(), context)));
     assert enqueuer.getWorklist().isEmpty();
 
+    // Clear the optimization info of instance initializers since the instance initializer
+    // optimization info may change when removing field puts.
+    instanceInitializers.forEach(
+        method -> OptimizationInfoRemover.processMethod(method.getDefinition()));
+
     // Prune field access info collection.
     prunedFields.values().forEach(field -> fieldAccessInfoCollection.remove(field.getReference()));
   }
diff --git a/src/main/java/com/android/tools/r8/shaking/FieldAccessInfoCollectionModifier.java b/src/main/java/com/android/tools/r8/shaking/FieldAccessInfoCollectionModifier.java
index ef4c7f7..0e0ad5d 100644
--- a/src/main/java/com/android/tools/r8/shaking/FieldAccessInfoCollectionModifier.java
+++ b/src/main/java/com/android/tools/r8/shaking/FieldAccessInfoCollectionModifier.java
@@ -4,45 +4,20 @@
 
 package com.android.tools.r8.shaking;
 
-import com.android.tools.r8.graph.AbstractAccessContexts;
-import com.android.tools.r8.graph.AbstractAccessContexts.ConcreteAccessContexts;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.FieldAccessInfoImpl;
 import com.android.tools.r8.graph.MutableFieldAccessInfo;
 import com.android.tools.r8.graph.MutableFieldAccessInfoCollection;
-import com.android.tools.r8.graph.ProgramMethod;
-import java.util.IdentityHashMap;
-import java.util.Map;
+import com.google.common.collect.Sets;
+import java.util.Set;
 
 public class FieldAccessInfoCollectionModifier {
 
-  private static class FieldAccessContexts {
+  private final Set<DexField> newFields;
 
-    private AbstractAccessContexts writesWithContexts = AbstractAccessContexts.empty();
-
-    void addWriteContext(DexField field, ProgramMethod context) {
-      if (writesWithContexts.isBottom()) {
-        ConcreteAccessContexts concreteWriteContexts = new ConcreteAccessContexts();
-        concreteWriteContexts.recordAccess(field, context);
-        writesWithContexts = concreteWriteContexts;
-      } else if (writesWithContexts.isConcrete()) {
-        writesWithContexts.asConcrete().recordAccess(field, context);
-      } else {
-        assert writesWithContexts.isTop();
-      }
-    }
-
-    void recordWriteInUnknownContext() {
-      writesWithContexts = AbstractAccessContexts.unknown();
-    }
-  }
-
-  private final Map<DexField, FieldAccessContexts> newFieldAccessContexts;
-
-  private FieldAccessInfoCollectionModifier(
-      Map<DexField, FieldAccessContexts> newFieldAccessContexts) {
-    this.newFieldAccessContexts = newFieldAccessContexts;
+  private FieldAccessInfoCollectionModifier(Set<DexField> newFields) {
+    this.newFields = newFields;
   }
 
   public static Builder builder() {
@@ -52,37 +27,28 @@
   public void modify(AppView<AppInfoWithLiveness> appView) {
     MutableFieldAccessInfoCollection<?, ? extends MutableFieldAccessInfo>
         mutableFieldAccessInfoCollection = appView.appInfo().getMutableFieldAccessInfoCollection();
-    newFieldAccessContexts.forEach(
-        (field, accessContexts) -> {
-          FieldAccessInfoImpl fieldAccessInfo = new FieldAccessInfoImpl(field);
-          fieldAccessInfo.setReadDirectly();
-          fieldAccessInfo.setReadsWithContexts(AbstractAccessContexts.unknown());
-          fieldAccessInfo.setWritesWithContexts(accessContexts.writesWithContexts);
-          mutableFieldAccessInfoCollection.extend(field, fieldAccessInfo);
-        });
+    for (DexField field : newFields) {
+      FieldAccessInfoImpl fieldAccessInfo = new FieldAccessInfoImpl(field, 0, null);
+      assert !fieldAccessInfo.hasKnownReadContexts();
+      assert !fieldAccessInfo.hasKnownWriteContexts();
+      fieldAccessInfo.setReadDirectly();
+      fieldAccessInfo.setWrittenDirectly();
+      mutableFieldAccessInfoCollection.extend(field, fieldAccessInfo);
+    }
   }
 
   public static class Builder {
 
-    private final Map<DexField, FieldAccessContexts> newFieldAccessContexts =
-        new IdentityHashMap<>();
+    private final Set<DexField> newFields = Sets.newIdentityHashSet();
 
     public Builder() {}
 
-    private FieldAccessContexts getFieldAccessContexts(DexField field) {
-      return newFieldAccessContexts.computeIfAbsent(field, ignore -> new FieldAccessContexts());
-    }
-
-    public void recordFieldWrittenInContext(DexField field, ProgramMethod context) {
-      getFieldAccessContexts(field).addWriteContext(field, context);
-    }
-
-    public void recordFieldWriteInUnknownContext(DexField field) {
-      getFieldAccessContexts(field).recordWriteInUnknownContext();
+    public void addField(DexField field) {
+      newFields.add(field);
     }
 
     public FieldAccessInfoCollectionModifier build() {
-      return new FieldAccessInfoCollectionModifier(newFieldAccessContexts);
+      return new FieldAccessInfoCollectionModifier(newFields);
     }
   }
 }