Require a keep reason to mark types live.

Bug: 120959039
Change-Id: Ib22f11fb1ad2717ce406053f82de7a06b58cf875
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 43bdad1..b42aa19 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -73,7 +73,6 @@
 import com.google.common.base.Equivalence.Wrapper;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList.Builder;
-import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.ImmutableSortedSet;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
@@ -98,6 +97,7 @@
 import java.util.concurrent.ExecutorService;
 import java.util.function.BiConsumer;
 import java.util.function.BiPredicate;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 /**
@@ -193,7 +193,7 @@
    * Set of types that are mentioned in the program. We at least need an empty abstract class item
    * for these.
    */
-  private final Set<DexProgramClass> liveTypes = Sets.newIdentityHashSet();
+  private final SetWithReason<DexProgramClass> liveTypes = new SetWithReason<>(this::registerClass);
 
   /** Set of live types defined in the library and classpath. Used to avoid duplicate tracing. */
   private final Set<DexClass> liveNonProgramTypes = Sets.newIdentityHashSet();
@@ -472,7 +472,7 @@
       return;
     }
     populateInstantiatedTypesCache(clazz);
-    markTypeAsLive(clazz.type);
+    markTypeAsLive(clazz, reason);
   }
 
   private void enqueueFirstNonSerializableClassInitializer(
@@ -520,7 +520,7 @@
       Map<DexMethod, Set<DexEncodedMethod>> seen, DexMethod method, DexEncodedMethod context) {
     DexType baseHolder = method.holder.toBaseType(appView.dexItemFactory());
     if (baseHolder.isClassType()) {
-      markTypeAsLive(baseHolder);
+      markTypeAsLive(baseHolder, clazz -> graphReporter.reportClassReferencedFrom(clazz, context));
       return seen.computeIfAbsent(method, ignore -> Sets.newIdentityHashSet()).add(context);
     }
     return false;
@@ -695,7 +695,7 @@
       }
 
       // Must mark the field as targeted even if it does not exist.
-      markFieldAsTargeted(field);
+      markFieldAsTargeted(field, currentMethod);
 
       DexEncodedField encodedField = appInfo.resolveField(field);
       if (encodedField == null) {
@@ -712,17 +712,20 @@
         Log.verbose(getClass(), "Register Iput `%s`.", field);
       }
 
+
       // If unused interface removal is enabled, then we won't necessarily mark the actual holder of
       // the field as live, if the holder is an interface.
       if (appView.options().enableUnusedInterfaceRemoval) {
         if (encodedField.field != field) {
-          markTypeAsLive(encodedField.field.holder);
-          markTypeAsLive(encodedField.field.type);
+          markTypeAsLive(clazz, graphReporter.reportClassReferencedFrom(clazz, currentMethod));
+          markTypeAsLive(
+              encodedField.field.type,
+              type -> graphReporter.reportClassReferencedFrom(type, currentMethod));
         }
       }
 
-      workList.enqueueMarkReachableFieldAction(
-          clazz, encodedField, KeepReason.fieldReferencedIn(currentMethod));
+      KeepReason reason = KeepReason.fieldReferencedIn(currentMethod);
+      workList.enqueueMarkReachableFieldAction(clazz, encodedField, reason);
       return true;
     }
 
@@ -733,7 +736,7 @@
       }
 
       // Must mark the field as targeted even if it does not exist.
