Processing of @CovariantReturnType annotations

This CL is responsible for processing the annotations dalvik.annotation.codegen.CovariantReturnType and dalvik.annotation.codegen.CovariantReturnType$CovariantReturnTypes. If such an annotation is attached to a method m, then the compiler should insert a new synthetic method that is equivalent to method m, but has the return type specified by the annotation.

Bug: 78618808
Change-Id: Id4ea6c1e3f5c97a57954b1ededca473a022736a3
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 86ca12a..4e4a78f 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -230,6 +230,10 @@
   public final DexType annotationThrows = createType("Ldalvik/annotation/Throws;");
   public final DexType annotationSynthesizedClassMap =
       createType("Lcom/android/tools/r8/annotations/SynthesizedClassMap;");
+  public final DexType annotationCovariantReturnType =
+      createType("Ldalvik/annotation/codegen/CovariantReturnType;");
+  public final DexType annotationCovariantReturnTypes =
+      createType("Ldalvik/annotation/codegen/CovariantReturnType$CovariantReturnTypes;");
 
   private static final String METAFACTORY_METHOD_NAME = "metafactory";
   private static final String METAFACTORY_ALT_METHOD_NAME = "altMetafactory";
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index fc5e95b..fa00cd1 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -6,6 +6,7 @@
 import static com.android.tools.r8.ir.desugar.InterfaceMethodRewriter.Flavor.ExcludeDexResources;
 import static com.android.tools.r8.ir.desugar.InterfaceMethodRewriter.Flavor.IncludeAllResources;
 
+import com.android.tools.r8.ApiLevelException;
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.AppInfoWithSubtyping;
@@ -32,6 +33,7 @@
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.code.ValueType;
+import com.android.tools.r8.ir.desugar.CovariantReturnTypeAnnotationTransformer;
 import com.android.tools.r8.ir.desugar.InterfaceMethodRewriter;
 import com.android.tools.r8.ir.desugar.LambdaRewriter;
 import com.android.tools.r8.ir.desugar.StringConcatRewriter;
@@ -99,6 +101,7 @@
   private final ProtoLitePruner protoLiteRewriter;
   private final IdentifierNameStringMarker identifierNameStringMarker;
   private final Devirtualizer devirtualizer;
+  private final CovariantReturnTypeAnnotationTransformer covariantReturnTypeAnnotationTransformer;
 
   private final OptimizationFeedback ignoreOptimizationFeedback = new OptimizationFeedbackIgnore();
   private DexString highestSortingString;
@@ -126,6 +129,10 @@
             ? new InterfaceMethodRewriter(this, options) : null;
     this.lambdaMerger = options.enableLambdaMerging
         ? new LambdaMerger(appInfo.dexItemFactory, options.reporter) : null;
+    this.covariantReturnTypeAnnotationTransformer =
+        options.processCovariantReturnTypeAnnotations
+            ? new CovariantReturnTypeAnnotationTransformer(this, appInfo.dexItemFactory)
+            : null;
     if (enableWholeProgramOptimizations) {
       assert appInfo.hasLiveness();
       this.nonNullTracker = new NonNullTracker();
@@ -246,6 +253,12 @@
     }
   }
 
