Always class inline proto builders

Change-Id: I8799c34cc9e0200caa82646bb9fda4d71003c91e
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
index d1cf2e5..4660f87 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
@@ -12,6 +12,7 @@
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.utils.PredicateSet;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
@@ -39,10 +40,13 @@
 
   public static void addInliningHeuristicsForBuilderInlining(
       AppView<? extends AppInfoWithSubtyping> appView,
+      PredicateSet<DexType> alwaysClassInline,
       Set<DexMethod> alwaysInline,
       Set<DexMethod> neverInline,
       Set<DexMethod> bypassClinitforInlining) {
-    new RootSetExtension(appView, alwaysInline, neverInline, bypassClinitforInlining).extend();
+    new RootSetExtension(
+            appView, alwaysClassInline, alwaysInline, neverInline, bypassClinitforInlining)
+        .extend();
   }
 
   public void preprocessCallGraphBeforeCycleElimination(Map<DexMethod, Node> nodes) {
@@ -65,23 +69,28 @@
     private final AppView<? extends AppInfoWithSubtyping> appView;
     private final ProtoReferences references;
 
+    private final PredicateSet<DexType> alwaysClassInline;
     private final Set<DexMethod> alwaysInline;
     private final Set<DexMethod> neverInline;
     private final Set<DexMethod> bypassClinitforInlining;
 
     RootSetExtension(
         AppView<? extends AppInfoWithSubtyping> appView,
+        PredicateSet<DexType> alwaysClassInline,
         Set<DexMethod> alwaysInline,
         Set<DexMethod> neverInline,
         Set<DexMethod> bypassClinitforInlining) {
       this.appView = appView;
       this.references = appView.protoShrinker().references;
+      this.alwaysClassInline = alwaysClassInline;
       this.alwaysInline = alwaysInline;
       this.neverInline = neverInline;
       this.bypassClinitforInlining = bypassClinitforInlining;
     }
 
     void extend() {
+      alwaysClassInlineGeneratedMessageLiteBuilders();
+
       // GeneratedMessageLite heuristics.
       alwaysInlineCreateBuilderFromGeneratedMessageLite();
       neverInlineIsInitializedFromGeneratedMessageLite();
@@ -94,6 +103,14 @@
       alwaysInlineBuildPartialFromGeneratedMessageLiteBuilder();
     }
 
+    private void alwaysClassInlineGeneratedMessageLiteBuilders() {
+      alwaysClassInline.addPredicate(
+          type ->
+              appView
+                  .appInfo()
+                  .isStrictSubtypeOf(type, references.generatedMessageLiteBuilderType));
+    }
+
     private void bypassClinitforInliningNewBuilderMethods() {
       for (DexType type : appView.appInfo().subtypes(references.generatedMessageLiteType)) {
         DexProgramClass clazz = appView.definitionFor(type).asProgramClass();
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index 33e4eed..ac4f95f 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -29,6 +29,7 @@
 import com.android.tools.r8.ir.analysis.type.ClassTypeLatticeElement;
 import com.android.tools.r8.ir.code.Invoke.Type;
 import com.android.tools.r8.utils.CollectionUtils;
+import com.android.tools.r8.utils.PredicateSet;
 import com.android.tools.r8.utils.SetUtils;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
@@ -139,7 +140,7 @@
   /** All methods that may not have any unused arguments removed. */
   public final Set<DexMethod> keepUnusedArguments;
   /** All types that should be inlined if possible due to a configuration directive. */
-  public final Set<DexType> alwaysClassInline;
+  public final PredicateSet<DexType> alwaysClassInline;
   /** All types that *must* never be inlined due to a configuration directive (testing only). */
   public final Set<DexType> neverClassInline;
   /** All types that *must* never be merged due to a configuration directive (testing only). */
@@ -207,7 +208,7 @@
       Set<DexMethod> whyAreYouNotInlining,
       Set<DexMethod> keepConstantArguments,
       Set<DexMethod> keepUnusedArguments,
-      Set<DexType> alwaysClassInline,
+      PredicateSet<DexType> alwaysClassInline,
       Set<DexType> neverClassInline,
       Set<DexType> neverMerge,
       Set<DexReference> neverPropagateValue,
@@ -286,7 +287,7 @@
       Set<DexMethod> whyAreYouNotInlining,
       Set<DexMethod> keepConstantArguments,
       Set<DexMethod> keepUnusedArguments,
-      Set<DexType> alwaysClassInline,
+      PredicateSet<DexType> alwaysClassInline,
       Set<DexType> neverClassInline,
       Set<DexType> neverMerge,
       Set<DexReference> neverPropagateValue,
@@ -496,7 +497,7 @@
             .map(this::definitionFor)
             .filter(Objects::nonNull)
             .collect(Collectors.toList()));
-    this.alwaysClassInline = rewriteItems(previous.alwaysClassInline, lense::lookupType);
+    this.alwaysClassInline = previous.alwaysClassInline.rewriteItems(lense::lookupType);
     this.neverClassInline = rewriteItems(previous.neverClassInline, lense::lookupType);
     this.neverMerge = rewriteItems(previous.neverMerge, lense::lookupType);
     this.neverPropagateValue = lense.rewriteReferencesConservatively(previous.neverPropagateValue);
diff --git a/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java b/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
index 8555b93..62a0553 100644
--- a/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
+++ b/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
@@ -32,6 +32,7 @@
 import com.android.tools.r8.utils.Consumer3;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.MethodSignatureEquivalence;
+import com.android.tools.r8.utils.PredicateSet;
 import com.android.tools.r8.utils.StringDiagnostic;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.google.common.base.Equivalence.Wrapper;
@@ -82,7 +83,7 @@
   private final Set<DexMethod> whyAreYouNotInlining = Sets.newIdentityHashSet();
   private final Set<DexMethod> keepParametersWithConstantValue = Sets.newIdentityHashSet();
   private final Set<DexMethod> keepUnusedArguments = Sets.newIdentityHashSet();
-  private final Set<DexType> alwaysClassInline = Sets.newIdentityHashSet();
+  private final PredicateSet<DexType> alwaysClassInline = new PredicateSet<>();
   private final Set<DexType> neverClassInline = Sets.newIdentityHashSet();
   private final Set<DexType> neverMerge = Sets.newIdentityHashSet();
   private final Set<DexReference> neverPropagateValue = Sets.newIdentityHashSet();
@@ -295,7 +296,7 @@
     }
     if (appView.options().protoShrinking().enableGeneratedMessageLiteBuilderShrinking) {
       GeneratedMessageLiteBuilderShrinker.addInliningHeuristicsForBuilderInlining(
-          appView, alwaysInline, neverInline, bypassClinitforInlining);
+          appView, alwaysClassInline, alwaysInline, neverInline, bypassClinitforInlining);
     }
     assert Sets.intersection(neverInline, alwaysInline).isEmpty()
             && Sets.intersection(neverInline, forceInline).isEmpty()