-      markFieldAsTargeted(field);
+      markFieldAsTargeted(field, currentMethod);
 
       DexEncodedField encodedField = appInfo.resolveField(field);
       if (encodedField == null) {
@@ -754,8 +757,10 @@
       // the field as live, if the holder is an interface.
       if (appView.options().enableUnusedInterfaceRemoval) {
         if (encodedField.field != field) {
-          markTypeAsLive(encodedField.field.holder);
-          markTypeAsLive(encodedField.field.type);
+          markTypeAsLive(clazz, graphReporter.reportClassReferencedFrom(clazz, currentMethod));
+          markTypeAsLive(
+              encodedField.field.type,
+              type -> graphReporter.reportClassReferencedFrom(type, currentMethod));
         }
       }
 
@@ -773,7 +778,7 @@
       DexProgramClass clazz = getProgramClassOrNull(type);
       if (clazz != null) {
         if (clazz.isInterface()) {
-          markTypeAsLive(clazz.type);
+          markTypeAsLive(clazz, keepReason);
         } else {
           markInstantiated(clazz, keepReason);
         }
@@ -790,7 +795,7 @@
       DexEncodedField encodedField = appInfo.resolveField(field);
       if (encodedField == null) {
         // Must mark the field as targeted even if it does not exist.
-        markFieldAsTargeted(field);
+        markFieldAsTargeted(field, currentMethod);
         reportMissingField(field);
         return false;
       }
@@ -819,7 +824,7 @@
       if (encodedField.field != field) {
         // Mark the non-rebound field access as targeted. Note that this should only be done if the
         // field is not a dead proto field (in which case we bail-out above).
-        markFieldAsTargeted(field);
+        markFieldAsTargeted(field, currentMethod);
       }
 
       markStaticFieldAsLive(encodedField, KeepReason.fieldReferencedIn(currentMethod));
@@ -835,7 +840,7 @@
       DexEncodedField encodedField = appInfo.resolveField(field);
       if (encodedField == null) {
         // Must mark the field as targeted even if it does not exist.
-        markFieldAsTargeted(field);
+        markFieldAsTargeted(field, currentMethod);
         reportMissingField(field);
         return false;
       }
@@ -872,7 +877,7 @@
       if (encodedField.field != field) {
         // Mark the non-rebound field access as targeted. Note that this should only be done if the
         // field is not a dead proto field (in which case we bail-out above).
-        markFieldAsTargeted(field);
+        markFieldAsTargeted(field, currentMethod);
       }
 
       markStaticFieldAsLive(encodedField, KeepReason.fieldReferencedIn(currentMethod));
@@ -891,7 +896,7 @@
 
     @Override
     public boolean registerTypeReference(DexType type) {
-      markTypeAsLive(type);
+      markTypeAsLive(type, clazz -> graphReporter.reportClassReferencedFrom(clazz, currentMethod));
       return true;
     }
 
@@ -1024,7 +1029,8 @@
           markClassAsInstantiatedWithCompatRule(baseClass);
         } else {
           // This also handles reporting of missing classes.
-          markTypeAsLive(baseType);
+          markTypeAsLive(
+              baseType, clazz -> graphReporter.reportClassReferencedFrom(clazz, currentMethod));
         }
         return true;
       }
@@ -1068,14 +1074,9 @@
     return true;
   }
 
