[TraceRef] Model covariant return types in tracereferences

Bug: b/271005787
Change-Id: I6f641fb63100e17eb41d685d8bc1f096fd32fd98
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 7673d69..b46adc0 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -11,13 +11,12 @@
 import static com.android.tools.r8.naming.IdentifierNameStringUtils.identifyIdentifier;
 import static com.android.tools.r8.naming.IdentifierNameStringUtils.isReflectionMethod;
 import static com.android.tools.r8.shaking.KeepInfo.Joiner.asFieldJoinerOrNull;
+import static com.android.tools.r8.utils.CovariantReturnTypeUtils.modelLibraryMethodsWithCovariantReturnTypes;
 import static com.android.tools.r8.utils.FunctionUtils.ignoreArgument;
 import static com.android.tools.r8.utils.MapUtils.ignoreKey;
 import static java.util.Collections.emptySet;
 
 import com.android.tools.r8.Diagnostic;
-import com.android.tools.r8.androidapi.ComputedApiLevel;
-import com.android.tools.r8.androidapi.CovariantReturnTypeMethods;
 import com.android.tools.r8.cf.code.CfInstruction;
 import com.android.tools.r8.cf.code.CfInvoke;
 import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
@@ -75,7 +74,6 @@
 import com.android.tools.r8.graph.LookupMethodTarget;
 import com.android.tools.r8.graph.LookupResult;
 import com.android.tools.r8.graph.LookupTarget;
-import com.android.tools.r8.graph.MethodAccessFlags;
 import com.android.tools.r8.graph.MethodAccessInfoCollection;
 import com.android.tools.r8.graph.MethodResolutionResult;
 import com.android.tools.r8.graph.MethodResolutionResult.FailedResolutionResult;
@@ -3586,7 +3584,7 @@
     if (mode.isInitialTreeShaking()) {
       // Amend library methods with covariant return types.
       timing.begin("Model library");
-      modelLibraryMethodsWithCovariantReturnTypes();
+      modelLibraryMethodsWithCovariantReturnTypes(appView);
       timing.end();
     } else if (appView.getKeepInfo() != null) {
       timing.begin("Retain keep info");
@@ -3645,30 +3643,6 @@
             this::recordDependentMinimumKeepInfo);
   }
 
-  private void modelLibraryMethodsWithCovariantReturnTypes() {
-    CovariantReturnTypeMethods.registerMethodsWithCovariantReturnType(
-        appView.dexItemFactory(),
-        method -> {
-          DexLibraryClass libraryClass =
-              DexLibraryClass.asLibraryClassOrNull(
-                  appView.appInfo().definitionForWithoutExistenceAssert(method.getHolderType()));
-          if (libraryClass == null) {
-            return;
-          }
-          // Check if the covariant method exists on the class.
-          DexEncodedMethod covariantMethod = libraryClass.lookupMethod(method);
-          if (covariantMethod != null) {
-            return;
-          }
-          libraryClass.addVirtualMethod(
-              DexEncodedMethod.builder()
-                  .setMethod(method)
-                  .setAccessFlags(MethodAccessFlags.builder().setPublic().build())
-                  .setApiLevelForDefinition(ComputedApiLevel.notSet())
-                  .build());
-        });
-  }
-
   private void applyMinimumKeepInfo(DexProgramClass clazz) {
     EnqueuerEvent preconditionEvent = UnconditionalKeepInfoEvent.get();
     KeepClassInfo.Joiner minimumKeepInfoForClass =
diff --git a/src/main/java/com/android/tools/r8/tracereferences/TraceReferences.java b/src/main/java/com/android/tools/r8/tracereferences/TraceReferences.java
index 416a9a4..e1000f6 100644
--- a/src/main/java/com/android/tools/r8/tracereferences/TraceReferences.java
+++ b/src/main/java/com/android/tools/r8/tracereferences/TraceReferences.java
@@ -3,6 +3,8 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.tracereferences;
 
+import static com.android.tools.r8.utils.CovariantReturnTypeUtils.modelLibraryMethodsWithCovariantReturnTypes;
+
 import com.android.tools.r8.CompilationFailedException;
 import com.android.tools.r8.Keep;
 import com.android.tools.r8.ProgramResource;
@@ -11,8 +13,14 @@
 import com.android.tools.r8.ResourceException;
 import com.android.tools.r8.Version;
 import com.android.tools.r8.dex.ApplicationReader;
+import com.android.tools.r8.experimental.startup.StartupOrder;
+import com.android.tools.r8.features.ClassToFeatureSplitMap;
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
+import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.origin.CommandLineOrigin;
+import com.android.tools.r8.shaking.MainDexInfo;
+import com.android.tools.r8.synthesis.SyntheticItems.GlobalSyntheticsStrategy;
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.ExceptionUtils;
 import com.android.tools.r8.utils.InternalOptions;
@@ -73,7 +81,20 @@
     for (ProgramResourceProvider provider : command.getSource()) {
       forEachDescriptor(provider, targetDescriptors::remove);
     }
-    Tracer tracer = new Tracer(targetDescriptors, builder.build(), command.getReporter(), options);
+    AppView<AppInfoWithClassHierarchy> appView =
+        AppView.createForTracer(
+            AppInfoWithClassHierarchy.createInitialAppInfoWithClassHierarchy(
+                new ApplicationReader(builder.build(), options, Timing.empty()).read().toDirect(),
+                ClassToFeatureSplitMap.createEmptyClassToFeatureSplitMap(),
+                MainDexInfo.none(),
+                GlobalSyntheticsStrategy.forSingleOutputMode(),
+                StartupOrder.empty()));
+    modelLibraryMethodsWithCovariantReturnTypes(appView);
+    Tracer tracer =
+        new Tracer(
+            appView,
+            command.getReporter(),
+            type -> targetDescriptors.contains(type.toDescriptorString()));
     tracer.run(command.getConsumer());
   }
 
