Towards supporting @CovariantReturnType in R8

Bug: b/211362069
Change-Id: I5cc4cf91ae903c911e8a7d371b1800c54ef2aced
diff --git a/src/main/java/com/android/tools/r8/desugar/covariantreturntype/CovariantReturnTypeAnnotationTransformer.java b/src/main/java/com/android/tools/r8/desugar/covariantreturntype/CovariantReturnTypeAnnotationTransformer.java
index 55f4d78..ea193a5 100644
--- a/src/main/java/com/android/tools/r8/desugar/covariantreturntype/CovariantReturnTypeAnnotationTransformer.java
+++ b/src/main/java/com/android/tools/r8/desugar/covariantreturntype/CovariantReturnTypeAnnotationTransformer.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.desugar.covariantreturntype;
 
 import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexAnnotation;
 import com.android.tools.r8.graph.DexAnnotationElement;
@@ -24,9 +25,13 @@
 import com.android.tools.r8.ir.conversion.MethodProcessorEventConsumer;
 import com.android.tools.r8.ir.synthetic.ForwardMethodBuilder;
 import com.android.tools.r8.utils.ForEachable;
+import com.android.tools.r8.utils.ListUtils;
+import com.android.tools.r8.utils.ThreadUtils;
 import java.util.ArrayList;
+import java.util.Comparator;
 import java.util.LinkedHashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
@@ -59,6 +64,11 @@
   private final DexItemFactory factory;
   private final CovariantReturnTypeReferences references;
 
+  public CovariantReturnTypeAnnotationTransformer(
+      AppView<? extends AppInfoWithClassHierarchy> appView) {
+    this(appView, null);
+  }
+
   private CovariantReturnTypeAnnotationTransformer(AppView<?> appView, IRConverter converter) {
     this.appView = appView;
     this.converter = converter;
@@ -104,6 +114,25 @@
         executorService);
   }
 
+  public void processMethods(
+      Map<DexProgramClass, List<ProgramMethod>> methodsToProcess,
+      CovariantReturnTypeAnnotationTransformerEventConsumer eventConsumer,
+      ExecutorService executorService)
+      throws ExecutionException {
+    if (methodsToProcess.isEmpty()) {
+      return;
+    }
+    ThreadUtils.processMap(
+        methodsToProcess,
+        (clazz, methods) -> {
+          List<ProgramMethod> sortedMethods =
+              ListUtils.destructiveSort(methods, Comparator.comparing(ProgramMethod::getReference));
+          processClass(clazz, sortedMethods::forEach, eventConsumer);
+        },
+        appView.options().getThreadingModule(),
+        executorService);
+  }
+
   // Processes all the dalvik.annotation.codegen.CovariantReturnType and dalvik.annotation.codegen.
   // CovariantReturnTypes annotations in the given DexClass. Adds the newly constructed, synthetic
   // methods to the list covariantReturnTypeMethods.
@@ -208,6 +237,14 @@
     return covariantReturnTypes;
   }
 