-  private void markTypeAsLive(DexType type) {
-    markTypeAsLive(
-        type, scopedMethodsForLiveTypes.computeIfAbsent(type, ignore -> new ScopedDexMethodSet()));
-  }
-
-  private void markTypeAsLive(DexType type, ScopedDexMethodSet seen) {
+  private void markTypeAsLive(DexType type, KeepReason reason) {
     if (type.isArrayType()) {
-      markTypeAsLive(type.toBaseType(appView.dexItemFactory()));
+      markTypeAsLive(type.toBaseType(appView.dexItemFactory()), reason);
       return;
     }
     if (!type.isClassType()) {
@@ -1083,16 +1084,55 @@
       return;
     }
     DexProgramClass holder = getProgramClassOrNull(type);
-    if (holder == null || !liveTypes.add(holder)) {
+    if (holder == null) {
+      return;
+    }
+    markTypeAsLive(
+        holder,
+        scopedMethodsForLiveTypes.computeIfAbsent(type, ignore -> new ScopedDexMethodSet()),
+        reason);
+  }
+
+  private void markTypeAsLive(DexType type, Function<DexProgramClass, KeepReason> reason) {
+    if (type.isArrayType()) {
+      markTypeAsLive(type.toBaseType(appView.dexItemFactory()), reason);
+      return;
+    }
+    if (!type.isClassType()) {
+      // Ignore primitive types.
+      return;
+    }
+    DexProgramClass holder = getProgramClassOrNull(type);
+    if (holder == null) {
+      return;
+    }
+    markTypeAsLive(
+        holder,
+        scopedMethodsForLiveTypes.computeIfAbsent(type, ignore -> new ScopedDexMethodSet()),
+        reason.apply(holder));
+  }
+
+  private void markTypeAsLive(DexProgramClass clazz, KeepReason reason) {
+    markTypeAsLive(
+        clazz,
+        scopedMethodsForLiveTypes.computeIfAbsent(clazz.type, ignore -> new ScopedDexMethodSet()),
+        reason);
+  }
+
+  private void markTypeAsLive(
+      DexProgramClass holder, ScopedDexMethodSet seen, KeepReason reasonForType) {
+    if (!liveTypes.add(holder, reasonForType)) {
       return;
     }
 
     if (Log.ENABLED) {
-      Log.verbose(getClass(), "Type `%s` has become live.", type);
+      Log.verbose(getClass(), "Type `%s` has become live.", holder.type);
     }
 
+    KeepReason reason = KeepReason.reachableFromLiveType(holder.type);
+
     for (DexType iface : holder.interfaces.values) {
-      markInterfaceTypeAsLiveViaInheritanceClause(iface);
+      markInterfaceTypeAsLiveViaInheritanceClause(iface, reason);
     }
 
     if (holder.superType != null) {
@@ -1100,15 +1140,9 @@
           scopedMethodsForLiveTypes.computeIfAbsent(
               holder.superType, ignore -> new ScopedDexMethodSet());
       seen.setParent(seenForSuper);
-      // TODO(b/120959039): The keep reason should be passed to markTypeAsLive.
-      DexClass holderSuper = appView.definitionFor(holder.superType);
-      if (holderSuper != null && holderSuper.isProgramClass()) {
-        registerType(holder.superType, KeepReason.reachableFromLiveType(type));
-      }
-      markTypeAsLive(holder.superType, seenForSuper);
+      markTypeAsLive(holder.superType, reason);
     }
 
-    KeepReason reason = KeepReason.reachableFromLiveType(type);
 
     // We cannot remove virtual methods defined earlier in the type hierarchy if it is widening
     // access and is defined in an interface:
@@ -1145,7 +1179,7 @@
       processAnnotations(holder, holder.annotations.annotations);
     }
     // If this type has deferred annotations, we have to process those now, too.
-    Set<DexAnnotation> annotations = deferredAnnotations.remove(type);
+    Set<DexAnnotation> annotations = deferredAnnotations.remove(holder.type);
     if (annotations != null && !annotations.isEmpty()) {
       assert holder.accessFlags.isAnnotation();
       assert annotations.stream().allMatch(a -> a.annotation.type == holder.type);
@@ -1175,7 +1209,7 @@
     }
   }
 
-  private void markInterfaceTypeAsLiveViaInheritanceClause(DexType type) {
+  private void markInterfaceTypeAsLiveViaInheritanceClause(DexType type, KeepReason reason) {
     if (appView.options().enableUnusedInterfaceRemoval && !mode.isTracingMainDex()) {
       DexProgramClass clazz = getProgramClassOrNull(type);
       if (clazz == null) {
@@ -1185,13 +1219,13 @@
       assert clazz.isInterface();
 
       if (!clazz.interfaces.isEmpty()) {
-        markTypeAsLive(type);
+        markTypeAsLive(type, reason);
         return;
       }
 
       for (DexEncodedMethod method : clazz.virtualMethods()) {
         if (!method.accessFlags.isAbstract()) {
-          markTypeAsLive(type);
+          markTypeAsLive(type, reason);
           return;
         }
       }
@@ -1200,7 +1234,7 @@
       // inheritance clause of another type, and the interface only has abstract methods, it can
       // simply be removed from the inheritance clause.
     } else {
-      markTypeAsLive(type);
+      markTypeAsLive(type, reason);
     }
   }
 
@@ -1232,9 +1266,10 @@
       }
       return;
     }
-    liveAnnotations.add(annotation, KeepReason.annotatedOn(holder));
+    KeepReason reason = KeepReason.annotatedOn(holder);
+    liveAnnotations.add(annotation, reason);
     AnnotationReferenceMarker referenceMarker =
-        new AnnotationReferenceMarker(annotation.annotation.type, appView.dexItemFactory());
+        new AnnotationReferenceMarker(annotation.annotation.type, appView.dexItemFactory(), reason);
     annotation.annotation.collectIndexedItems(referenceMarker);
   }
 
@@ -1363,7 +1398,8 @@
       // Already targeted.
       return;
     }
-    markTypeAsLive(method.method.holder);
+    markTypeAsLive(method.method.holder,
+        holder -> graphReporter.reportClassReferencedFrom(holder, method));
     markParameterAndReturnTypesAsLive(method);
     processAnnotations(method, method.annotations.annotations);
     method.parameterAnnotationsList.forEachAnnotation(
@@ -1404,7 +1440,7 @@
       Log.verbose(getClass(), "Class `%s` is instantiated, processing...", clazz);
     }
     // This class becomes live, so it and all its supertypes become live types.
-    markTypeAsLive(clazz.type);
+    markTypeAsLive(clazz, reason);
     // For all methods of the class, if we have seen a call, mark the method live.
     // We only do this for virtual calls, as the other ones will be done directly.
     transitionMethodsForInstantiatedClass(clazz);
@@ -1601,16 +1637,16 @@
         && !instantiatedTypes.contains(current.asProgramClass()));
   }
 
