Maintain the set of dead proto types for definitionFor

Bug: 149363884, 150736225
Change-Id: I6de89063c80ecf2d35981aa7fe37bbd2b85176e5
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
index feca06c..0801f16 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
@@ -36,6 +36,7 @@
 import com.android.tools.r8.ir.optimize.inliner.FixedInliningReasonStrategy;
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.shaking.Enqueuer;
 import com.android.tools.r8.utils.PredicateSet;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
@@ -64,11 +65,16 @@
 
   /** Returns true if an action was deferred. */
   public boolean deferDeadProtoBuilders(
-      DexProgramClass clazz, DexEncodedMethod context, BooleanSupplier register) {
+      DexProgramClass clazz,
+      DexEncodedMethod context,
+      BooleanSupplier register,
+      Enqueuer enqueuer) {
     if (references.isDynamicMethod(context) && references.isGeneratedMessageLiteBuilder(clazz)) {
       if (register.getAsBoolean()) {
-        assert builders.getOrDefault(clazz, context) == context;
-        builders.put(clazz, context);
+        if (enqueuer.getMode().isFinalTreeShaking()) {
+          assert builders.getOrDefault(clazz, context) == context;
+          builders.put(clazz, context);
+        }
         return true;
       }
     }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoEnqueuerUseRegistry.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoEnqueuerUseRegistry.java
index 71750f1..79dea35 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoEnqueuerUseRegistry.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoEnqueuerUseRegistry.java
@@ -43,6 +43,7 @@
   @Override
   public boolean registerConstClass(DexType type) {
     if (references.isDynamicMethod(getContextMethod())) {
+      enqueuer.addDeadProtoTypeCandidate(type);
       return false;
     }
     return super.registerConstClass(type);
@@ -58,6 +59,7 @@
   @Override
   public boolean registerStaticFieldRead(DexField field) {
     if (references.isDynamicMethod(getContextMethod())) {
+      enqueuer.addDeadProtoTypeCandidate(field.holder);
       return false;
     }
     return super.registerStaticFieldRead(field);
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index ff68a8d..ea314bf 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -73,7 +73,8 @@
 
 /** Encapsulates liveness and reachability information for an application. */
 public class AppInfoWithLiveness extends AppInfoWithSubtyping implements InstantiatedSubTypeInfo {
-
+  /** Set of reachable proto types that will be dead code eliminated. */
+  private final Set<DexType> deadProtoTypes;
   /** Set of types that are mentioned in the program, but for which no definition exists. */
   private final Set<DexType> missingTypes;
   /**
@@ -199,6 +200,7 @@
   // TODO(zerny): Clean up the constructors so we have just one.
   AppInfoWithLiveness(
       DirectMappedDexApplication application,
+      Set<DexType> deadProtoTypes,
       Set<DexType> missingTypes,
       Set<DexType> liveTypes,
       Set<DexType> instantiatedAppServices,
@@ -240,6 +242,7 @@
       Set<DexType> constClassReferences,
       Map<DexType, Visibility> initClassReferences) {
     super(application);
+    this.deadProtoTypes = deadProtoTypes;
     this.missingTypes = missingTypes;
     this.liveTypes = liveTypes;
     this.instantiatedAppServices = instantiatedAppServices;
@@ -284,6 +287,7 @@
 
   public AppInfoWithLiveness(
       AppInfoWithSubtyping appInfoWithSubtyping,
+      Set<DexType> deadProtoTypes,
       Set<DexType> missingTypes,
       Set<DexType> liveTypes,
       Set<DexType> instantiatedAppServices,
@@ -325,6 +329,7 @@
       Set<DexType> constClassReferences,
       Map<DexType, Visibility> initClassReferences) {
     super(appInfoWithSubtyping);
+    this.deadProtoTypes = deadProtoTypes;
     this.missingTypes = missingTypes;
     this.liveTypes = liveTypes;
     this.instantiatedAppServices = instantiatedAppServices;
@@ -370,6 +375,7 @@
   private AppInfoWithLiveness(AppInfoWithLiveness previous) {
     this(
         previous,
+        previous.deadProtoTypes,
         previous.missingTypes,
         previous.liveTypes,
         previous.instantiatedAppServices,
@@ -420,6 +426,7 @@
       Collection<DexReference> additionalPinnedItems) {
     this(
         application,
+        previous.deadProtoTypes,
         previous.missingTypes,
         previous.liveTypes,
         previous.instantiatedAppServices,
@@ -473,6 +480,7 @@
       Map<DexField, Int2ReferenceMap<DexField>> switchMaps,
       EnumValueInfoMapCollection enumValueInfoMaps) {
     super(previous);
+    this.deadProtoTypes = previous.deadProtoTypes;
     this.missingTypes = previous.missingTypes;
     this.liveTypes = previous.liveTypes;
     this.instantiatedAppServices = previous.instantiatedAppServices;
@@ -535,6 +543,7 @@
     DexClass definition = super.definitionFor(type);
     assert !assertDefinitionFor
             || definition != null
+            || deadProtoTypes.contains(type)
             || missingTypes.contains(type)
             // TODO(b/150693139): Remove these exceptions once fixed.
             || InterfaceMethodRewriter.isCompanionClassType(type)
@@ -1057,6 +1066,7 @@
 
     return new AppInfoWithLiveness(
         application,
+        deadProtoTypes,
         missingTypes,
         rewriteItems(liveTypes, lens::lookupType),
         rewriteItems(instantiatedAppServices, lens::lookupType),
diff --git a/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java b/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
index db955c9..c04ddf5 100644
--- a/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
+++ b/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
@@ -18,7 +18,7 @@
 public class DefaultEnqueuerUseRegistry extends UseRegistry {
 
   private final ProgramMethod context;
-  private final Enqueuer enqueuer;
+  protected final Enqueuer enqueuer;
 
   public DefaultEnqueuerUseRegistry(
       AppView<?> appView, DexProgramClass holder, DexEncodedMethod method, Enqueuer enqueuer) {
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index c4dcb14..05cac5f 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -221,6 +221,9 @@
    */
   private final Set<DexClass> liveNonProgramTypes = Sets.newIdentityHashSet();
 
+  /** Set of reachable proto types that will be dead code eliminated. */
+  private final Set<DexProgramClass> deadProtoTypeCandidates = Sets.newIdentityHashSet();
+
   /** Set of missing types. */
   private final Set<DexType> missingTypes = Sets.newIdentityHashSet();
 
@@ -412,6 +415,15 @@
     this.annotationRemoverBuilder = annotationRemoverBuilder;
   }
 
+  public void addDeadProtoTypeCandidate(DexType type) {
+    assert type.isProgramType(appView);
+    addDeadProtoTypeCandidate(appView.definitionFor(type).asProgramClass());
+  }
+
+  public void addDeadProtoTypeCandidate(DexProgramClass clazz) {
+    deadProtoTypeCandidates.add(clazz);
+  }
+
   private boolean isProgramClass(DexType type) {
     return getProgramClassOrNull(type) != null;
   }
@@ -926,6 +938,7 @@
                 workList.enqueueTraceInvokeDirectAction(
                     invokedMethod, currentHolder, currentMethod));
     if (skipTracing) {
+      addDeadProtoTypeCandidate(invokedMethod.holder);
       return false;
     }
 
@@ -941,7 +954,10 @@
       return appView.withGeneratedMessageLiteBuilderShrinker(
           shrinker ->
               shrinker.deferDeadProtoBuilders(
-                  clazz, currentMethod, () -> liveTypes.registerDeferredAction(clazz, action)),
+                  clazz,
+                  currentMethod,
+                  () -> liveTypes.registerDeferredAction(clazz, action),
+                  this),
           false);
     }
     return false;
