Add cast for covariant return types in synthetic lambda bridges

This CL inserts a cast if the main method's return type is different
from the bridge return type, unless it is casting to Object. Such
bootstrap methods seems to be generated by the Scala compiler for
automatic boxing.

Bug: 140267001
Change-Id: I8e4a7d5353b3ad46591c5380f51eec9b7df4f2b9
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaBridgeMethodSourceCode.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaBridgeMethodSourceCode.java
index 9507471..f0eb914 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaBridgeMethodSourceCode.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaBridgeMethodSourceCode.java
@@ -63,6 +63,11 @@
       ValueType valueType = ValueType.fromDexType(proto.returnType);
       int tempValue = nextRegister(valueType);
       add(builder -> builder.addMoveResult(tempValue));
+      // We lack precise sub-type information, but there should not be a need to cast to object.
+      if (proto.returnType != mainMethod.proto.returnType
+          && proto.returnType != factory().objectType) {
+        add(builder -> builder.addCheckCast(tempValue, proto.returnType));
+      }
       add(builder -> builder.addReturn(tempValue));
     }
   }
diff --git a/src/test/java/com/android/tools/r8/desugar/bridge/LambdaReturnTypeBridgeTest.java b/src/test/java/com/android/tools/r8/desugar/bridge/LambdaReturnTypeBridgeTest.java
new file mode 100644
index 0000000..2dbcc30
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/bridge/LambdaReturnTypeBridgeTest.java
@@ -0,0 +1,83 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.desugar.bridge;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.CompilationFailedException;
+import com.android.tools.r8.D8TestBuilder;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.FoundClassSubject;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class LambdaReturnTypeBridgeTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDexRuntimes().build();
+  }
+
+  public LambdaReturnTypeBridgeTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testReturnTypeLambda()
+      throws IOException, CompilationFailedException, ExecutionException {
+    runTest(testForD8().addProgramClasses(LambdaWithMultipleImplementingInterfaces.class), false);
+  }
+
+  @Test
+  public void testCovariantReturnTypeLambda()
+      throws IOException, CompilationFailedException, ExecutionException {
+    runTest(
+        testForD8()
+            .addProgramClassFileData(LambdaWithMultipleImplementingInterfacesCovariantDump.dump()),
+        true);
+  }
+
+  private void runTest(D8TestBuilder builder, boolean shouldHaveCheckCast)
+      throws IOException, CompilationFailedException, ExecutionException {
+    builder
+        .addInnerClasses(LambdaWithMultipleImplementingInterfaces.class)
+        .setMinApi(parameters.getRuntime())
+        .run(parameters.getRuntime(), LambdaWithMultipleImplementingInterfaces.class)
+        .assertSuccessWithOutputLines("Hello World!", "Hello World!")
+        .inspect(
+            codeInspector -> {
+              boolean foundBridge = false;
+              for (FoundClassSubject clazz : codeInspector.allClasses()) {
+                if (clazz
+                    .getOriginalName()
+                    .contains(
+                        "-$$Lambda$"
+                            + LambdaWithMultipleImplementingInterfaces.class.getSimpleName()
+                            + "$")) {
+                  // Find bridge method and check whether or not it has a cast.
+                  for (FoundMethodSubject bridge : clazz.allMethods(FoundMethodSubject::isBridge)) {
+                    foundBridge = true;
+                    assertEquals(
+                        shouldHaveCheckCast,
+                        bridge.streamInstructions().anyMatch(InstructionSubject::isCheckCast));
+                  }
+                }
+              }
+              assertTrue(foundBridge);
+            });
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces.java b/src/test/java/com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces.java
new file mode 100644
index 0000000..53f543a
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces.java
@@ -0,0 +1,32 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.desugar.bridge;
+
+public class LambdaWithMultipleImplementingInterfaces {
+
+  public interface I {
+    Object get();
+  }
+
+  public interface J {
+    String get();
+  }
+
+  public interface K extends I, J {}
+
+  public static void main(String[] args) {
+    K k = () -> "Hello World!";
+    testI(k);
+    testJ(k);
+  }
+
+  private static void testI(I i) {
+    System.out.println(i.get());
+  }
+
+  private static void testJ(J j) {
+    System.out.println(j.get());
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfacesCovariantDump.java b/src/test/java/com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfacesCovariantDump.java
new file mode 100644
index 0000000..a00539e
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfacesCovariantDump.java
@@ -0,0 +1,263 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.desugar.bridge;
+
+import org.objectweb.asm.ClassWriter;
+import org.objectweb.asm.Handle;
+import org.objectweb.asm.Label;
+import org.objectweb.asm.MethodVisitor;
+import org.objectweb.asm.Opcodes;
+import org.objectweb.asm.Type;
+
+public class LambdaWithMultipleImplementingInterfacesCovariantDump implements Opcodes {
+
+  // Generated from LambdaWithMultipleImplementingInterfaces with the bootstrapmethod changed from:
+  //   0: #47 invokestatic java/lang/invoke/LambdaMetafactory.altMetafactory:...
+  //     Method arguments:
+  //       #48 ()Ljava/lang/String;
+  //       #49 invokestatic com/android/tools/r8/desugar/bridge/
+  //          LambdaWithMultipleImplementingInterfaces.lambda$main$0:()Ljava/lang/String;
+  //       #48 ()Ljava/lang/String;
+  //       #50 4
+  //       #51 1
+  //       #52 ()Ljava/lang/Object;
+  // to
+  //   0: #47 invokestatic java/lang/invoke/LambdaMetafactory.altMetafactory:...
+  //     Method arguments:
+  //       #48 ()Ljava/lang/Object;
+  //       #49 invokestatic com/android/tools/r8/desugar/bridge/
+  //          LambdaWithMultipleImplementingInterfaces.lambda$main$0:()Ljava/lang/String;
+  //       #48 ()Ljava/lang/String;
+  //       #50 4
+  //       #51 1
+  //       #52 ()Ljava/lang/String;
+  // This will need to create a bridge method that takes in an Object and returns a String, which
+  // is seen in class-files generated by the Scala-compiler.
+  public static byte[] dump() {
+
+    ClassWriter classWriter = new ClassWriter(0);
+    MethodVisitor methodVisitor;
+
+    classWriter.visit(
+        V1_8,
+        ACC_PUBLIC | ACC_SUPER,
+        "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces",
+        null,
+        "java/lang/Object",
+        null);
+
+    classWriter.visitSource("LambdaWithMultipleImplementingInterfaces.java", null);
+
+    classWriter.visitInnerClass(
+        "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$K",
+        "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces",
+        "K",
+        ACC_PUBLIC | ACC_STATIC | ACC_ABSTRACT | ACC_INTERFACE);
+
+    classWriter.visitInnerClass(
+        "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$J",
+        "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces",
+        "J",
+        ACC_PUBLIC | ACC_STATIC | ACC_ABSTRACT | ACC_INTERFACE);
+
+    classWriter.visitInnerClass(
+        "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$I",
+        "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces",
+        "I",
+        ACC_PUBLIC | ACC_STATIC | ACC_ABSTRACT | ACC_INTERFACE);
+
+    classWriter.visitInnerClass(
+        "java/lang/invoke/MethodHandles$Lookup",
+        "java/lang/invoke/MethodHandles",
+        "Lookup",
+        ACC_PUBLIC | ACC_FINAL | ACC_STATIC);
+
+    {
+      methodVisitor = classWriter.visitMethod(ACC_PUBLIC, "<init>", "()V", null, null);
+      methodVisitor.visitCode();
+      Label label0 = new Label();
+      methodVisitor.visitLabel(label0);
+      methodVisitor.visitLineNumber(7, label0);
+      methodVisitor.visitVarInsn(ALOAD, 0);
+      methodVisitor.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V", false);
+      methodVisitor.visitInsn(RETURN);
+      Label label1 = new Label();
+      methodVisitor.visitLabel(label1);
+      methodVisitor.visitLocalVariable(
+          "this",
+          "Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces;",
+          null,
+          label0,
+          label1,
+          0);
+      methodVisitor.visitMaxs(1, 1);
+      methodVisitor.visitEnd();
+    }
+    {
+      methodVisitor =
+          classWriter.visitMethod(
+              ACC_PUBLIC | ACC_STATIC, "main", "([Ljava/lang/String;)V", null, null);
+      methodVisitor.visitCode();
+      Label label0 = new Label();
+      methodVisitor.visitLabel(label0);
+      methodVisitor.visitLineNumber(22, label0);
+      methodVisitor.visitInvokeDynamicInsn(
+          "get",
+          "()Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$K;",
+          new Handle(
+              Opcodes.H_INVOKESTATIC,
+              "java/lang/invoke/LambdaMetafactory",
+              "altMetafactory",
+              "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;",
+              false),
+          new Object[] {
+            Type.getType("()Ljava/lang/Object;"),
+            new Handle(
+                Opcodes.H_INVOKESTATIC,
+                "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces",
+                "lambda$main$0",
+                "()Ljava/lang/String;",
+                false),
+            Type.getType("()Ljava/lang/String;"),
+            new Integer(4),
+            new Integer(1),
+            Type.getType("()Ljava/lang/String;")
+          });
+      methodVisitor.visitVarInsn(ASTORE, 1);
+      Label label1 = new Label();
+      methodVisitor.visitLabel(label1);
+      methodVisitor.visitLineNumber(23, label1);
+      methodVisitor.visitVarInsn(ALOAD, 1);
+      methodVisitor.visitMethodInsn(
+          INVOKESTATIC,
+          "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces",
+          "testI",
+          "(Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$I;)V",
+          false);
+      Label label2 = new Label();
+      methodVisitor.visitLabel(label2);
+      methodVisitor.visitLineNumber(24, label2);
+      methodVisitor.visitVarInsn(ALOAD, 1);
+      methodVisitor.visitMethodInsn(
+          INVOKESTATIC,
+          "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces",
+          "testJ",
+          "(Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$J;)V",
+          false);
+      Label label3 = new Label();
+      methodVisitor.visitLabel(label3);
+      methodVisitor.visitLineNumber(25, label3);
+      methodVisitor.visitInsn(RETURN);
+      Label label4 = new Label();
+      methodVisitor.visitLabel(label4);
+      methodVisitor.visitLocalVariable("args", "[Ljava/lang/String;", null, label0, label4, 0);
+      methodVisitor.visitLocalVariable(
+          "k",
+          "Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$K;",
+          null,
+          label1,
+          label4,
+          1);
+      methodVisitor.visitMaxs(1, 2);
+      methodVisitor.visitEnd();
+    }
+    {
+      methodVisitor =
+          classWriter.visitMethod(
+              ACC_PRIVATE | ACC_STATIC,
+              "testI",
+              "(Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$I;)V",
+              null,
+              null);
+      methodVisitor.visitCode();
+      Label label0 = new Label();
+      methodVisitor.visitLabel(label0);
+      methodVisitor.visitLineNumber(28, label0);
+      methodVisitor.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
+      methodVisitor.visitVarInsn(ALOAD, 0);
+      methodVisitor.visitMethodInsn(
+          INVOKEINTERFACE,
+          "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$I",
+          "get",
+          "()Ljava/lang/Object;",
+          true);
+      methodVisitor.visitMethodInsn(
+          INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/Object;)V", false);
+      Label label1 = new Label();
+      methodVisitor.visitLabel(label1);
+      methodVisitor.visitLineNumber(29, label1);
+      methodVisitor.visitInsn(RETURN);
+      Label label2 = new Label();
+      methodVisitor.visitLabel(label2);
+      methodVisitor.visitLocalVariable(
+          "i",
+          "Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$I;",
+          null,
+          label0,
+          label2,
+          0);
+      methodVisitor.visitMaxs(2, 1);
+      methodVisitor.visitEnd();
+    }
+    {
+      methodVisitor =
+          classWriter.visitMethod(
+              ACC_PRIVATE | ACC_STATIC,
+              "testJ",
+              "(Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$J;)V",
+              null,
+              null);
+      methodVisitor.visitCode();
+      Label label0 = new Label();
+      methodVisitor.visitLabel(label0);
+      methodVisitor.visitLineNumber(32, label0);
+      methodVisitor.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
+      methodVisitor.visitVarInsn(ALOAD, 0);
+      methodVisitor.visitMethodInsn(
+          INVOKEINTERFACE,
+          "com/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$J",
+          "get",
+          "()Ljava/lang/String;",
+          true);
+      methodVisitor.visitMethodInsn(
+          INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
+      Label label1 = new Label();
+      methodVisitor.visitLabel(label1);
+      methodVisitor.visitLineNumber(33, label1);
+      methodVisitor.visitInsn(RETURN);
+      Label label2 = new Label();
+      methodVisitor.visitLabel(label2);
+      methodVisitor.visitLocalVariable(
+          "j",
+          "Lcom/android/tools/r8/desugar/bridge/LambdaWithMultipleImplementingInterfaces$J;",
+          null,
+          label0,
+          label2,
+          0);
+      methodVisitor.visitMaxs(2, 1);
+      methodVisitor.visitEnd();
+    }
+    {
+      methodVisitor =
+          classWriter.visitMethod(
+              ACC_PRIVATE | ACC_STATIC | ACC_SYNTHETIC,
+              "lambda$main$0",
+              "()Ljava/lang/String;",
+              null,
+              null);
+      methodVisitor.visitCode();
+      Label label0 = new Label();
+      methodVisitor.visitLabel(label0);
+      methodVisitor.visitLineNumber(22, label0);
+      methodVisitor.visitLdcInsn("Hello World!");
+      methodVisitor.visitInsn(ARETURN);
+      methodVisitor.visitMaxs(1, 0);
+      methodVisitor.visitEnd();
+    }
+    classWriter.visitEnd();
+
+    return classWriter.toByteArray();
+  }
+}