+  public boolean hasCovariantReturnTypeAnnotation(ProgramMethod method) {
+    return method
+        .getAnnotations()
+        .hasAnnotation(
+            annotation ->
+                references.isOneOfCovariantReturnTypeAnnotations(annotation.getAnnotationType()));
+  }
+
   private void getCovariantReturnTypesFromAnnotation(
       ProgramMethod method, DexEncodedAnnotation annotation, Set<DexType> covariantReturnTypes) {
     boolean hasPresentAfterElement = false;
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 5ca4dbe..bf8be36 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -22,9 +22,11 @@
 import com.android.tools.r8.cf.code.CfInvoke;
 import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
 import com.android.tools.r8.contexts.CompilationContext.ProcessorContext;
+import com.android.tools.r8.desugar.covariantreturntype.CovariantReturnTypeAnnotationTransformer;
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.dex.code.CfOrDexInstruction;
 import com.android.tools.r8.errors.InterfaceDesugarMissingTypeDiagnostic;
+import com.android.tools.r8.errors.Unimplemented;
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.experimental.graphinfo.GraphConsumer;
 import com.android.tools.r8.features.IsolatedFeatureSplitsChecker;
@@ -478,6 +480,10 @@
   private final CfInstructionDesugaringCollection desugaring;
   private final ProgramMethodSet pendingCodeDesugaring = ProgramMethodSet.create();
 
+  private final CovariantReturnTypeAnnotationTransformer covariantReturnTypeAnnotationTransformer;
+  private final Map<DexProgramClass, List<ProgramMethod>> pendingCovariantReturnTypeDesugaring =
+      new IdentityHashMap<>();
+
   // Collections for tracing progress on interface method desugaring.
 
   // The pending method move set is all the methods that need to be moved to companions.
@@ -551,9 +557,12 @@
     if (mode.isInitialTreeShaking()) {
       desugaring = CfInstructionDesugaringCollection.create(appView, appView.apiLevelCompute());
       interfaceProcessor = InterfaceProcessor.create(appView);
+      covariantReturnTypeAnnotationTransformer =
+          new CovariantReturnTypeAnnotationTransformer(appView);
     } else {
       desugaring = CfInstructionDesugaringCollection.empty();
       interfaceProcessor = null;
+      covariantReturnTypeAnnotationTransformer = null;
     }
 
     objectAllocationInfoCollection =
@@ -4303,6 +4312,7 @@
     // registered first and no dependencies may exist among them.
     SyntheticAdditions additions = new SyntheticAdditions(appView.createProcessorContext());
     desugar(additions);
+    processCovariantReturnTypeAnnotations();
     synthesizeInterfaceMethodBridges();
     if (additions.isEmpty()) {
       return;
@@ -4325,6 +4335,11 @@
   }
 
   private boolean addToPendingDesugaring(ProgramMethod method) {
+    if (covariantReturnTypeAnnotationTransformer.hasCovariantReturnTypeAnnotation(method)) {
+      pendingCovariantReturnTypeDesugaring
+          .computeIfAbsent(method.getHolder(), ignoreKey(ArrayList::new))
+          .add(method);
+    }
     if (options.isInterfaceMethodDesugaringEnabled()) {
       if (mustMoveToInterfaceCompanionMethod(method)) {
         // TODO(b/199043500): Once "live moved methods" are tracked this can avoid the code check.
@@ -4459,6 +4474,38 @@
     }
   }
 
+  private void processCovariantReturnTypeAnnotations() throws ExecutionException {
+    covariantReturnTypeAnnotationTransformer.processMethods(
+        pendingCovariantReturnTypeDesugaring,
+        (bridge, target) -> {
+          KeepMethodInfo.Joiner bridgeKeepInfo = getKeepInfoForCovariantReturnTypeBridge(target);
+          keepInfo.registerCompilerSynthesizedMethod(bridge);
+          applyMinimumKeepInfoWhenLiveOrTargeted(bridge, bridgeKeepInfo);
+          profileCollectionAdditions.addMethodIfContextIsInProfile(bridge, target);
+        },
+        executorService);
+    pendingCovariantReturnTypeDesugaring.clear();
+  }
+
+  private KeepMethodInfo.Joiner getKeepInfoForCovariantReturnTypeBridge(ProgramMethod target) {
+    KeepInfo.Joiner<?, ?, ?> targetKeepInfo =
+        appView
+            .rootSet()
+            .getDependentMinimumKeepInfo()
+            .getUnconditionalMinimumKeepInfoOrDefault(MinimumKeepInfoCollection.empty())
+            .getOrDefault(target.getReference(), null);
+    if (targetKeepInfo == null) {
+      targetKeepInfo = KeepMethodInfo.newEmptyJoiner();
+    }
+    if ((options.isMinifying() && targetKeepInfo.isMinificationAllowed())
+        || (options.isOptimizing() && targetKeepInfo.isOptimizationAllowed())
+        || (options.isShrinking() && targetKeepInfo.isShrinkingAllowed())) {
+      // TODO(b/211362069): Report a fatal diagnostic explaining the problem.
+      throw new Unimplemented();
+    }
+    return targetKeepInfo.asMethodJoiner();
+  }
+
   private void synthesizeInterfaceMethodBridges() {
     for (InterfaceMethodSyntheticBridgeAction action : syntheticInterfaceMethodBridges.values()) {
       ProgramMethod bridge = action.getMethodToKeep();
diff --git a/src/main/java/com/android/tools/r8/shaking/KeepInfo.java b/src/main/java/com/android/tools/r8/shaking/KeepInfo.java
index 23be490..d37185a 100644
--- a/src/main/java/com/android/tools/r8/shaking/KeepInfo.java
+++ b/src/main/java/com/android/tools/r8/shaking/KeepInfo.java
@@ -544,6 +544,10 @@
       return builder.isCheckDiscardedEnabled();
     }
 
+    public boolean isMinificationAllowed() {
+      return builder.isMinificationAllowed();
+    }
+
     public boolean isOptimizationAllowed() {
       return builder.isOptimizationAllowed();
     }
diff --git a/src/test/java/com/android/tools/r8/ir/desugar/annotations/CovariantReturnTypeAnnotationTransformerR8Test.java b/src/test/java/com/android/tools/r8/ir/desugar/annotations/CovariantReturnTypeAnnotationTransformerR8Test.java
new file mode 100644
index 0000000..e8ff92c
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/desugar/annotations/CovariantReturnTypeAnnotationTransformerR8Test.java
@@ -0,0 +1,154 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.desugar.annotations;
+
+import static com.android.tools.r8.utils.codeinspector.AssertUtils.assertFailsCompilation;
+
+import com.android.tools.r8.R8FullTestBuilder;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.ThrowableConsumer;
+import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.ir.desugar.annotations.CovariantReturnType.CovariantReturnTypes;
+import com.google.common.collect.ImmutableMap;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class CovariantReturnTypeAnnotationTransformerR8Test extends TestBase {
+
+  private static final String covariantReturnTypeDescriptor =
+      "Ldalvik/annotation/codegen/CovariantReturnType;";
+  private static final String covariantReturnTypesDescriptor =
+      "Ldalvik/annotation/codegen/CovariantReturnType$CovariantReturnTypes;";
+
+  private static final Map<String, String> descriptorTransformation =
+      ImmutableMap.of(
+          descriptor(com.android.tools.r8.ir.desugar.annotations.version2.B.class),
+          descriptor(B.class),
+          descriptor(com.android.tools.r8.ir.desugar.annotations.version2.C.class),
+          descriptor(C.class),
+          descriptor(com.android.tools.r8.ir.desugar.annotations.version2.E.class),
+          descriptor(E.class),
+          descriptor(com.android.tools.r8.ir.desugar.annotations.version2.F.class),
+          descriptor(F.class),
+          descriptor(CovariantReturnType.class),
+          covariantReturnTypeDescriptor,
+          descriptor(CovariantReturnTypes.class),
+          covariantReturnTypesDescriptor);
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDexRuntimes().withMaximumApiLevel().build();
+  }
+
+  @Test
+  public void testDontObfuscateDontOptimizeDontShrink() throws Exception {
+    R8TestCompileResult r8CompileResult =
+        compileWithR8(
+            testBuilder -> testBuilder.addDontObfuscate().addDontOptimize().addDontShrink());
+    testOnRuntime(r8CompileResult);
+  }
+
+  @Test
+  public void testUnconditionalKeepAllPublicMethods() throws Exception {
+    R8TestCompileResult r8CompileResult =
+        compileWithR8(
+            testBuilder -> testBuilder.addKeepRules("-keep public class * { public <methods>; }"));
+    testOnRuntime(r8CompileResult);
+  }
+
+  @Test
+  public void testUnconditionalKeepAllPublicMethodsAllowObfuscation() throws Exception {
+    assertFailsCompilation(
+        () ->
+            compileWithR8(
+                testBuilder ->
+                    testBuilder.addKeepRules(
+                        "-keep,allowobfuscation public class * { public <methods>; }")));
+  }
+
+  @Test
+  public void testUnconditionalKeepAllPublicMethodsAllowOptimization() throws Exception {
+    assertFailsCompilation(
+        () ->
+            compileWithR8(
+                testBuilder ->
+                    testBuilder.addKeepRules(
+                        "-keep,allowoptimization public class * { public <methods>; }")));
+  }
+
+  @Test
+  public void testConditionalKeepAllPublicMethods() throws Exception {
+    assertFailsCompilation(
+        () ->
+            compileWithR8(
+                testBuilder ->
+                    testBuilder.addKeepRules(
+                        "-if public class * -keep class <1> { public <methods>; }",
+                        "-keep public class *")));
+  }
+
+  private R8TestCompileResult compileWithR8(
+      ThrowableConsumer<? super R8FullTestBuilder> configuration) throws Exception {
+    return testForR8(parameters.getBackend())
+        .addProgramClasses(A.class, D.class)
+        .addProgramClassFileData(
+            transformer(com.android.tools.r8.ir.desugar.annotations.version2.B.class)
+                .replaceClassDescriptorInAnnotations(descriptorTransformation)
+                .replaceClassDescriptorInMethodInstructions(descriptorTransformation)
+                .setClassDescriptor(descriptor(B.class))
+                .transform(),
+            transformer(com.android.tools.r8.ir.desugar.annotations.version2.C.class)
+                .replaceClassDescriptorInAnnotations(descriptorTransformation)
+                .replaceClassDescriptorInMethodInstructions(descriptorTransformation)
+                .setClassDescriptor(descriptor(C.class))
+                .transform(),
+            transformer(com.android.tools.r8.ir.desugar.annotations.version2.E.class)
+                .replaceClassDescriptorInAnnotations(descriptorTransformation)
+                .replaceClassDescriptorInMethodInstructions(descriptorTransformation)
+                .setClassDescriptor(descriptor(E.class))
+                .transform(),
+            transformer(com.android.tools.r8.ir.desugar.annotations.version2.F.class)
+                .replaceClassDescriptorInAnnotations(descriptorTransformation)
+                .replaceClassDescriptorInMethodInstructions(descriptorTransformation)
+                .setClassDescriptor(descriptor(F.class))
+                .transform(),
+            transformer(CovariantReturnType.class)
+                .replaceClassDescriptorInAnnotations(descriptorTransformation)
+                .setClassDescriptor(covariantReturnTypeDescriptor)
+                .transform(),
+            transformer(CovariantReturnTypes.class)
+                .replaceClassDescriptorInAnnotations(descriptorTransformation)
+                .replaceClassDescriptorInMembers(
+                    descriptor(CovariantReturnType.class), covariantReturnTypeDescriptor)
+                .setClassDescriptor(covariantReturnTypesDescriptor)
+                .transform())
+        .addLibraryFiles(ToolHelper.getMostRecentAndroidJar())
+        .addOptionsModification(options -> options.processCovariantReturnTypeAnnotations = true)
+        .apply(configuration)
+        .setMinApi(parameters)
+        .compile();
+  }
+
+  private void testOnRuntime(R8TestCompileResult r8CompileResult) throws Exception {
+    testForD8()
+        .addProgramClasses(Client.class)
+        .addClasspathClasses(A.class, B.class, C.class, D.class, E.class, F.class)
+        .setMinApi(parameters)
+        .compile()
+        .addRunClasspathFiles(r8CompileResult.writeToZip())
+        .run(parameters.getRuntime(), Client.class)
+        .assertSuccessWithOutputLines("a=A", "b=B", "c=C", "d=F", "e=F", "f=F");
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/desugar/annotations/version2/C.java b/src/test/java/com/android/tools/r8/ir/desugar/annotations/version2/C.java
index 52fe17f..058e13e 100644
--- a/src/test/java/com/android/tools/r8/ir/desugar/annotations/version2/C.java
+++ b/src/test/java/com/android/tools/r8/ir/desugar/annotations/version2/C.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.ir.desugar.annotations.version2;
 
 import com.android.tools.r8.ir.desugar.annotations.A;
+import com.android.tools.r8.ir.desugar.annotations.B;
 import com.android.tools.r8.ir.desugar.annotations.CovariantReturnType;
 
 public class C extends B {
diff --git a/src/test/java/com/android/tools/r8/ir/desugar/annotations/version2/F.java b/src/test/java/com/android/tools/r8/ir/desugar/annotations/version2/F.java
index a133826..5423b42 100644
--- a/src/test/java/com/android/tools/r8/ir/desugar/annotations/version2/F.java
+++ b/src/test/java/com/android/tools/r8/ir/desugar/annotations/version2/F.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.ir.desugar.annotations.CovariantReturnType;
 import com.android.tools.r8.ir.desugar.annotations.D;
+import com.android.tools.r8.ir.desugar.annotations.E;
 
 public class F extends E {
   @CovariantReturnType(returnType = E.class, presentAfter = 25)
diff --git a/src/test/java/com/android/tools/r8/repackage/RepackageMissingTypeCollisionTest.java b/src/test/java/com/android/tools/r8/repackage/RepackageMissingTypeCollisionTest.java
index e1f0cb0..fff87b6 100644
--- a/src/test/java/com/android/tools/r8/repackage/RepackageMissingTypeCollisionTest.java
+++ b/src/test/java/com/android/tools/r8/repackage/RepackageMissingTypeCollisionTest.java
@@ -43,7 +43,7 @@
             transformer(A.class).setClassDescriptor(newADescriptor).transform(),
             transformer(Anno.class)
                 .replaceClassDescriptorInMembers(descriptor(Missing.class), newMissingDescriptor)
-                .replaceClassDescriptorInAnnotationDefault(
+                .replaceClassDescriptorInAnnotations(
                     descriptor(Missing.class), newMissingDescriptor)
                 .transform(),
             transformer(Main.class)
@@ -96,7 +96,7 @@
             transformer(A.class).setClassDescriptor(newADescriptor).transform(),
             transformer(Anno.class)
                 .replaceClassDescriptorInMembers(descriptor(Missing.class), newMissingDescriptor)
-                .replaceClassDescriptorInAnnotationDefault(
+                .replaceClassDescriptorInAnnotations(
                     descriptor(Missing.class), newMissingDescriptor)
                 .transform(),
             transformer(Main.class)
diff --git a/src/test/testbase/java/com/android/tools/r8/transformers/ClassFileTransformer.java b/src/test/testbase/java/com/android/tools/r8/transformers/ClassFileTransformer.java
index 72bc22a..69bd9b2 100644
--- a/src/test/testbase/java/com/android/tools/r8/transformers/ClassFileTransformer.java
+++ b/src/test/testbase/java/com/android/tools/r8/transformers/ClassFileTransformer.java
@@ -1089,27 +1089,67 @@
         });
   }
 
-  public ClassFileTransformer replaceClassDescriptorInAnnotationDefault(
+  public ClassFileTransformer replaceClassDescriptorInAnnotations(
       String oldDescriptor, String newDescriptor) {
-    return addMethodTransformer(
-        new MethodTransformer() {
+    return replaceClassDescriptorInAnnotations(ImmutableMap.of(oldDescriptor, newDescriptor));
+  }
 
-          @Override
-          public AnnotationVisitor visitAnnotationDefault() {
-            return new AnnotationVisitor(ASM_VERSION, super.visitAnnotationDefault()) {
-              @Override
-              public void visit(String name, Object value) {
-                super.visit(name, value);
-              }
+  public ClassFileTransformer replaceClassDescriptorInAnnotations(Map<String, String> map) {
+    class AnnotationTransformer extends AnnotationVisitor {
 
-              @Override
-              public void visitEnum(String name, String descriptor, String value) {
-                super.visitEnum(
-                    name, descriptor.equals(oldDescriptor) ? newDescriptor : descriptor, value);
-              }
-            };
+      protected AnnotationTransformer(AnnotationVisitor annotationVisitor) {
+        super(ASM_VERSION, annotationVisitor);
+      }
+
+      @Override
+      public AnnotationVisitor visitAnnotation(String name, String descriptor) {
+        return new AnnotationTransformer(
+            super.visitAnnotation(name, map.getOrDefault(descriptor, descriptor)));
+      }
+
+      @Override
+      public AnnotationVisitor visitArray(String name) {
+        return new AnnotationTransformer(super.visitArray(name));
+      }
+
+      @Override
+      public void visitEnum(String name, String descriptor, String value) {
+        super.visitEnum(name, map.getOrDefault(descriptor, descriptor), value);
+      }
+
+      @Override
+      public void visit(String name, Object value) {
+        if (value instanceof Type) {
+          Type type = (Type) value;
+          if (map.containsKey(type.getDescriptor())) {
+            value = Type.getType(map.get(type.getDescriptor()));
           }
-        });
+        }
+        super.visit(name, value);
+      }
+    }
+    return addClassTransformer(
+            new ClassTransformer() {
+              @Override
+              public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
+                return new AnnotationTransformer(
+                    super.visitAnnotation(map.getOrDefault(descriptor, descriptor), visible));
+              }
+            })
+        .addMethodTransformer(
+            new MethodTransformer() {
+
+              @Override
+              public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
+                return new AnnotationTransformer(
+                    super.visitAnnotation(map.getOrDefault(descriptor, descriptor), visible));
+              }
+
+              @Override
+              public AnnotationVisitor visitAnnotationDefault() {
+                return new AnnotationTransformer(super.visitAnnotationDefault());
+              }
+            });
   }
 
   public ClassFileTransformer replaceClassDescriptorInMethodInstructions(