@@ -1084,6 +1100,7 @@
         registerDeferredActionForDeadProtoBuilder(
             type, currentMethod, () -> workList.enqueueTraceNewInstanceAction(type, context));
     if (skipTracing) {
+      addDeadProtoTypeCandidate(type);
       return false;
     }
 
@@ -1243,7 +1260,8 @@
       fieldAccessInfoCollection.get(encodedField.field).setReadFromMethodHandle();
     }
 
-    if (!isProgramClass(encodedField.holder())) {
+    DexProgramClass holder = getProgramClassOrNull(encodedField.holder());
+    if (holder == null) {
       // No need to trace into the non-program code.
       return false;
     }
@@ -1261,6 +1279,7 @@
                       encodedField, fieldAccessInfoCollection, pinnedItems),
               false);
       if (skipTracing) {
+        addDeadProtoTypeCandidate(holder);
         return false;
       }
     }
@@ -1300,7 +1319,8 @@
       fieldAccessInfoCollection.get(encodedField.field).setWrittenFromMethodHandle();
     }
 
-    if (!isProgramClass(encodedField.holder())) {
+    DexProgramClass holder = getProgramClassOrNull(encodedField.holder());
+    if (holder == null) {
       // No need to trace into the non-program code.
       return false;
     }
@@ -1318,6 +1338,7 @@
                       encodedField, fieldAccessInfoCollection, pinnedItems),
               false);
       if (skipTracing) {
+        addDeadProtoTypeCandidate(holder);
         return false;
       }
     }
@@ -2754,6 +2775,9 @@
   }
 
   private AppInfoWithLiveness createAppInfo(AppInfoWithSubtyping appInfo) {
+    // Compute the set of dead proto types.
+    deadProtoTypeCandidates.removeIf(this::isTypeLive);
+
     // Remove the temporary mappings that have been inserted into the field access info collection
     // and verify that the mapping is then one-to-one.
     fieldAccessInfoCollection.removeIf(
@@ -2796,6 +2820,7 @@
     AppInfoWithLiveness appInfoWithLiveness =
         new AppInfoWithLiveness(
             app,
+            SetUtils.mapIdentityHashSet(deadProtoTypeCandidates, DexProgramClass::getType),
             missingTypes,
             SetUtils.mapIdentityHashSet(liveTypes.getItems(), DexProgramClass::getType),
             Collections.unmodifiableSet(instantiatedAppServices),