-  private void markFieldAsTargeted(DexField field) {
-    markTypeAsLive(field.type);
-    markTypeAsLive(field.holder);
+  private void markFieldAsTargeted(DexField field, DexEncodedMethod context) {
+    markTypeAsLive(field.type, clazz -> graphReporter.reportClassReferencedFrom(clazz, context));
+    markTypeAsLive(field.holder, clazz -> graphReporter.reportClassReferencedFrom(clazz, context));
   }
 
   private void markStaticFieldAsLive(DexEncodedField encodedField, KeepReason reason) {
     // Mark the type live here, so that the class exists at runtime.
     DexField field = encodedField.field;
-    markTypeAsLive(field.holder);
-    markTypeAsLive(field.type);
+    markTypeAsLive(field.holder, reason);
+    markTypeAsLive(field.type, reason);
 
     DexProgramClass clazz = getProgramClassOrNull(field.holder);
     if (clazz == null) {
@@ -1646,8 +1682,8 @@
   private void markInstanceFieldAsLive(DexEncodedField field, KeepReason reason) {
     assert field != null;
     assert field.isProgramField(appView);
-    markTypeAsLive(field.field.holder);
-    markTypeAsLive(field.field.type);
+    markTypeAsLive(field.field.holder, reason);
+    markTypeAsLive(field.field.type, reason);
     if (Log.ENABLED) {
       Log.verbose(getClass(), "Adding instance field `%s` to live set.", field.field);
     }
@@ -1744,8 +1780,8 @@
       Log.verbose(getClass(), "Marking instance field `%s` as reachable.", field);
     }
 
-    markTypeAsLive(field.holder);
-    markTypeAsLive(field.type);
+    markTypeAsLive(field.holder, reason);
+    markTypeAsLive(field.type, reason);
 
     DexProgramClass clazz = getProgramClassOrNull(field.holder);
     if (clazz == null) {
@@ -1792,7 +1828,7 @@
       // like an invoke on a direct subtype of java.lang.Object that has no further subtypes.
       // As it has no subtypes, it cannot affect liveness of the program we are processing.
       // Ergo, we can ignore it. We need to make sure that the element type is available, though.
-      markTypeAsLive(method.holder);
+      markTypeAsLive(method.holder, reason);
       return;
     }
     DexClass holder = appView.definitionFor(method.holder);
@@ -2036,7 +2072,7 @@
     enqueueRootItems(rootSet.noShrinking);
     trace(executorService, timing);
     options.reporter.failIfPendingErrors();
-    return Collections.unmodifiableSet(liveTypes);
+    return liveTypes.getItems();
   }
 
   public AppInfoWithLiveness traceApplication(
@@ -2069,7 +2105,7 @@
     AppInfoWithLiveness appInfoWithLiveness =
         new AppInfoWithLiveness(
             appInfo,
-            SetUtils.mapIdentityHashSet(liveTypes, DexProgramClass::getType),
+            SetUtils.mapIdentityHashSet(liveTypes.getItems(), DexProgramClass::getType),
             SetUtils.mapIdentityHashSet(
                 liveAnnotations.getItems(), DexAnnotation::getAnnotationType),
             Collections.unmodifiableSet(instantiatedAppServices),
@@ -2140,7 +2176,7 @@
     timing.begin("Grow the tree.");
     try {
       while (true) {
-        long numOfLiveItems = (long) liveTypes.size();
+        long numOfLiveItems = (long) liveTypes.items.size();
         numOfLiveItems += (long) liveMethods.items.size();
         numOfLiveItems += (long) liveFields.items.size();
         while (!workList.isEmpty()) {
@@ -2180,7 +2216,7 @@
         }
 
         // Continue fix-point processing if -if rules are enabled by items that newly became live.
-        long numOfLiveItemsAfterProcessing = (long) liveTypes.size();
+        long numOfLiveItemsAfterProcessing = (long) liveTypes.items.size();
         numOfLiveItemsAfterProcessing += (long) liveMethods.items.size();
         numOfLiveItemsAfterProcessing += (long) liveFields.items.size();
         if (numOfLiveItemsAfterProcessing > numOfLiveItems) {
@@ -2202,7 +2238,7 @@
                   activeIfRules,
                   liveFields.getItems(),
                   liveMethods.getItems(),
-                  liveTypes,
+                  liveTypes.getItems(),
                   mode,
                   consequentSetBuilder,
                   targetedMethods.getItems());
@@ -2261,7 +2297,7 @@
         Log.debug(getClass(), "%s methods are reachable but not live", reachableNotLive.size());
         Log.info(getClass(), "Only reachable: %s", reachableNotLive);
         Set<DexProgramClass> liveButNotInstantiated =
-            Sets.difference(liveTypes, instantiatedTypes.getItems());
+            Sets.difference(liveTypes.getItems(), instantiatedTypes.getItems());
         Log.debug(getClass(), "%s classes are live but not instantiated",
             liveButNotInstantiated.size());
         Log.info(getClass(), "Live but not instantiated: %s", liveButNotInstantiated);
@@ -2309,7 +2345,7 @@
           DexEncodedMethod implementation = target.getDefaultInterfaceMethodImplementation();
           if (implementation != null) {
             DexProgramClass companion = getProgramClassOrNull(implementation.method.holder);
-            markTypeAsLive(companion.type);
+            markTypeAsLive(companion, reason);
             markVirtualMethodAsLive(companion, implementation, reason);
           }
         }
@@ -2416,9 +2452,12 @@
 
   private void markParameterAndReturnTypesAsLive(DexEncodedMethod method) {
     for (DexType parameterType : method.method.proto.parameters.values) {
-      markTypeAsLive(parameterType);
+      markTypeAsLive(
+          parameterType, clazz -> graphReporter.reportClassReferencedFrom(clazz, method));
     }
-    markTypeAsLive(method.method.proto.returnType);
+    markTypeAsLive(
+        method.method.proto.returnType,
+        clazz -> graphReporter.reportClassReferencedFrom(clazz, method));
   }
 
   private void collectProguardCompatibilityRule(KeepReason reason) {
@@ -2813,7 +2852,7 @@
     }
 
     Set<T> getItems() {
-      return ImmutableSet.copyOf(items);
+      return Collections.unmodifiableSet(items);
     }
   }
 
@@ -2878,10 +2917,13 @@
 
     private final DexItem annotationHolder;
     private final DexItemFactory dexItemFactory;
+    private final KeepReason reason;
 
-    private AnnotationReferenceMarker(DexItem annotationHolder, DexItemFactory dexItemFactory) {
+    private AnnotationReferenceMarker(
+        DexItem annotationHolder, DexItemFactory dexItemFactory, KeepReason reason) {
       this.annotationHolder = annotationHolder;
       this.dexItemFactory = dexItemFactory;
+      this.reason = reason;
     }
 
     @Override
@@ -2971,7 +3013,7 @@
       // Annotations can also contain the void type, which is not a class type, so filter it out
       // here.
       if (type != dexItemFactory.voidType) {
-        markTypeAsLive(type);
+        markTypeAsLive(type, reason);
       }
       return false;
     }
@@ -3054,6 +3096,16 @@
       return KeepReasonWitness.INSTANCE;
     }
 
+    public KeepReasonWitness reportClassReferencedFrom(
+        DexProgramClass clazz, DexEncodedMethod method) {
+      if (keptGraphConsumer != null) {
+        MethodGraphNode source = getMethodGraphNode(method.method);
+        ClassGraphNode target = getClassGraphNode(clazz.type);
+        return reportEdge(source, target, EdgeKind.ReferencedFrom);
+      }
+      return KeepReasonWitness.INSTANCE;
+    }
+
     private KeepReasonWitness reportEdge(
         GraphNode source, GraphNode target, GraphEdgeInfo.EdgeKind kind) {
       assert keptGraphConsumer != null;
diff --git a/src/test/java/com/android/tools/r8/shaking/keptgraph/WhyAreYouKeepingAllTest.java b/src/test/java/com/android/tools/r8/shaking/keptgraph/WhyAreYouKeepingAllTest.java
index 4f0ba9e..f4ed50c 100644
--- a/src/test/java/com/android/tools/r8/shaking/keptgraph/WhyAreYouKeepingAllTest.java
+++ b/src/test/java/com/android/tools/r8/shaking/keptgraph/WhyAreYouKeepingAllTest.java
@@ -8,6 +8,8 @@
 import static org.hamcrest.MatcherAssert.assertThat;
 
 import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.utils.StringUtils;
 import java.io.ByteArrayOutputStream;
@@ -15,11 +17,15 @@
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
 
 /**
  * Run compiling R8 with R8 using a match-all -whyareyoukeeping rule to check that it does not cause
  * compilation to fail.
  */
+@RunWith(Parameterized.class)
 public class WhyAreYouKeepingAllTest extends TestBase {
 
   private static final Path MAIN_KEEP = Paths.get("src/main/keep.txt");
@@ -29,6 +35,17 @@
       "-whyareyoukeeping @interface ** { *; }"
   );
 
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withNoneRuntime().build();
+  }
+
+  final TestParameters parameters;
+
+  public WhyAreYouKeepingAllTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
   @Test
   public void test() throws Throwable {
     ByteArrayOutputStream baos = new ByteArrayOutputStream();