@@ -1123,7 +1124,7 @@
       }
       switch (classInlineRule.getType()) {
         case ALWAYS:
-          alwaysClassInline.add(item.asDexClass().type);
+          alwaysClassInline.addElement(item.asDexClass().type);
           break;
         case NEVER:
           neverClassInline.add(item.asDexClass().type);
@@ -1202,7 +1203,7 @@
     public final Set<DexMethod> whyAreYouNotInlining;
     public final Set<DexMethod> keepConstantArguments;
     public final Set<DexMethod> keepUnusedArguments;
-    public final Set<DexType> alwaysClassInline;
+    public final PredicateSet<DexType> alwaysClassInline;
     public final Set<DexType> neverClassInline;
     public final Set<DexType> neverMerge;
     public final Set<DexReference> neverPropagateValue;
@@ -1229,7 +1230,7 @@
         Set<DexMethod> whyAreYouNotInlining,
         Set<DexMethod> keepConstantArguments,
         Set<DexMethod> keepUnusedArguments,
-        Set<DexType> alwaysClassInline,
+        PredicateSet<DexType> alwaysClassInline,
         Set<DexType> neverClassInline,
         Set<DexType> neverMerge,
         Set<DexReference> neverPropagateValue,
diff --git a/src/main/java/com/android/tools/r8/utils/PredicateSet.java b/src/main/java/com/android/tools/r8/utils/PredicateSet.java
new file mode 100644
index 0000000..da15974
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/PredicateSet.java
@@ -0,0 +1,49 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+public class PredicateSet<T> {
+
+  private final Set<T> elements = Sets.newIdentityHashSet();
+  private final List<Predicate<T>> predicates = new ArrayList<>();
+
+  public boolean addElement(T element) {
+    return elements.add(element);
+  }
+
+  public void addPredicate(Predicate<T> predicate) {
+    predicates.add(predicate);
+  }
+
+  public PredicateSet<T> rewriteItems(Function<T, T> mapping) {
+    PredicateSet<T> set = new PredicateSet<>();
+    for (T item : elements) {
+      set.elements.add(mapping.apply(item));
+    }
+    // It is assumed that the predicates do not need rewriting. Otherwise, this method must be
+    // overwritten.
+    set.predicates.addAll(predicates);
+    return set;
+  }
+
+  public boolean contains(T element) {
+    if (elements.contains(element)) {
+      return true;
+    }
+    for (Predicate<T> predicate : predicates) {
+      if (predicate.test(element)) {
+        return true;
+      }
+    }
+    return false;
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/internal/proto/Proto2BuilderShrinkingTest.java b/src/test/java/com/android/tools/r8/internal/proto/Proto2BuilderShrinkingTest.java
index 7a158d3..a717189 100644
--- a/src/test/java/com/android/tools/r8/internal/proto/Proto2BuilderShrinkingTest.java
+++ b/src/test/java/com/android/tools/r8/internal/proto/Proto2BuilderShrinkingTest.java
@@ -27,7 +27,6 @@
 @RunWith(Parameterized.class)
 public class Proto2BuilderShrinkingTest extends ProtoShrinkingTestBase {
 
-  private static final String LITE_BUILDER = "com.google.protobuf.GeneratedMessageLite$Builder";
   private static final String METHOD_TO_INVOKE_ENUM =
       "com.google.protobuf.GeneratedMessageLite$MethodToInvoke";
 
@@ -164,14 +163,9 @@
   }
 
   private void verifyBuildersAreAbsent(CodeInspector outputInspector) {
-    boolean primitivesBuilderShouldBeLive =
-        mains.contains("proto2.BuilderWithReusedSettersTestClass");
-    assertThat(
-        outputInspector.clazz(LITE_BUILDER),
-        primitivesBuilderShouldBeLive ? isPresent() : not(isPresent()));
     assertThat(
         outputInspector.clazz("com.android.tools.r8.proto2.TestProto$Primitives$Builder"),
-        primitivesBuilderShouldBeLive ? isPresent() : not(isPresent()));
+        not(isPresent()));
     assertThat(
         outputInspector.clazz("com.android.tools.r8.proto2.TestProto$OuterMessage$Builder"),
         not(isPresent()));