Add support for factory methods on lambda classes

Not enabled except in accompanying test.

Bug: b/225839019
Change-Id: Icb406959b8d6504155f12d72ac25376a9993d888
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
index 27298a9..d01f7fc 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
@@ -72,6 +72,7 @@
   public LambdaDescriptor descriptor;
   public final DexMethod constructor;
   final DexMethod classConstructor;
+  private final DexMethod factoryMethod;
   public final DexField lambdaField;
   public final Target target;
 
@@ -108,6 +109,13 @@
         statelessSingleton
             ? factory.createField(type, type, factory.lambdaInstanceFieldName)
             : null;
+    this.factoryMethod =
+        appView.options().testing.alwaysGenerateLambdaFactoryMethods
+            ? factory.createMethod(
+                type,
+                factory.createProto(type, descriptor.captures.values),
+                factory.createString("create"))
+            : null;
 
     // Synthesize the program class once all fields are set.
     synthesizeLambdaClass(builder, desugarInvoke);
@@ -151,6 +159,15 @@
     return appView.options().createSingletonsForStatelessLambdas && descriptor.isStateless();
   }
 
+  public boolean hasFactoryMethod() {
+    return factoryMethod != null;
+  }
+
+  public DexMethod getFactoryMethod() {
+    assert hasFactoryMethod();
+    return factoryMethod;
+  }
+
   // Synthesize virtual methods.
   private void synthesizeVirtualMethods(
       SyntheticProgramClassBuilder builder, DesugarInvoke desugarInvoke) {
@@ -227,6 +244,17 @@
               .build());
       feedback.classInitializerMayBePostponed(methods.get(1));
     }
