Add support for kotlin.Metadata removal in R8 partial

Fixes: b/414304059
Change-Id: Id0baa34c5cdca46d6ab4bb812214b5aa0f007246
diff --git a/src/main/java/com/android/tools/r8/kotlin/KotlinMetadataEnqueuerExtension.java b/src/main/java/com/android/tools/r8/kotlin/KotlinMetadataEnqueuerExtension.java
index 4cad91a..cbed787 100644
--- a/src/main/java/com/android/tools/r8/kotlin/KotlinMetadataEnqueuerExtension.java
+++ b/src/main/java/com/android/tools/r8/kotlin/KotlinMetadataEnqueuerExtension.java
@@ -78,9 +78,9 @@
     boolean keepKotlinMetadata =
         KeepClassInfo.isKotlinMetadataClassKept(
             factory,
+            enqueuer.getKeepInfo(),
             appView.options(),
-            appView.appInfo()::definitionForWithoutExistenceAssert,
-            enqueuer::getKeepInfo);
+            appView.appInfo()::definitionForWithoutExistenceAssert);
     // In the first round of tree shaking build up all metadata such that it can be traced later.
     if (enqueuer.getMode().isInitialTreeShaking()) {
       Set<DexMethod> keepByteCodeFunctions = Sets.newIdentityHashSet();
diff --git a/src/main/java/com/android/tools/r8/shaking/KeepClassInfo.java b/src/main/java/com/android/tools/r8/shaking/KeepClassInfo.java
index 308e29c..02025b4 100644
--- a/src/main/java/com/android/tools/r8/shaking/KeepClassInfo.java
+++ b/src/main/java/com/android/tools/r8/shaking/KeepClassInfo.java
@@ -5,7 +5,6 @@
 
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexItemFactory;
-import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.utils.InternalOptions;
 import java.util.List;
@@ -117,14 +116,14 @@
 
   public static boolean isKotlinMetadataClassKept(
       DexItemFactory factory,
+      KeepInfoCollection keepInfo,
       InternalOptions options,
-      Function<DexType, DexClass> definitionForWithoutExistenceAssert,
-      Function<DexProgramClass, KeepClassInfo> getClassInfo) {
+      Function<DexType, DexClass> definitionForWithoutExistenceAssert) {
     DexType kotlinMetadataType = factory.kotlinMetadataType;
     DexClass kotlinMetadataClass = definitionForWithoutExistenceAssert.apply(kotlinMetadataType);
     return kotlinMetadataClass == null
         || kotlinMetadataClass.isNotProgramClass()
-        || !getClassInfo.apply(kotlinMetadataClass.asProgramClass()).isShrinkingAllowed(options);
+        || !keepInfo.isKotlinMetadataRemovalAllowed(options);
   }
 
   public boolean isPermittedSubclassesRemovalAllowed(GlobalKeepInfoConfiguration configuration) {
diff --git a/src/main/java/com/android/tools/r8/shaking/KeepInfoCollection.java b/src/main/java/com/android/tools/r8/shaking/KeepInfoCollection.java
index 4ef518f..11ebdfb 100644
--- a/src/main/java/com/android/tools/r8/shaking/KeepInfoCollection.java
+++ b/src/main/java/com/android/tools/r8/shaking/KeepInfoCollection.java
@@ -18,6 +18,7 @@
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItem;
+import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexReference;
@@ -31,10 +32,12 @@
 import com.android.tools.r8.shaking.KeepFieldInfo.Joiner;
 import com.android.tools.r8.shaking.rules.ApplicableRulesEvaluator;
 import com.android.tools.r8.shaking.rules.MaterializedRules;
+import com.android.tools.r8.shaking.rules.ReferencedFromExcludedClassInR8PartialRule;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.MapUtils;
 import com.android.tools.r8.utils.timing.Timing;
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Streams;
 import java.io.IOException;
 import java.nio.file.Files;
@@ -219,6 +222,8 @@
     throw new Unreachable();
   }
 
+  public abstract boolean isKotlinMetadataRemovalAllowed(GlobalKeepInfoConfiguration configuration);
+
   public final boolean isPinned(
       ProgramDefinition definition, GlobalKeepInfoConfiguration configuration) {
     return getInfo(definition).isPinned(configuration);
@@ -279,6 +284,8 @@
   // Mutation interface for building up the keep info.
   public static class MutableKeepInfoCollection extends KeepInfoCollection {
 
+    private final DexItemFactory factory;
+
     // These are typed at signatures but the interface should make sure never to allow access
     // directly with a signature. See the comment in KeepInfoCollection.
     private final Map<DexType, KeepClassInfo> keepClassInfo;
@@ -290,6 +297,8 @@
     private final Map<DexField, KeepFieldInfo.Joiner> fieldRuleInstances;
     private final Map<DexMethod, KeepMethodInfo.Joiner> methodRuleInstances;
 
+    private boolean allowKotlinMetadataRemoval = true;
+
     // Collection of materialized rules.
     private MaterializedRules materializedRules;
 
@@ -297,6 +306,7 @@
 
     MutableKeepInfoCollection(InternalOptions options) {
       this(
+          options,
           new IdentityHashMap<>(),
           new IdentityHashMap<>(),
           new IdentityHashMap<>(),
@@ -310,6 +320,7 @@
     }
 
     private MutableKeepInfoCollection(
+        InternalOptions options,
         Map<DexType, KeepClassInfo> keepClassInfo,
         Map<DexMethod, KeepMethodInfo> keepMethodInfo,
         Map<DexField, KeepFieldInfo> keepFieldInfo,
@@ -318,6 +329,7 @@
         Map<DexMethod, KeepMethodInfo.Joiner> methodRuleInstances,
         MaterializedRules materializedRules,
         KeepInfoCanonicalizer keepInfoCanonicalizer) {
+      this.factory = options.dexItemFactory();
       this.keepClassInfo = keepClassInfo;
       this.keepMethodInfo = keepMethodInfo;
       this.keepFieldInfo = keepFieldInfo;
@@ -376,6 +388,7 @@
       Map<DexField, KeepFieldInfo> newFieldInfo = rewriteFieldInfo(lens, options, timing);
       MutableKeepInfoCollection result =
           new MutableKeepInfoCollection(
+              options,
               newClassInfo,
               newMethodInfo,
               newFieldInfo,
@@ -532,24 +545,6 @@
           });
     }
 
-    void evaluateClassRule(DexProgramClass clazz, KeepClassInfo.Joiner minimumKeepInfo) {
-      if (!minimumKeepInfo.isBottom()) {
-        joinClass(clazz, joiner -> joiner.merge(minimumKeepInfo));
-        classRuleInstances
-            .computeIfAbsent(clazz.getType(), ignoreKey(KeepClassInfo::newEmptyJoiner))
-            .merge(minimumKeepInfo);
-      }
-    }
-
-    void evaluateFieldRule(ProgramField field, KeepFieldInfo.Joiner minimumKeepInfo) {
-      if (!minimumKeepInfo.isBottom()) {
-        joinField(field, joiner -> joiner.merge(minimumKeepInfo));
-        fieldRuleInstances
-            .computeIfAbsent(field.getReference(), ignoreKey(KeepFieldInfo::newEmptyJoiner))
-            .merge(minimumKeepInfo);
-      }
-    }
-
     void evaluateMethodRule(ProgramMethod method, KeepMethodInfo.Joiner minimumKeepInfo) {
       if (!minimumKeepInfo.isBottom()) {
         joinMethod(method, joiner -> joiner.merge(minimumKeepInfo));
@@ -595,16 +590,14 @@
     }
 
     @Override
-    @SuppressWarnings("ReferenceEquality")
     public KeepMethodInfo getMethodInfo(DexEncodedMethod method, DexProgramClass holder) {
-      assert method.getHolderType() == holder.type;
+      assert method.getHolderType().isIdenticalTo(holder.getType());
       return keepMethodInfo.getOrDefault(method.getReference(), KeepMethodInfo.bottom());
     }
 
     @Override
-    @SuppressWarnings("ReferenceEquality")
     public KeepFieldInfo getFieldInfo(DexEncodedField field, DexProgramClass holder) {
-      assert field.getHolderType() == holder.type;
+      assert field.getHolderType().isIdenticalTo(holder.getType());
       return keepFieldInfo.getOrDefault(field.getReference(), KeepFieldInfo.bottom());
     }
 
@@ -619,9 +612,33 @@
       KeepClassInfo joined = joiner.join();
       if (!info.equals(joined)) {
         keepClassInfo.put(clazz.type, canonicalizer.canonicalizeKeepClassInfo(joined));
+        maybeDisallowKotlinMetadataRemoval(clazz, info, joined, joiner);
       }
     }
 
+    private void maybeDisallowKotlinMetadataRemoval(
+        DexProgramClass clazz,
+        KeepClassInfo oldInfo,
+        KeepClassInfo newInfo,
+        KeepClassInfo.Joiner joiner) {
+      if (!oldInfo.internalIsShrinkingAllowed()
+          || newInfo.internalIsShrinkingAllowed()
+          || clazz.getType().isNotIdenticalTo(factory.kotlinMetadataType)) {
+        return;
+      }
+      // The kotlin.Metadata class went from not being kept to being kept.
+      if (joiner.getReasons().isEmpty()
+          && !joiner.getRules().isEmpty()
+          && Iterables.all(
+              joiner.getRules(),
+              rule -> rule instanceof ReferencedFromExcludedClassInR8PartialRule)) {
+        // We do not want to disallow kotlin.Metadata removal in R8 partial simply because
+        // kotlin.Metadata is referenced from D8.
+        return;
+      }
+      allowKotlinMetadataRemoval = false;
+    }
+
     public void keepClass(DexProgramClass clazz) {
       joinClass(clazz, KeepInfo.Joiner::top);
     }
@@ -663,6 +680,11 @@
     }
 
     @Override
+    public boolean isKotlinMetadataRemovalAllowed(GlobalKeepInfoConfiguration configuration) {
+      return allowKotlinMetadataRemoval && configuration.isTreeShakingEnabled();
+    }
+
+    @Override
     public KeepInfoCollection mutate(Consumer<MutableKeepInfoCollection> mutator) {
       mutator.accept(this);
       return this;