Mark program types referenced from failed resolutions as live

Bug: 214079809
Change-Id: I801aac3b819bc3b27a8fd8909f155ac668f0c451
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 68dba3e..2cb4329 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -608,16 +608,17 @@
   }
 
   private void recordTypeReference(DexType type, ProgramDefinition context) {
-    recordTypeReference(type, context, this::reportMissingClass);
+    recordTypeReference(type, context, this::recordNonProgramClass, this::reportMissingClass);
   }
 
   private void recordTypeReference(DexType type, ProgramDerivedContext context) {
-    recordTypeReference(type, context, this::reportMissingClass);
+    recordTypeReference(type, context, this::recordNonProgramClass, this::reportMissingClass);
   }
 
   private void recordTypeReference(
       DexType type,
       ProgramDerivedContext context,
+      BiConsumer<DexClass, ProgramDerivedContext> foundClassConsumer,
       BiConsumer<DexType, ProgramDerivedContext> missingClassConsumer) {
     if (type == null) {
       return;
@@ -629,21 +630,22 @@
       return;
     }
     // Lookup the definition, ignoring the result. This populates the missing and referenced sets.
-    definitionFor(type, context, missingClassConsumer);
+    definitionFor(type, context, foundClassConsumer, missingClassConsumer);
   }
 
   private void recordMethodReference(DexMethod method, ProgramDerivedContext context) {
-    recordMethodReference(method, context, this::reportMissingClass);
+    recordMethodReference(method, context, this::recordNonProgramClass, this::reportMissingClass);
   }
 
   private void recordMethodReference(
       DexMethod method,
       ProgramDerivedContext context,
+      BiConsumer<DexClass, ProgramDerivedContext> foundClassConsumer,
       BiConsumer<DexType, ProgramDerivedContext> missingClassConsumer) {
-    recordTypeReference(method.holder, context, missingClassConsumer);
-    recordTypeReference(method.proto.returnType, context, missingClassConsumer);
+    recordTypeReference(method.holder, context, foundClassConsumer, missingClassConsumer);
+    recordTypeReference(method.proto.returnType, context, foundClassConsumer, missingClassConsumer);
     for (DexType type : method.proto.parameters.values) {
-      recordTypeReference(type, context, missingClassConsumer);
+      recordTypeReference(type, context, foundClassConsumer, missingClassConsumer);
     }
   }
 
@@ -661,31 +663,28 @@
   }
 
   public DexClass definitionFor(DexType type, ProgramDefinition context) {
-    return definitionFor(type, context, this::reportMissingClass);
+    return definitionFor(type, context, this::recordNonProgramClass, this::reportMissingClass);
   }
 
   private DexClass definitionFor(
       DexType type,
       ProgramDerivedContext context,
+      BiConsumer<DexClass, ProgramDerivedContext> foundClassConsumer,
       BiConsumer<DexType, ProgramDerivedContext> missingClassConsumer) {
-    return internalDefinitionFor(type, context, missingClassConsumer);
+    return internalDefinitionFor(type, context, foundClassConsumer, missingClassConsumer);
   }
 
   private DexClass internalDefinitionFor(
       DexType type,
       ProgramDerivedContext context,
+      BiConsumer<DexClass, ProgramDerivedContext> foundClassConsumer,
       BiConsumer<DexType, ProgramDerivedContext> missingClassConsumer) {
     DexClass clazz = appInfo().definitionFor(type);
     if (clazz == null) {
       missingClassConsumer.accept(type, context);
       return null;
     }
-    if (clazz.isNotProgramClass()) {
-      addLiveNonProgramType(
-          clazz.asClasspathOrLibraryClass(),
-          (missingType, derivedContext) ->
-              reportMissingClass(missingType, derivedContext.asProgramDerivedContext(context)));
-    }
+    foundClassConsumer.accept(clazz, context);
     return clazz;
   }
 
@@ -772,7 +771,8 @@
   private DexProgramClass getProgramClassOrNullFromReflectiveAccess(
       DexType type, ProgramDefinition context) {
     // To avoid that we report reflectively accessed types as missing.
-    DexClass clazz = definitionFor(type, context, this::ignoreMissingClass);
+    DexClass clazz =
+        definitionFor(type, context, this::recordNonProgramClass, this::ignoreMissingClass);
     return clazz != null && clazz.isProgramClass() ? clazz.asProgramClass() : null;
   }
 
@@ -1805,8 +1805,16 @@
               ? this::reportMissingClass
               : this::ignoreMissingClass;
       for (InnerClassAttribute innerClassAttribute : clazz.getInnerClasses()) {
-        recordTypeReference(innerClassAttribute.getInner(), clazz, missingClassConsumer);
-        recordTypeReference(innerClassAttribute.getOuter(), clazz, missingClassConsumer);
+        recordTypeReference(
+            innerClassAttribute.getInner(),
+            clazz,
+            this::recordNonProgramClass,
+            missingClassConsumer);
+        recordTypeReference(
+            innerClassAttribute.getOuter(),
+            clazz,
+            this::recordNonProgramClass,
+            missingClassConsumer);
       }
     }
 