+  private void processCovariantReturnTypeAnnotations(Builder<?> builder) throws ApiLevelException {
+    if (covariantReturnTypeAnnotationTransformer != null) {
+      covariantReturnTypeAnnotationTransformer.process(builder);
+    }
+  }
+
   public DexApplication convertToDex(DexApplication application, ExecutorService executor)
       throws ExecutionException {
     removeLambdaDeserializationMethods();
@@ -259,6 +272,7 @@
 
     synthesizeLambdaClasses(builder);
     desugarInterfaceMethods(builder, ExcludeDexResources);
+    processCovariantReturnTypeAnnotations(builder);
 
     handleSynthesizedClassMapping(builder);
     timing.end();
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/CovariantReturnTypeAnnotationTransformer.java b/src/main/java/com/android/tools/r8/ir/desugar/CovariantReturnTypeAnnotationTransformer.java
new file mode 100644
index 0000000..64fe450
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/desugar/CovariantReturnTypeAnnotationTransformer.java
@@ -0,0 +1,269 @@
+// Copyright (c) 2018, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.desugar;
+
+import com.android.tools.r8.ApiLevelException;
+import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.graph.DexAnnotation;
+import com.android.tools.r8.graph.DexAnnotationElement;
+import com.android.tools.r8.graph.DexApplication;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedAnnotation;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexProto;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.DexValue;
+import com.android.tools.r8.graph.MethodAccessFlags;
+import com.android.tools.r8.ir.code.Invoke;
+import com.android.tools.r8.ir.conversion.IRConverter;
+import com.android.tools.r8.ir.synthetic.ForwardMethodSourceCode;
+import com.android.tools.r8.ir.synthetic.SynthesizedCode;
+import com.google.common.base.Predicates;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+
+// Responsible for processing the annotations dalvik.annotation.codegen.CovariantReturnType and
+// dalvik.annotation.codegen.CovariantReturnType$CovariantReturnTypes.
+//
+// Consider the following class:
+//   public class B extends A {
+//     @CovariantReturnType(returnType = B.class, presentAfter = 25)
+//     @Override
+//     public A m(...) { ... return new B(); }
+//   }
+//
+// The annotation is used to indicate that the compiler should insert a synthetic method that is
+// equivalent to method m, but has return type B instead of A. Thus, for this example, this
+// component is responsible for inserting the following method in class B (in addition to the
+// existing method m):
+//   public B m(...) { A result = "invoke B.m(...)A;"; return (B) result; }
+//
+// Note that a method may be annotated with more than one CovariantReturnType annotation. In this
+// case there will be a CovariantReturnType$CovariantReturnTypes annotation on the method that wraps
+// several CovariantReturnType annotations. In this case, a new method is synthesized for each of
+// the contained CovariantReturnType annotations.
+public final class CovariantReturnTypeAnnotationTransformer {
+  private final IRConverter converter;
+  private final DexItemFactory factory;
+
+  public CovariantReturnTypeAnnotationTransformer(IRConverter converter, DexItemFactory factory) {
+    this.converter = converter;
+    this.factory = factory;
+  }
+
+  public void process(DexApplication.Builder<?> builder) throws ApiLevelException {
+    // List of methods that should be added to the next class.
+    List<DexEncodedMethod> methodsWithCovariantReturnTypeAnnotation = new LinkedList<>();
+    List<DexEncodedMethod> covariantReturnTypeMethods = new LinkedList<>();
+    for (DexClass clazz : builder.getProgramClasses()) {
+      // Construct the methods that should be added to clazz.
+      buildCovariantReturnTypeMethodsForClass(
+          clazz, methodsWithCovariantReturnTypeAnnotation, covariantReturnTypeMethods);
+      if (covariantReturnTypeMethods.isEmpty()) {
+        continue;
+      }
+      updateClass(clazz, methodsWithCovariantReturnTypeAnnotation, covariantReturnTypeMethods);
+      // Reset lists for the next class that will have a CovariantReturnType or
+      // CovariantReturnType$CovariantReturnTypes annotation.
+      methodsWithCovariantReturnTypeAnnotation.clear();
+      covariantReturnTypeMethods.clear();
+    }
+  }
+
+  private void updateClass(
+      DexClass clazz,
+      List<DexEncodedMethod> methodsWithCovariantReturnTypeAnnotation,
+      List<DexEncodedMethod> covariantReturnTypeMethods) {
+    // It is a compilation error if the class already has a method with a signature similar to one
+    // of the methods in covariantReturnTypeMethods.
+    for (DexEncodedMethod syntheticMethod : covariantReturnTypeMethods) {
+      if (hasVirtualMethodWithSignature(clazz, syntheticMethod)) {
+        throw new CompilationError(
+            String.format(
+                "Cannot process CovariantReturnType annotation: Class %s already "
+                    + "has a method \"%s\"",
+                clazz.getType(), syntheticMethod.toSourceString()));
+      }
+    }
+    // Remove the CovariantReturnType annotations.
+    for (DexEncodedMethod method : methodsWithCovariantReturnTypeAnnotation) {
+      method.annotations =
+          method.annotations.keepIf(x -> !isCovariantReturnTypeAnnotation(x.annotation));
+    }
+    // Add the newly constructed methods to the class.
+    DexEncodedMethod[] oldVirtualMethods = clazz.virtualMethods();
+    DexEncodedMethod[] newVirtualMethods =
+        new DexEncodedMethod[oldVirtualMethods.length + covariantReturnTypeMethods.size()];
+    System.arraycopy(oldVirtualMethods, 0, newVirtualMethods, 0, oldVirtualMethods.length);
+    int i = oldVirtualMethods.length;
+    for (DexEncodedMethod syntheticMethod : covariantReturnTypeMethods) {
+      newVirtualMethods[i] = syntheticMethod;
+      i++;
+    }
+    clazz.setVirtualMethods(newVirtualMethods);
+  }
+
+  // Processes all the dalvik.annotation.codegen.CovariantReturnType and dalvik.annotation.codegen.
+  // CovariantReturnTypes annotations in the given DexClass. Adds the newly constructed, synthetic
+  // methods to the list covariantReturnTypeMethods.
+  private void buildCovariantReturnTypeMethodsForClass(
+      DexClass clazz,
+      List<DexEncodedMethod> methodsWithCovariantReturnTypeAnnotation,
+      List<DexEncodedMethod> covariantReturnTypeMethods)
+      throws ApiLevelException {
+    for (DexEncodedMethod method : clazz.virtualMethods()) {
+      if (methodHasCovariantReturnTypeAnnotation(method)) {
+        methodsWithCovariantReturnTypeAnnotation.add(method);
+        buildCovariantReturnTypeMethodsForMethod(clazz, method, covariantReturnTypeMethods);
+      }
+    }
+  }
+
+  private boolean methodHasCovariantReturnTypeAnnotation(DexEncodedMethod method) {
+    for (DexAnnotation annotation : method.annotations.annotations) {
+      if (isCovariantReturnTypeAnnotation(annotation.annotation)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  // Processes all the dalvik.annotation.codegen.CovariantReturnType and dalvik.annotation.Co-
+  // variantReturnTypes annotations on the given method. Adds the newly constructed, synthetic
+  // methods to the list covariantReturnTypeMethods.
+  private void buildCovariantReturnTypeMethodsForMethod(
+      DexClass clazz, DexEncodedMethod method, List<DexEncodedMethod> covariantReturnTypeMethods)
+      throws ApiLevelException {
+    assert methodHasCovariantReturnTypeAnnotation(method);
+    for (DexType covariantReturnType : getCovariantReturnTypes(clazz, method)) {
+      DexEncodedMethod covariantReturnTypeMethod =
+          buildCovariantReturnTypeMethod(clazz, method, covariantReturnType);
+      covariantReturnTypeMethods.add(covariantReturnTypeMethod);
+    }
+  }
+
+  // Builds a synthetic method that invokes the given method, casts the result to
+  // covariantReturnType, and then returns the result. The newly created method will have return
+  // type covariantReturnType.
+  //
+  // Note: any "synchronized" or "strictfp" modifier could be dropped safely.
+  private DexEncodedMethod buildCovariantReturnTypeMethod(
+      DexClass clazz, DexEncodedMethod method, DexType covariantReturnType)
+      throws ApiLevelException {
+    DexProto newProto =
+        factory.createProto(
+            covariantReturnType, method.method.proto.shorty, method.method.proto.parameters);
+    MethodAccessFlags newAccessFlags = method.accessFlags.copy();
+    newAccessFlags.setSynthetic();
+    DexEncodedMethod newVirtualMethod =
+        new DexEncodedMethod(
+            factory.createMethod(method.method.holder, newProto, method.method.name),
+            newAccessFlags,
+            method.annotations.keepIf(x -> !isCovariantReturnTypeAnnotation(x.annotation)),
+            method.parameterAnnotationsList.keepIf(Predicates.alwaysTrue()),
+            new SynthesizedCode(
+                new ForwardMethodSourceCode(
+                    clazz.type,
+                    newProto,
+                    method.method.holder,
+                    method.method,
+                    Invoke.Type.VIRTUAL,
+                    true)));
+    // Optimize to generate DexCode instead of SynthesizedCode.
+    converter.optimizeSynthesizedMethod(newVirtualMethod);
+    return newVirtualMethod;
+  }
+
+  // Returns the set of covariant return types for method.
+  //
+  // If the method is:
+  //   @dalvik.annotation.codegen.CovariantReturnType(returnType=SubOfFoo, presentAfter=25)
+  //   @dalvik.annotation.codegen.CovariantReturnType(returnType=SubOfSubOfFoo, presentAfter=28)
+  //   @Override
+  //   public Foo foo() { ... return new SubOfSubOfFoo(); }
+  // then this method returns the set { SubOfFoo, SubOfSubOfFoo }.
+  private Set<DexType> getCovariantReturnTypes(DexClass clazz, DexEncodedMethod method) {
+    Set<DexType> covariantReturnTypes = new HashSet<>();
+    for (DexAnnotation annotation : method.annotations.annotations) {
+      if (isCovariantReturnTypeAnnotation(annotation.annotation)) {
+        getCovariantReturnTypesFromAnnotation(
+            clazz, method, annotation.annotation, covariantReturnTypes);
+      }
+    }
+    return covariantReturnTypes;
+  }
+
+  private void getCovariantReturnTypesFromAnnotation(
+      DexClass clazz,
+      DexEncodedMethod method,
+      DexEncodedAnnotation annotation,
+      Set<DexType> covariantReturnTypes) {
+    assert isCovariantReturnTypeAnnotation(annotation);
+    boolean hasPresentAfterElement = false;
+    for (DexAnnotationElement element : annotation.elements) {
+      String name = element.name.toString();
+      if (annotation.type == factory.annotationCovariantReturnType) {
+        if (name.equals("returnType")) {
+          if (!(element.value instanceof DexValue.DexValueType)) {
+            throw new CompilationError(
+                String.format(
+                    "Expected element \"returnType\" of CovariantReturnType annotation to "
+                        + "reference a type (method: \"%s\", was: %s)",
+                    method.toSourceString(), element.value.getClass().getCanonicalName()));
+          }
+
+          DexValue.DexValueType dexValueType = (DexValue.DexValueType) element.value;
+          covariantReturnTypes.add(dexValueType.value);
+        } else if (name.equals("presentAfter")) {
+          hasPresentAfterElement = true;
+        }
+      } else {
+        if (name.equals("value")) {
+          if (!(element.value instanceof DexValue.DexValueArray)) {
+            throw new CompilationError(
+                String.format(
+                    "Expected element \"value\" of CovariantReturnTypes annotation to "
+                        + "be an array (method: \"%s\", was: %s)",
+                    method.toSourceString(), element.value.getClass().getCanonicalName()));
+          }
+
+          DexValue.DexValueArray array = (DexValue.DexValueArray) element.value;
+          // Handle the inner dalvik.annotation.codegen.CovariantReturnType annotations recursively.
+          for (DexValue value : array.getValues()) {
+            assert value instanceof DexValue.DexValueAnnotation;
+            DexValue.DexValueAnnotation innerAnnotation = (DexValue.DexValueAnnotation) value;
+            getCovariantReturnTypesFromAnnotation(
+                clazz, method, innerAnnotation.value, covariantReturnTypes);
+          }
+        }
+      }
+    }
+
+    if (annotation.type == factory.annotationCovariantReturnType && !hasPresentAfterElement) {
+      throw new CompilationError(
+          String.format(
+              "CovariantReturnType annotation for method \"%s\" is missing mandatory element "
+                  + "\"presentAfter\" (class %s)",
+              clazz.getType(), method.toSourceString()));
+    }
+  }
+
+  private boolean isCovariantReturnTypeAnnotation(DexEncodedAnnotation annotation) {
+    return annotation.type == factory.annotationCovariantReturnType
+        || annotation.type == factory.annotationCovariantReturnTypes;
+  }
+
+  private static boolean hasVirtualMethodWithSignature(DexClass clazz, DexEncodedMethod method) {
+    for (DexEncodedMethod existingMethod : clazz.virtualMethods()) {
+      if (existingMethod.method.equals(method.method)) {
+        return true;
+      }
+    }
+    return false;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/synthetic/ForwardMethodSourceCode.java b/src/main/java/com/android/tools/r8/ir/synthetic/ForwardMethodSourceCode.java
index 9259485..0ba49d3 100644
--- a/src/main/java/com/android/tools/r8/ir/synthetic/ForwardMethodSourceCode.java
+++ b/src/main/java/com/android/tools/r8/ir/synthetic/ForwardMethodSourceCode.java
@@ -21,15 +21,27 @@
   private final DexType targetReceiver;
   private final DexMethod target;
   private final Invoke.Type invokeType;
+  private final boolean castResult;
 
   public ForwardMethodSourceCode(DexType receiver, DexProto proto,
       DexType targetReceiver, DexMethod target, Invoke.Type invokeType) {
+    this(receiver, proto, targetReceiver, target, invokeType, false);
+  }
+
+  public ForwardMethodSourceCode(
+      DexType receiver,
+      DexProto proto,
+      DexType targetReceiver,
+      DexMethod target,
+      Invoke.Type invokeType,
+      boolean castResult) {
     super(receiver, proto);
     assert (targetReceiver == null) == (invokeType == Invoke.Type.STATIC);
 
     this.target = target;
     this.targetReceiver = targetReceiver;
     this.invokeType = invokeType;
+    this.castResult = castResult;
     assert checkSignatures();
 
     switch (invokeType) {
@@ -67,7 +79,7 @@
       assert (source.isClassType() && target.isClassType()) || source == target;
     }
 
-    assert this.proto.returnType == target.proto.returnType;
+    assert this.proto.returnType == target.proto.returnType || castResult;
     return true;
   }
 
@@ -99,6 +111,9 @@
       ValueType valueType = ValueType.fromDexType(proto.returnType);
       int tempValue = nextRegister(valueType);
       add(builder -> builder.addMoveResult(tempValue));
+      if (this.proto.returnType != target.proto.returnType) {
+        add(builder -> builder.addCheckCast(tempValue, this.proto.returnType));
+      }
       add(builder -> builder.addReturn(valueType, tempValue));
     }
   }
diff --git a/src/test/java/com/android/tools/r8/ir/desugar/annotations/CovariantReturnTypeAnnotationTransformerTest.java b/src/test/java/com/android/tools/r8/ir/desugar/annotations/CovariantReturnTypeAnnotationTransformerTest.java
index 9c58d25..1f4d2ba 100644
--- a/src/test/java/com/android/tools/r8/ir/desugar/annotations/CovariantReturnTypeAnnotationTransformerTest.java
+++ b/src/test/java/com/android/tools/r8/ir/desugar/annotations/CovariantReturnTypeAnnotationTransformerTest.java
@@ -14,7 +14,6 @@
 import com.android.tools.r8.utils.DexInspector;
 import java.util.Collections;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 
 public class CovariantReturnTypeAnnotationTransformerTest extends AsmTestBase {
@@ -49,7 +48,6 @@
     failsIndependentOfFlag(input);
   }
 
-  @Ignore("b/78618808")
   @Test
   public void testVersion2WithClient1And2() throws Exception {
     AndroidApp input =
@@ -63,7 +61,6 @@
     succeedsIndependentOfFlag(input, true);
   }
 
-  @Ignore("b/78618808")
   @Test
   public void testVersion2WithClient3() throws Exception {
     AndroidApp input =
@@ -108,6 +105,23 @@
     succeedsIndependentOfFlag(input, false);
   }
 
+  @Test
+  public void testRepeatedCompilation() throws Exception {
+    AndroidApp input =
+        buildAndroidApp(
+            ToolHelper.getClassAsBytes(Client.class),
+            ToolHelper.getClassAsBytes(A.class),
+            com.android.tools.r8.ir.desugar.annotations.version2.BDump.dump(),
+            com.android.tools.r8.ir.desugar.annotations.version2.CDump.dump());
+
+    AndroidApp output =
+        compileWithD8(input, options -> options.processCovariantReturnTypeAnnotations = true);
+
+    // Compilation will fail with a compilation error the second time if the implementation does
+    // not remove the CovariantReturnType annotations properly during the first compilation.
+    compileWithD8(output, options -> options.processCovariantReturnTypeAnnotations = true);
+  }
+
   private void succeedsWithOption(
       AndroidApp input, boolean option, boolean checkPresenceOfSyntheticMethods) throws Exception {
     AndroidApp output =