diff --git a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
index 97c4499..f182a0b 100644
--- a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
+++ b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
@@ -4,11 +4,8 @@
 package com.android.tools.r8.tracereferences;
 
 import com.android.tools.r8.DiagnosticsHandler;
-import com.android.tools.r8.dex.ApplicationReader;
 import com.android.tools.r8.diagnostic.DefinitionContext;
 import com.android.tools.r8.diagnostic.internal.DefinitionContextUtils;
-import com.android.tools.r8.experimental.startup.StartupOrder;
-import com.android.tools.r8.features.ClassToFeatureSplitMap;
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.ClassResolutionResult;
@@ -41,17 +38,11 @@
 import com.android.tools.r8.references.FieldReference;
 import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.references.Reference;
-import com.android.tools.r8.shaking.MainDexInfo;
-import com.android.tools.r8.synthesis.SyntheticItems.GlobalSyntheticsStrategy;
 import com.android.tools.r8.tracereferences.TraceReferencesConsumer.TracedReference;
 import com.android.tools.r8.tracereferences.internal.TracedClassImpl;
 import com.android.tools.r8.tracereferences.internal.TracedFieldImpl;
 import com.android.tools.r8.tracereferences.internal.TracedMethodImpl;
-import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.BooleanBox;
-import com.android.tools.r8.utils.InternalOptions;
-import com.android.tools.r8.utils.Timing;
-import java.io.IOException;
 import java.util.HashSet;
 import java.util.Set;
 import java.util.function.Function;
@@ -63,24 +54,6 @@
   private final DiagnosticsHandler diagnostics;
   private final Predicate<DexType> targetPredicate;
 
-  Tracer(
-      Set<String> targetDescriptors,
-      AndroidApp inputApp,
-      DiagnosticsHandler diagnostics,
-      InternalOptions options)
-      throws IOException {
-    this(
-        AppView.createForTracer(
-            AppInfoWithClassHierarchy.createInitialAppInfoWithClassHierarchy(
-                new ApplicationReader(inputApp, options, Timing.empty()).read().toDirect(),
-                ClassToFeatureSplitMap.createEmptyClassToFeatureSplitMap(),
-                MainDexInfo.none(),
-                GlobalSyntheticsStrategy.forSingleOutputMode(),
-                StartupOrder.empty())),
-        diagnostics,
-        type -> targetDescriptors.contains(type.toDescriptorString()));
-  }
-
   public Tracer(
       AppView<? extends AppInfoWithClassHierarchy> appView,
       DiagnosticsHandler diagnostics,
diff --git a/src/main/java/com/android/tools/r8/utils/CovariantReturnTypeUtils.java b/src/main/java/com/android/tools/r8/utils/CovariantReturnTypeUtils.java
new file mode 100644
index 0000000..202132f
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/CovariantReturnTypeUtils.java
@@ -0,0 +1,39 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import com.android.tools.r8.androidapi.ComputedApiLevel;
+import com.android.tools.r8.androidapi.CovariantReturnTypeMethods;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexLibraryClass;
+import com.android.tools.r8.graph.MethodAccessFlags;
+
+public class CovariantReturnTypeUtils {
+
+  public static void modelLibraryMethodsWithCovariantReturnTypes(AppView<?> appView) {
+    CovariantReturnTypeMethods.registerMethodsWithCovariantReturnType(
+        appView.dexItemFactory(),
+        method -> {
+          DexLibraryClass libraryClass =
+              DexLibraryClass.asLibraryClassOrNull(
+                  appView.appInfo().definitionForWithoutExistenceAssert(method.getHolderType()));
+          if (libraryClass == null) {
+            return;
+          }
+          // Check if the covariant method exists on the class.
+          DexEncodedMethod covariantMethod = libraryClass.lookupMethod(method);
+          if (covariantMethod != null) {
+            return;
+          }
+          libraryClass.addVirtualMethod(
+              DexEncodedMethod.builder()
+                  .setMethod(method)
+                  .setAccessFlags(MethodAccessFlags.builder().setPublic().build())
+                  .setApiLevelForDefinition(ComputedApiLevel.notSet())
+                  .build());
+        });
+  }
+}