Initial support for retaining proto fields that contain required/map fields

Bug: 112437944
Change-Id: I1ed78c18c712cfbbdd736e5bf34ec72694e865ba
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index 0ac2984..7b34833 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -634,7 +634,7 @@
             }
           }
 
-          Enqueuer enqueuer = EnqueuerFactory.createForPostTreeShaking(appView, keptGraphConsumer);
+          Enqueuer enqueuer = EnqueuerFactory.createForFinalTreeShaking(appView, keptGraphConsumer);
           appView.setAppInfo(
               enqueuer.traceApplication(
                   appView.rootSet(),
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index a500d17..d0cbbbe 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.kotlin.KotlinInfo;
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.PredicateUtils;
 import com.google.common.base.MoreObjects;
 import com.google.common.base.Predicates;
 import com.google.common.collect.Iterables;
@@ -575,6 +576,11 @@
     return lookupTarget(virtualMethods, method);
   }
 
+  /** Find virtual method in this class matching {@param predicate}. */
+  public DexEncodedMethod lookupVirtualMethod(Predicate<DexEncodedMethod> predicate) {
+    return PredicateUtils.findFirst(virtualMethods, predicate);
+  }
+
   /** Find method in this class matching {@param method}. */
   public DexEncodedMethod lookupMethod(DexMethod method) {
     DexEncodedMethod result = lookupDirectMethod(method);
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
index 4429aa0..3b842eb 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
@@ -7,7 +7,6 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.ir.analysis.proto.schema.ProtoFieldTypeFactory;
 import com.android.tools.r8.ir.analysis.proto.schema.ProtoMessageInfo;
 import com.android.tools.r8.ir.analysis.proto.schema.ProtoObject;
 import com.android.tools.r8.ir.analysis.type.Nullability;
@@ -38,11 +37,9 @@
   private final TypeLatticeElement objectArrayType;
   private final TypeLatticeElement stringType;
 
-  private final ProtoFieldTypeFactory factory = new ProtoFieldTypeFactory();
-
   public GeneratedMessageLiteShrinker(AppView<AppInfoWithLiveness> appView) {
     this.appView = appView;
-    this.decoder = new RawMessageInfoDecoder(factory);
+    this.decoder = appView.protoShrinker().decoder;
     this.encoder = new RawMessageInfoEncoder(appView.dexItemFactory());
     this.references = appView.protoShrinker().references;
     this.throwingInfo = ThrowingInfo.defaultForConstString(appView.options());
@@ -74,7 +71,7 @@
       return;
     }
 
-    InvokeMethod newMessageInfoInvoke = getNewMessageInfoInvoke(code);
+    InvokeMethod newMessageInfoInvoke = getNewMessageInfoInvoke(code, references);
     if (newMessageInfoInvoke != null) {
       // If this invoke is targeting RawMessageInfo.<init>(...) then `info` and `objects` is at
       // positions 2 and 3, respectively, and not position 1 and 2 as when calling the static method
@@ -88,7 +85,7 @@
       Value objectsValue = newMessageInfoInvoke.inValues().get(2 + adjustment).getAliasedValue();
 
       // Decode the arguments passed to newMessageInfo().
-      ProtoMessageInfo protoMessageInfo = decoder.run(infoValue, objectsValue, context);
+      ProtoMessageInfo protoMessageInfo = decoder.run(context, infoValue, objectsValue);
       if (protoMessageInfo != null) {
         // Rewrite the arguments to newMessageInfo().
         rewriteArgumentsToNewMessageInfo(
@@ -166,7 +163,7 @@
     }
   }
 
-  private InvokeMethod getNewMessageInfoInvoke(IRCode code) {
+  public static InvokeMethod getNewMessageInfoInvoke(IRCode code, ProtoReferences references) {
     for (Instruction instruction : code.instructions()) {
       if (instruction.isInvokeMethod()) {
         InvokeMethod invoke = instruction.asInvokeMethod();
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
index da66929..0919db9 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
@@ -4,6 +4,7 @@
 
 package com.android.tools.r8.ir.analysis.proto;
 
+import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProto;
@@ -78,6 +79,10 @@
     return method.name == dynamicMethodName && method.proto == dynamicMethodProto;
   }
 
+  public boolean isDynamicMethod(DexEncodedMethod encodedMethod) {
+    return isDynamicMethod(encodedMethod.method);
+  }
+
   public boolean isFindLiteExtensionByNumberMethod(DexMethod method) {
     return method.proto == findLiteExtensionByNumberProto
         && method.name.startsWith(findLiteExtensionByNumberName);
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoShrinker.java
index 81d89fe..599bb84 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoShrinker.java
@@ -5,15 +5,21 @@
 package com.android.tools.r8.ir.analysis.proto;
 
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.ir.analysis.proto.schema.ProtoFieldTypeFactory;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 
 public class ProtoShrinker {
 
+  public final RawMessageInfoDecoder decoder;
+  public final ProtoFieldTypeFactory factory;
   public final GeneratedExtensionRegistryShrinker generatedExtensionRegistryShrinker;
   public final ProtoReferences references;
 
   public ProtoShrinker(AppView<AppInfoWithLiveness> appView) {
+    ProtoFieldTypeFactory factory = new ProtoFieldTypeFactory();
     ProtoReferences references = new ProtoReferences(appView.dexItemFactory());
+    this.decoder = new RawMessageInfoDecoder(factory, references);
+    this.factory = factory;
     this.generatedExtensionRegistryShrinker =
         appView.options().enableGeneratedExtensionRegistryShrinking
             ? new GeneratedExtensionRegistryShrinker(appView, references)
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoUtils.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoUtils.java
new file mode 100644
index 0000000..06a7bfe
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoUtils.java
@@ -0,0 +1,33 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.analysis.proto;
+
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.utils.BooleanUtils;
+
+public class ProtoUtils {
+
+  static Value getInfoValueFromMessageInfoConstructionInvoke(
+      InvokeMethod invoke, ProtoReferences references) {
+    assert references.isMessageInfoConstructionMethod(invoke.getInvokedMethod());
+    int adjustment = BooleanUtils.intValue(invoke.isInvokeDirect());
+    return invoke.inValues().get(1 + adjustment).getAliasedValue();
+  }
+
+  static Value getObjectsValueFromMessageInfoConstructionInvoke(
+      InvokeMethod invoke, ProtoReferences references) {
+    assert references.isMessageInfoConstructionMethod(invoke.getInvokedMethod());
+    int adjustment = BooleanUtils.intValue(invoke.isInvokeDirect());
+    return invoke.inValues().get(2 + adjustment).getAliasedValue();
+  }
+
+  static void setObjectsValueForMessageInfoConstructionInvoke(
+      InvokeMethod invoke, Value newObjectsValue, ProtoReferences references) {
+    assert references.isMessageInfoConstructionMethod(invoke.getInvokedMethod());
+    int adjustment = BooleanUtils.intValue(invoke.isInvokeDirect());
+    invoke.replaceValue(2 + adjustment, newObjectsValue);
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
index ddc98e2..f9b1d62 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
@@ -4,6 +4,9 @@
 
 package com.android.tools.r8.ir.analysis.proto;
 
+import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.getInfoValueFromMessageInfoConstructionInvoke;
+import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.getObjectsValueFromMessageInfoConstructionInvoke;
+
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexReference;
@@ -23,6 +26,7 @@
 import com.android.tools.r8.ir.code.DexItemBasedConstString;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InstructionIterator;
+import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.NewArrayEmpty;
 import com.android.tools.r8.ir.code.StaticGet;
@@ -61,15 +65,24 @@
  */
 public class RawMessageInfoDecoder {
 
-  private static final int IS_PROTO_2_MASK = 0x1;
+  public static final int IS_PROTO_2_MASK = 0x1;
 
   private final ProtoFieldTypeFactory factory;
+  private final ProtoReferences references;
 
-  RawMessageInfoDecoder(ProtoFieldTypeFactory factory) {
+  RawMessageInfoDecoder(ProtoFieldTypeFactory factory, ProtoReferences references) {
     this.factory = factory;
+    this.references = references;
   }
 
-  public ProtoMessageInfo run(Value infoValue, Value objectsValue, DexClass context) {
+  public ProtoMessageInfo run(DexClass context, InvokeMethod invoke) {
+    assert references.isMessageInfoConstructionMethod(invoke.getInvokedMethod());
+    Value infoValue = getInfoValueFromMessageInfoConstructionInvoke(invoke, references);
+    Value objectsValue = getObjectsValueFromMessageInfoConstructionInvoke(invoke, references);
+    return run(context, infoValue, objectsValue);
+  }
+
+  public ProtoMessageInfo run(DexClass context, Value infoValue, Value objectsValue) {
     try {
       ProtoMessageInfo.Builder builder = ProtoMessageInfo.builder();
       ThrowingIntIterator<InvalidRawMessageInfoException> infoIterator =
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
new file mode 100644
index 0000000..1d77c62
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
@@ -0,0 +1,272 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.analysis.proto.schema;
+
+import com.android.tools.r8.OptionalBool;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedField;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.analysis.EnqueuerAnalysis;
+import com.android.tools.r8.ir.analysis.proto.GeneratedMessageLiteShrinker;
+import com.android.tools.r8.ir.analysis.proto.ProtoReferences;
+import com.android.tools.r8.ir.analysis.proto.RawMessageInfoDecoder;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.shaking.Enqueuer;
+import com.android.tools.r8.shaking.EnqueuerWorklist;
+import com.android.tools.r8.shaking.KeepReason;
+import java.util.IdentityHashMap;
+import java.util.Map;
+
+// TODO(b/112437944): Handle cycles in the graph + add a test that fails with the current
+//  implementation. The current caching mechanism is unsafe, because we may mark a message as not
+//  containing a map/required field in presence of cycles, although it does.
+
+// TODO(b/112437944): Handle extensions in the map/required field detection + add a test that fails
+//  with the current implementation. If there is a field whose type is an extension, then we should
+//  look if any of the applicable extensions could contain a map/required field.
+
+// TODO(b/112437944): Handle incomplete information about extensions + add a test that fails with
+//  the current implementation. If there are some extensions that cannot be resolved, then we should
+//  keep fields that could reach extensions to be conservative.
+public class ProtoEnqueuerExtension extends EnqueuerAnalysis {
+
+  private final AppView<?> appView;
+  private final RawMessageInfoDecoder decoder;
+  private final ProtoFieldTypeFactory factory;
+  private final ProtoReferences references;
+
+  // Mapping for the set of proto message that have already become live.
+  private final Map<DexType, ProtoMessageInfo> liveProtos = new IdentityHashMap<>();
+
+  // Mapping for additional proto messages that have not yet become live. If there is a proto that
+  // has become live, then its schema may refer to another proto message that has not yet become
+  // live. In that case, we still need to decode the schema of the not-yet-live proto message,
+  // because we need to check if it has a required/map field.
+  private final Map<DexType, ProtoMessageInfo> seenButNotLiveProtos = new IdentityHashMap<>();
+
+  // To cache whether a proto message contains a map/required field directly or indirectly.
+  private final Map<ProtoMessageInfo, OptionalBool> reachesMapOrRequiredFieldFromMessageCache =
+      new IdentityHashMap<>();
+
+  public ProtoEnqueuerExtension(AppView<?> appView) {
+    this.appView = appView;
+    this.decoder = appView.protoShrinker().decoder;
+    this.factory = appView.protoShrinker().factory;
+    this.references = appView.protoShrinker().references;
+  }
+
+  /**
+   * When a dynamicMethod() of a proto message becomes live, then build the corresponding {@link
+   * ProtoMessageInfo} object, and create a mapping from the holder to it.
+   */
+  @Override
+  public void processNewlyLiveMethod(DexEncodedMethod encodedMethod) {
+    if (!references.isDynamicMethod(encodedMethod)) {
+      return;
+    }
+
+    DexType holder = encodedMethod.method.holder;
+    if (seenButNotLiveProtos.containsKey(holder)) {
+      // The proto is now live instead of dead.
+      liveProtos.put(holder, seenButNotLiveProtos.remove(holder));
+      return;
+    }
+
+    // Since this dynamicMethod() only becomes live once, and it has just become live, it must be
+    // the case that the proto is not already live.
+    assert !liveProtos.containsKey(holder);
+    createProtoMessageInfoFromDynamicMethod(encodedMethod, liveProtos);
+  }
+
+  private void createProtoMessageInfoFromDynamicMethod(
+      DexEncodedMethod dynamicMethod, Map<DexType, ProtoMessageInfo> protos) {
+    DexType holder = dynamicMethod.method.holder;
+    assert !protos.containsKey(holder);
+
+    DexClass context = appView.definitionFor(holder);
+    if (context == null || !context.isProgramClass()) {
+      // TODO(b/112437944): What if a proto message references a proto message on the classpath or
+      //  library path? We should treat them as having a map/required field to be conservative.
+      assert false; // Should generally not happen.
+      return;
+    }
+
+    IRCode code = dynamicMethod.buildIR(appView, context.origin);
+    InvokeMethod newMessageInfoInvoke =
+        GeneratedMessageLiteShrinker.getNewMessageInfoInvoke(code, references);
+    ProtoMessageInfo protoMessageInfo =
+        newMessageInfoInvoke != null ? decoder.run(context, newMessageInfoInvoke) : null;
+    protos.put(holder, protoMessageInfo);
+  }
+
+  @Override
+  public void notifyFixpoint(Enqueuer enqueuer, EnqueuerWorklist worklist) {
+    // TODO(b/112437944): We only need to check if a given field can reach a map/required field
+    //  once. Maybe maintain a map `newlyLiveProtos` that store the set of proto messages that have
+    //  become live since the last intermediate fixpoint.
+
+    // TODO(b/112437944): We only need to visit the subset of protos in `liveProtos` that has at
+    //  least one field that is not yet live. Maybe split `liveProtos` into `partiallyLiveProtos`
+    //  and `fullyLiveProtos`.
+    for (ProtoMessageInfo protoMessageInfo : liveProtos.values()) {
+      if (protoMessageInfo == null || !protoMessageInfo.hasFields()) {
+        continue;
+      }
+
+      for (ProtoFieldInfo protoFieldInfo : protoMessageInfo.getFields()) {
+        DexField valueStorage = protoFieldInfo.getValueStorage(protoMessageInfo);
+        DexEncodedField encodedValueStorage = appView.appInfo().resolveField(valueStorage);
+        if (encodedValueStorage == null) {
+          continue;
+        }
+
+        DexClass clazz = appView.definitionFor(encodedValueStorage.field.holder);
+        if (clazz == null || !clazz.isProgramClass()) {
+          assert false; // Should generally not happen.
+          continue;
+        }
+
+        DexEncodedMethod dynamicMethod = clazz.lookupVirtualMethod(references::isDynamicMethod);
+        if (dynamicMethod == null) {
+          assert false; // Should generally not happen.
+          continue;
+        }
+
+        boolean encodedValueStorageIsLive;
+        if (enqueuer.isFieldLive(encodedValueStorage)) {
+          // Mark the field as both read and written, since it is used reflectively.
+          enqueuer.registerFieldAccess(encodedValueStorage.field, dynamicMethod);
+          encodedValueStorageIsLive = true;
+        } else if (reachesMapOrRequiredField(protoFieldInfo)) {
+          enqueuer.registerFieldAccess(encodedValueStorage.field, dynamicMethod);
+          worklist.enqueueMarkReachableFieldAction(
+              encodedValueStorage.field, KeepReason.reflectiveUseIn(dynamicMethod));
+          encodedValueStorageIsLive = true;
+        } else {
+          encodedValueStorageIsLive = false;
+        }
+
+        DexField newlyLiveField = null;
+        if (encodedValueStorageIsLive) {
+          // For one-of fields, mark the corresponding one-of-case field as live, and for proto2
+          // singular fields, mark the corresponding hazzer-bit field as live.
+          if (protoFieldInfo.getType().isOneOf()) {
+            newlyLiveField = protoFieldInfo.getOneOfCaseField(protoMessageInfo);
+          } else if (protoFieldInfo.hasHazzerBitField(protoMessageInfo)) {
+            newlyLiveField = protoFieldInfo.getHazzerBitField(protoMessageInfo);
+          }
+        } else {
+          // For one-of fields, mark the one-of field as live if the one-of-case field is live, and
+          // for proto2 singular fields, mark the field as live if the corresponding hazzer-bit
+          // field is live.
+          if (protoFieldInfo.getType().isOneOf()) {
+            DexField oneOfCaseField = protoFieldInfo.getOneOfCaseField(protoMessageInfo);
+            DexEncodedField encodedOneOfCaseField = appView.appInfo().resolveField(oneOfCaseField);
+            if (encodedOneOfCaseField != null && enqueuer.isFieldLive(encodedOneOfCaseField)) {
+              newlyLiveField = encodedValueStorage.field;
+            }
+          } else if (protoFieldInfo.hasHazzerBitField(protoMessageInfo)) {
+            DexField hazzerBitField = protoFieldInfo.getHazzerBitField(protoMessageInfo);
+            DexEncodedField encodedHazzerBitField = appView.appInfo().resolveField(hazzerBitField);
+            if (encodedHazzerBitField != null && enqueuer.isFieldLive(encodedHazzerBitField)) {
+              newlyLiveField = encodedValueStorage.field;
+            }
+          }
+        }
+
+        if (newlyLiveField != null) {
+          if (enqueuer.registerFieldAccess(newlyLiveField, dynamicMethod)) {
+            worklist.enqueueMarkReachableFieldAction(
+                newlyLiveField, KeepReason.reflectiveUseIn(dynamicMethod));
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Traverses the proto schema graph.
+   *
+   * @return true if this proto field contains a map/required field directly or indirectly.
+   */
+  private boolean reachesMapOrRequiredField(ProtoFieldInfo protoFieldInfo) {
+    ProtoFieldType protoFieldType = protoFieldInfo.getType();
+
+    // If it is a map/required field, return true.
+    if (protoFieldType.isMap() || protoFieldType.isRequired()) {
+      return true;
+    }
+
+    // Otherwise, check if the type of the field may contain a map/required field.
+    DexType baseMessageType = protoFieldInfo.getBaseMessageType(factory);
+    if (baseMessageType != null) {
+      ProtoMessageInfo protoMessageInfo = getOrCreateProtoMessageInfo(baseMessageType);
+      assert protoMessageInfo != null;
+      return reachesMapOrRequiredField(protoMessageInfo);
+    }
+    return false;
+  }
+
+  /**
+   * Traverses the proto schema graph.
+   *
+   * @return true if this proto message contains a map/required field directly or indirectly.
+   */
+  private boolean reachesMapOrRequiredField(ProtoMessageInfo protoMessageInfo) {
+    if (!protoMessageInfo.hasFields()) {
+      return false;
+    }
+    OptionalBool cache =
+        reachesMapOrRequiredFieldFromMessageCache.getOrDefault(
+            protoMessageInfo, OptionalBool.unknown());
+    if (!cache.isUnknown()) {
+      return cache.isTrue();
+    }
+
+    // To guard against infinite recursion, we set the cache for this message to false, although
+    // we may later find out that this message actually contains a map/required field.
+    reachesMapOrRequiredFieldFromMessageCache.put(protoMessageInfo, OptionalBool.of(false));
+
+    // Check if any of the fields contains a map/required field.
+    for (ProtoFieldInfo protoFieldInfo : protoMessageInfo.getFields()) {
+      if (reachesMapOrRequiredField(protoFieldInfo)) {
+        reachesMapOrRequiredFieldFromMessageCache.put(protoMessageInfo, OptionalBool.of(true));
+        return true;
+      }
+    }
+
+    return false;
+  }
+
+  private ProtoMessageInfo getOrCreateProtoMessageInfo(DexType type) {
+    if (liveProtos.containsKey(type)) {
+      return liveProtos.get(type);
+    }
+
+    if (seenButNotLiveProtos.containsKey(type)) {
+      return seenButNotLiveProtos.get(type);
+    }
+
+    DexClass clazz = appView.definitionFor(type);
+    if (clazz == null || !clazz.isProgramClass()) {
+      seenButNotLiveProtos.put(type, null);
+      return null;
+    }
+
+    DexEncodedMethod dynamicMethod = clazz.lookupVirtualMethod(references::isDynamicMethod);
+    if (dynamicMethod == null) {
+      seenButNotLiveProtos.put(type, null);
+      return null;
+    }
+
+    createProtoMessageInfoFromDynamicMethod(dynamicMethod, seenButNotLiveProtos);
+
+    return seenButNotLiveProtos.get(type);
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
index 153ea66..d8af9b6 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
@@ -4,6 +4,10 @@
 
 package com.android.tools.r8.ir.analysis.proto.schema;
 
+import static com.android.tools.r8.ir.analysis.proto.schema.ProtoMessageInfo.BITS_PER_HAS_BITS_WORD;
+
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexType;
 import java.util.List;
 import java.util.OptionalInt;
 
@@ -43,4 +47,120 @@
   public ProtoFieldType getType() {
     return type;
   }
+
+  /**
+   * For singular/repeated message-type fields, the type of the message.
+   *
+   * <p>This isn't populated for map-type fields because that's a bit difficult, but this doesn't
+   * matter in practice. We only use this for determining whether to retain a field based on whether
+   * it reaches a map/required field, but there's no need to go through that when we're already at
+   * one.
+   */
+  public DexType getBaseMessageType(ProtoFieldTypeFactory factory) {
+    if (type.isOneOf()) {
+      ProtoFieldType actualFieldType = type.asOneOf().getActualFieldType(factory);
+      if (actualFieldType.isGroup() || actualFieldType.isMessage()) {
+        ProtoObject object = objects.get(0);
+        assert object.isProtoTypeObject();
+        return object.asProtoTypeObject().getType();
+      }
+      return null;
+    }
+    if (type.isMessage() || type.isGroup()) {
+      ProtoObject object = objects.get(0);
+      assert object.isProtoFieldObject() : object.toString();
+      return object.asProtoFieldObject().getField().type;
+    }
+    if (type.isMessageList() || type.isGroupList()) {
+      ProtoObject object = objects.get(1);
+      assert object.isProtoTypeObject();
+      return object.asProtoTypeObject().getType();
+    }
+    return null;
+  }
+
+  /**
+   * (Proto2 singular fields only.)
+   *
+   * <p>Java field for denoting the presence of a protobuf field.
+   *
+   * <p>The generated Java code for:
+   *
+   * <pre>
+   *   message MyMessage {
+   *     optional int32 foo = 123;
+   *     optional int32 bar = 456;
+   *   }
+   * </pre>
+   *
+   * looks like:
+   *
+   * <pre>
+   *   boolean hasFoo() { return bitField0_ & 0x1; }
+   *   boolean hasBar() { return bitField0_ & 0x2; }
+   * </pre>
+   */
+  public boolean hasHazzerBitField(ProtoMessageInfo protoMessageInfo) {
+    return protoMessageInfo.isProto2() && type.isSingular();
+  }
+
+  public DexField getHazzerBitField(ProtoMessageInfo protoMessageInfo) {
+    assert hasHazzerBitField(protoMessageInfo);
+
+    int hasBitsIndex = getAuxData() / BITS_PER_HAS_BITS_WORD;
+    assert hasBitsIndex < protoMessageInfo.numberOfHasBitsObjects();
+
+    ProtoObject object = protoMessageInfo.getHasBitsObjects().get(hasBitsIndex);
+    assert object.isProtoFieldObject();
+    return object.asProtoFieldObject().getField();
+  }
+
+  /**
+   * (One-of fields only.)
+   *
+   * <p>Java field identifying what the containing oneof is currently being used for.
+   *
+   * <p>The generated Java code for:
+   *
+   * <pre>
+   *   message MyMessage {
+   *     oneof my_oneof {
+   *       int32 x = 123;
+   *       ...
+   *     }
+   *   }
+   * </pre>
+   *
+   * looks like:
+   *
+   * <pre>
+   *   ... getMyOneofCase() { return myOneofCase_; }
+   *   int getX() {
+   *     if (myOneofCase_ == 123) return (Integer) myOneof_;
+   *     return 0;
+   *   }
+   * </pre>
+   */
+  public DexField getOneOfCaseField(ProtoMessageInfo protoMessageInfo) {
+    assert type.isOneOf();
+
+    ProtoObject object = protoMessageInfo.getOneOfObjects().get(getAuxData()).getSecond();
+    assert object.isProtoFieldObject();
+    return object.asProtoFieldObject().getField();
+  }
+
+  /**
+   * Data about the field as referenced from the Java implementation.
+   *
+   * <p>Java field into which the value is stored; constituents of a oneof all share the same
+   * storage.
+   */
+  public DexField getValueStorage(ProtoMessageInfo protoMessageInfo) {
+    ProtoObject object =
+        type.isOneOf()
+            ? protoMessageInfo.getOneOfObjects().get(getAuxData()).getFirst()
+            : objects.get(0);
+    assert object.isProtoFieldObject();
+    return object.asProtoFieldObject().getField();
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldObject.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldObject.java
index 809791f..dbd3c7f 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldObject.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldObject.java
@@ -24,6 +24,10 @@
     this.field = field;
   }
 
+  public DexField getField() {
+    return field;
+  }
+
   @Override
   public Instruction buildIR(AppView<?> appView, IRCode code) {
     Value value =
@@ -36,4 +40,14 @@
     }
     return new ConstString(value, field.name, throwingInfo);
   }
+
+  @Override
+  public boolean isProtoFieldObject() {
+    return true;
+  }
+
+  @Override
+  public ProtoFieldObject asProtoFieldObject() {
+    return this;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldType.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldType.java
index 4174eac..9566a91 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldType.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldType.java
@@ -67,14 +67,38 @@
     return id;
   }
 
+  public boolean isGroup() {
+    return id == GROUP_ID;
+  }
+
+  public boolean isGroupList() {
+    return id == GROUP_LIST_ID;
+  }
+
+  public boolean isMap() {
+    return id == MAP_ID;
+  }
+
   public boolean isMapFieldWithProto2EnumValue() {
     return isMapFieldWithProto2EnumValue;
   }
 
+  public boolean isMessage() {
+    return id == MESSAGE_ID;
+  }
+
+  public boolean isMessageList() {
+    return id == MESSAGE_LIST_ID;
+  }
+
   public boolean isOneOf() {
     return false;
   }
 
+  public ProtoOneOfFieldType asOneOf() {
+    return null;
+  }
+
   public boolean isRequired() {
     return isRequired;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
index 9098d21..d0871df 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
@@ -4,12 +4,16 @@
 
 package com.android.tools.r8.ir.analysis.proto.schema;
 
+import static com.android.tools.r8.ir.analysis.proto.RawMessageInfoDecoder.IS_PROTO_2_MASK;
+
 import com.android.tools.r8.utils.Pair;
 import java.util.ArrayList;
 import java.util.List;
 
 public class ProtoMessageInfo {
 
+  public static final int BITS_PER_HAS_BITS_WORD = 32;
+
   public static class Builder {
 
     private int flags;
@@ -81,6 +85,10 @@
     return new ProtoMessageInfo.Builder();
   }
 
+  public boolean isProto2() {
+    return (flags & IS_PROTO_2_MASK) != 0;
+  }
+
   public List<ProtoFieldInfo> getFields() {
     return fields;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoObject.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoObject.java
index ad3a916..31ef09f 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoObject.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoObject.java
@@ -11,4 +11,20 @@
 public abstract class ProtoObject {
 
   public abstract Instruction buildIR(AppView<?> appView, IRCode code);
+
+  public boolean isProtoFieldObject() {
+    return false;
+  }
+
+  public ProtoFieldObject asProtoFieldObject() {
+    return null;
+  }
+
+  public boolean isProtoTypeObject() {
+    return false;
+  }
+
+  public ProtoTypeObject asProtoTypeObject() {
+    return null;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoOneOfFieldType.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoOneOfFieldType.java
index afafcca..1cc7b97 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoOneOfFieldType.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoOneOfFieldType.java
@@ -33,6 +33,11 @@
   }
 
   @Override
+  public ProtoOneOfFieldType asOneOf() {
+    return this;
+  }
+
+  @Override
   public boolean isSingular() {
     return true;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoTypeObject.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoTypeObject.java
index 21c37e0..f785a41 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoTypeObject.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoTypeObject.java
@@ -17,8 +17,22 @@
     this.type = type;
   }
 
+  public DexType getType() {
+    return type;
+  }
+
   @Override
   public Instruction buildIR(AppView<?> appView, IRCode code) {
     return code.createConstClass(appView, type);
   }
+
+  @Override
+  public boolean isProtoTypeObject() {
+    return true;
+  }
+
+  @Override
+  public ProtoTypeObject asProtoTypeObject() {
+    return this;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 217d655..67056fe 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -49,6 +49,7 @@
 import com.android.tools.r8.graph.PresortedComparable;
 import com.android.tools.r8.graph.TopDownClassHierarchyTraversal;
 import com.android.tools.r8.graph.analysis.EnqueuerAnalysis;
+import com.android.tools.r8.ir.analysis.proto.schema.ProtoEnqueuerExtension;
 import com.android.tools.r8.ir.code.ArrayPut;
 import com.android.tools.r8.ir.code.ConstantValueUtils;
 import com.android.tools.r8.ir.code.IRCode;
@@ -110,17 +111,25 @@
  */
 public class Enqueuer {
 
-  enum Mode {
+  public enum Mode {
     INITIAL_TREE_SHAKING,
-    POST_TREE_SHAKING,
+    FINAL_TREE_SHAKING,
     MAIN_DEX_TRACING,
     WHY_ARE_YOU_KEEPING;
 
-    boolean isInitialTreeShaking() {
+    public boolean isInitialTreeShaking() {
       return this == INITIAL_TREE_SHAKING;
     }
 
-    boolean isTracingMainDex() {
+    public boolean isFinalTreeShaking() {
+      return this == FINAL_TREE_SHAKING;
+    }
+
+    public boolean isInitialOrFinalTreeShaking() {
+      return isInitialTreeShaking() || isFinalTreeShaking();
+    }
+
+    public boolean isTracingMainDex() {
       return this == MAIN_DEX_TRACING;
     }
   }
@@ -304,13 +313,22 @@
       ProguardConfiguration.Builder compatibility,
       Mode mode) {
     assert appView.appServices() != null;
+    InternalOptions options = appView.options();
     this.appInfo = appView.appInfo();
     this.appView = appView;
     this.compatibility = compatibility;
-    this.forceProguardCompatibility = appView.options().forceProguardCompatibility;
+    this.forceProguardCompatibility = options.forceProguardCompatibility;
     this.keptGraphConsumer = keptGraphConsumer;
     this.mode = mode;
-    this.options = appView.options();
+    this.options = options;
+
+    if (options.enableGeneratedMessageLiteShrinking && mode.isInitialOrFinalTreeShaking()) {
+      registerAnalysis(new ProtoEnqueuerExtension(appView));
+    }
+  }
+
+  public Mode getMode() {
+    return mode;
   }
 
   public Enqueuer registerAnalysis(EnqueuerAnalysis analysis) {
@@ -452,6 +470,12 @@
     return registerFieldAccess(field, context, false);
   }
 
+  public boolean registerFieldAccess(DexField field, DexEncodedMethod context) {
+    boolean changed = registerFieldAccess(field, context, true);
+    changed |= registerFieldAccess(field, context, false);
+    return changed;
+  }
+
   private boolean registerFieldAccess(DexField field, DexEncodedMethod context, boolean isRead) {
     FieldAccessInfoImpl info = fieldAccessInfoCollection.get(field);
     if (info == null) {
@@ -1481,6 +1505,10 @@
     }
   }
 
+  public boolean isFieldLive(DexEncodedField field) {
+    return liveFields.contains(field);
+  }
+
   private boolean isInstantiatedOrHasInstantiatedSubtype(DexType type) {
     return instantiatedTypes.contains(type)
         || instantiatedLambdas.contains(type)
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerFactory.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerFactory.java
index 1a9e801..340b469 100644
--- a/src/main/java/com/android/tools/r8/shaking/EnqueuerFactory.java
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerFactory.java
@@ -22,9 +22,9 @@
     return new Enqueuer(appView, null, compatibility, Mode.INITIAL_TREE_SHAKING);
   }
 
-  public static Enqueuer createForPostTreeShaking(
+  public static Enqueuer createForFinalTreeShaking(
       AppView<? extends AppInfoWithSubtyping> appView, GraphConsumer keptGraphConsumer) {
-    return new Enqueuer(appView, keptGraphConsumer, null, Mode.POST_TREE_SHAKING);
+    return new Enqueuer(appView, keptGraphConsumer, null, Mode.FINAL_TREE_SHAKING);
   }
 
   public static Enqueuer createForMainDexTracing(AppView<? extends AppInfoWithSubtyping> appView) {
diff --git a/src/main/java/com/android/tools/r8/utils/PredicateUtils.java b/src/main/java/com/android/tools/r8/utils/PredicateUtils.java
new file mode 100644
index 0000000..2c6fac0
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/PredicateUtils.java
@@ -0,0 +1,19 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import java.util.function.Predicate;
+
+public class PredicateUtils {
+
+  public static <T> T findFirst(T[] items, Predicate<T> predicate) {
+    for (T entry : items) {
+      if (predicate.test(entry)) {
+        return entry;
+      }
+    }
+    return null;
+  }
+}