+    if (hasFactoryMethod()) {
+      methods.add(
+          DexEncodedMethod.syntheticBuilder()
+              .setMethod(factoryMethod)
+              .setAccessFlags(
+                  MethodAccessFlags.fromSharedAccessFlags(
+                      Constants.ACC_STATIC | Constants.ACC_PUBLIC | Constants.ACC_SYNTHETIC, false))
+              .setCode(LambdaClassFactorySourceCode.build(this))
+              .disableAndroidApiLevelCheck()
+              .build());
+    }
     builder.setDirectMethods(methods);
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassFactorySourceCode.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassFactorySourceCode.java
new file mode 100644
index 0000000..0ca21cf
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassFactorySourceCode.java
@@ -0,0 +1,40 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.desugar;
+
+import com.android.tools.r8.cf.code.CfInstruction;
+import com.android.tools.r8.cf.code.CfInvoke;
+import com.android.tools.r8.cf.code.CfLoad;
+import com.android.tools.r8.cf.code.CfNew;
+import com.android.tools.r8.cf.code.CfReturn;
+import com.android.tools.r8.cf.code.CfStackInstruction;
+import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
+import com.android.tools.r8.graph.CfCode;
+import com.android.tools.r8.ir.code.ValueType;
+import com.google.common.collect.ImmutableList;
+import org.objectweb.asm.Opcodes;
+
+// Source code representing lambda factory method.
+final class LambdaClassFactorySourceCode {
+
+  public static CfCode build(LambdaClass lambda) {
+    int maxStack = 0;
+    int maxLocals = 0;
+    ImmutableList.Builder<CfInstruction> builder = ImmutableList.builder();
+    builder.add(new CfNew(lambda.type)).add(new CfStackInstruction(Opcode.Dup));
+    maxStack += 2;
+    int local = 0;
+    for (int i = 0; i < lambda.constructor.proto.getParameters().size(); i++) {
+      ValueType parameterType = ValueType.fromDexType(lambda.constructor.proto.getParameter(i));
+      builder.add(new CfLoad(parameterType, local));
+      maxStack += parameterType.requiredRegisters();
+      local += parameterType.requiredRegisters();
+      maxLocals = local;
+    }
+    builder
+        .add(new CfInvoke(Opcodes.INVOKESPECIAL, lambda.constructor, false))
+        .add(new CfReturn(ValueType.fromDexType(lambda.type)));
+    return new CfCode(lambda.type, maxStack, maxLocals, builder.build());
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/lambda/LambdaInstructionDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/lambda/LambdaInstructionDesugaring.java
index d336a16..532909b 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/lambda/LambdaInstructionDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/lambda/LambdaInstructionDesugaring.java
@@ -120,6 +120,11 @@
 
     eventConsumer.acceptLambdaClass(lambdaClass, context);
 
+    if (lambdaClass.hasFactoryMethod()) {
+      return ImmutableList.of(
+          new CfInvoke(Opcodes.INVOKESTATIC, lambdaClass.getFactoryMethod(), false));
+    }
+
     if (lambdaClass.isStatelessSingleton()) {
       return ImmutableList.of(
           new CfStaticFieldRead(lambdaClass.lambdaField, lambdaClass.lambdaField));
@@ -143,7 +148,6 @@
     // elements on the stack, we load all the N arguments back onto the stack. At this point, we
     // have the original N arguments on the stack plus the 2 new stack elements.
     localStackAllocator.allocateLocalStack(2);
-
     return replacement;
   }
 
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 210ea92..e389b61 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -1872,6 +1872,8 @@
     public Predicate<DexMethod> cfByteCodePassThrough = null;
 
     public boolean enableExperimentalMapFileVersion = false;
+
+    public boolean alwaysGenerateLambdaFactoryMethods = false;
   }
 
   public MapVersion getMapFileVersion() {
diff --git a/src/test/java/com/android/tools/r8/desugar/lambdas/LambdaFactoryTest.java b/src/test/java/com/android/tools/r8/desugar/lambdas/LambdaFactoryTest.java
new file mode 100644
index 0000000..21ed11f
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/lambdas/LambdaFactoryTest.java
@@ -0,0 +1,158 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.desugar.lambdas;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.android.tools.r8.DesugarTestConfiguration;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.synthesis.SyntheticItemsTestUtils;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class LambdaFactoryTest extends TestBase {
+
+  @Parameter() public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimes().withAllApiLevelsAlsoForCf().build();
+  }
+
+  private static final String EXPECTED_OUTPUT = StringUtils.lines("1", "2", "3", "4.0", "5");
+
+  private boolean isLambdaFactoryMethod(MethodSubject method) {
+    return method.isSynthetic() && method.isStatic() && method.getFinalName().equals("create");
+  }
+
+  private boolean isInvokingLambdaFactoryMethod(InstructionSubject instruction) {
+    return instruction.isInvokeStatic()
+        && SyntheticItemsTestUtils.isExternalSynthetic(
+            instruction.getMethod().getHolderType().asClassReference())
+        && instruction.getMethod().getName().toString().equals("create");
+  }
+
+  private void inspectDesugared(CodeInspector inspector) {
+    inspector.forAllClasses(
+        clazz -> {
+          if (SyntheticItemsTestUtils.isExternalSynthetic(clazz.getFinalReference())) {
+            assertTrue(clazz.allMethods().stream().anyMatch(this::isLambdaFactoryMethod));
+          }
+        });
+    assertEquals(
+        3,
+        inspector
+            .clazz(TestClass.class)
+            .mainMethod()
+            .streamInstructions()
+            .filter(this::isInvokingLambdaFactoryMethod)
+            .count());
+  }
+
+  private void inspectNotDesugared(CodeInspector inspector) {
+    inspector.forAllClasses(
+        clazz -> {
+          if (SyntheticItemsTestUtils.isExternalSynthetic(clazz.getFinalReference())) {
+            assertTrue(clazz.allMethods().stream().noneMatch(this::isLambdaFactoryMethod));
+          }
+        });
+    assertEquals(
+        0,
+        inspector
+            .clazz(TestClass.class)
+            .mainMethod()
+            .streamInstructions()
+            .filter(this::isInvokingLambdaFactoryMethod)
+            .count());
+  }
+
+  @Test
+  public void testDesugaring() throws Exception {
+    testForDesugaring(
+            parameters,
+            options -> {
+              options.testing.alwaysGenerateLambdaFactoryMethods = true;
+            })
+        .addInnerClasses(getClass())
+        .run(parameters.getRuntime(), TestClass.class)
+        .applyIf(
+            DesugarTestConfiguration::isDesugared,
+            r -> {
+              try {
+                r.inspect(this::inspectDesugared);
+              } catch (Exception e) {
+                fail();
+              }
+            },
+            r -> {
+              try {
+                r.inspect(this::inspectNotDesugared);
+              } catch (Exception e) {
+                fail();
+              }
+            })
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(TestClass.class)
+        .setMinApi(parameters.getApiLevel())
+        .run(parameters.getRuntime(), TestClass.class)
+        .inspect(
+            inspector -> {
+              if (parameters.isDexRuntime()) {
+                // Lambdas are fully inlined when desugaring.
+                assertEquals(1, inspector.allClasses().size());
+                assertEquals(1, inspector.clazz(TestClass.class).allMethods().size());
+              }
+            })
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
+  }
+
+  interface MyConsumer<T> {
+    void create(T o);
+  }
+
+  interface MyTriConsumer<T, U, V> {
+    void accept(T o1, U o2, V o3);
+  }
+
+  static class TestClass {
+
+    public static void greet() {
+      System.out.println("1");
+    }
+
+    public static void greet(MyConsumer<String> consumer) {
+      consumer.create("2");
+    }
+
+    public static void greetTri(long l, double d, String s) {
+      System.out.println(l);
+      System.out.println(d);
+      System.out.println(s);
+    }
+
+    public static void main(String[] args) throws Exception {
+      ((Runnable) TestClass::greet).run();
+      greet(System.out::println);
+      ((MyTriConsumer<Long, Double, String>) TestClass::greetTri).accept(3L, 4.0, "5");
+    }
+  }
+}