@@ -1828,10 +1836,12 @@
               ? this::reportMissingClass
               : this::ignoreMissingClass;
       if (enclosingMethod != null) {
-        recordMethodReference(enclosingMethod, clazz, missingClassConsumer);
+        recordMethodReference(
+            enclosingMethod, clazz, this::recordNonProgramClass, missingClassConsumer);
       } else {
         DexType enclosingClass = enclosingMethodAttribute.getEnclosingClass();
-        recordTypeReference(enclosingClass, clazz, missingClassConsumer);
+        recordTypeReference(
+            enclosingClass, clazz, this::recordNonProgramClass, missingClassConsumer);
       }
     }
 
@@ -2072,11 +2082,13 @@
   private SingleResolutionResult resolveMethod(
       DexMethod method, ProgramDefinition context, KeepReason reason) {
     // Record the references in case they are not program types.
-    recordMethodReference(method, context);
     MethodResolutionResult resolutionResult = appInfo.unsafeResolveMethodDueToDexFormat(method);
     if (resolutionResult.isFailedResolution()) {
       markFailedMethodResolutionTargets(
           method, resolutionResult.asFailedResolution(), context, reason);
+      recordMethodReference(method, context, this::recordFoundClass, this::reportMissingClass);
+    } else {
+      recordMethodReference(method, context);
     }
     return resolutionResult.asSingleResolution();
   }
@@ -2092,7 +2104,7 @@
       assert resolutionResult.isFailedResolution();
       markFailedMethodResolutionTargets(
           method, resolutionResult.asFailedResolution(), context, reason);
-      recordMethodReference(method, context);
+      recordMethodReference(method, context, this::recordFoundClass, this::reportMissingClass);
     }
     return resolutionResult.asSingleResolution();
   }
@@ -2327,6 +2339,25 @@
         });
   }
 
+  private void recordFoundClass(DexClass clazz, ProgramDerivedContext context) {
+    if (clazz.isProgramClass()) {
+      if (context.isProgramContext()) {
+        markTypeAsLive(clazz, context.getContext().asProgramDefinition());
+      }
+    } else {
+      recordNonProgramClass(clazz, context);
+    }
+  }
+
+  private void recordNonProgramClass(DexClass clazz, ProgramDerivedContext context) {
+    if (!clazz.isProgramClass()) {
+      addLiveNonProgramType(
+          clazz.asClasspathOrLibraryClass(),
+          (missingType, derivedContext) ->
+              reportMissingClass(missingType, derivedContext.asProgramDerivedContext(context)));
+    }
+  }
+
   private void ignoreMissingClass(DexType clazz) {
     missingClassesBuilder.ignoreNewMissingClass(clazz);
   }
@@ -2968,19 +2999,18 @@
       return;
     }
 
-    // Note that all virtual methods derived from library methods are kept regardless of being
-    // reachable, so the following only needs to consider reachable targets in the program.
-    // TODO(b/70160030): Revise this to support tree shaking library methods on non-escaping types.
-    DexProgramClass holder = getProgramClassOrNull(method.holder, context);
-    if (holder == null) {
-      // TODO(b/139464956): clean this.
-      // Ensure that the full proto of the targeted method is referenced.
-      recordMethodReference(method, context);
+    SingleResolutionResult resolution = resolveMethod(method, context, reason, interfaceInvoke);
+    if (resolution == null) {
       return;
     }
 
-    SingleResolutionResult resolution = resolveMethod(method, context, reason, interfaceInvoke);
-    if (resolution == null) {
+    // Note that all virtual methods derived from library methods are kept regardless of being
+    // reachable, so the following only needs to consider reachable targets in the program.
+    // TODO(b/70160030): Revise this to support tree shaking library methods on non-escaping types.
+    DexProgramClass initialResolutionHolder =
+        resolution.getInitialResolutionHolder().asProgramClass();
+    if (initialResolutionHolder == null) {
+      recordMethodReference(method, context);
       return;
     }
 
@@ -2996,7 +3026,8 @@
     // If the method has already been marked, just report the new reason for the resolved target and
     // save the context to ensure correct lookup of virtual dispatch targets.
     ResolutionSearchKey resolutionSearchKey = new ResolutionSearchKey(method, interfaceInvoke);
-    ProgramMethodSet seenContexts = getReachableVirtualTargets(holder).get(resolutionSearchKey);
+    ProgramMethodSet seenContexts =
+        getReachableVirtualTargets(initialResolutionHolder).get(resolutionSearchKey);
     if (seenContexts != null) {
       seenContexts.add(context);
       graphReporter.registerMethod(resolution.getResolvedMethod(), reason);
@@ -3019,7 +3050,7 @@
 
     // The method resolved and is accessible, so currently live overrides become live.
     reachableVirtualTargets
-        .computeIfAbsent(holder, ignoreArgument(HashMap::new))
+        .computeIfAbsent(initialResolutionHolder, ignoreArgument(HashMap::new))
         .computeIfAbsent(resolutionSearchKey, ignoreArgument(ProgramMethodSet::create))
         .add(context);
 
@@ -4987,7 +5018,8 @@
     }
 
     public DexClass definitionFor(DexType type, ProgramDefinition context) {
-      return enqueuer.definitionFor(type, context, enqueuer::ignoreMissingClass);
+      return enqueuer.definitionFor(
+          type, context, enqueuer::recordNonProgramClass, enqueuer::ignoreMissingClass);
     }
   }