Add support for factory methods on lambda classes
Not enabled except in accompanying test.
Bug: b/225839019
Change-Id: Icb406959b8d6504155f12d72ac25376a9993d888
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
index 27298a9..d01f7fc 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
@@ -72,6 +72,7 @@
public LambdaDescriptor descriptor;
public final DexMethod constructor;
final DexMethod classConstructor;
+ private final DexMethod factoryMethod;
public final DexField lambdaField;
public final Target target;
@@ -108,6 +109,13 @@
statelessSingleton
? factory.createField(type, type, factory.lambdaInstanceFieldName)
: null;
+ this.factoryMethod =
+ appView.options().testing.alwaysGenerateLambdaFactoryMethods
+ ? factory.createMethod(
+ type,
+ factory.createProto(type, descriptor.captures.values),
+ factory.createString("create"))
+ : null;
// Synthesize the program class once all fields are set.
synthesizeLambdaClass(builder, desugarInvoke);
@@ -151,6 +159,15 @@
return appView.options().createSingletonsForStatelessLambdas && descriptor.isStateless();
}
+ public boolean hasFactoryMethod() {
+ return factoryMethod != null;
+ }
+
+ public DexMethod getFactoryMethod() {
+ assert hasFactoryMethod();
+ return factoryMethod;
+ }
+
// Synthesize virtual methods.
private void synthesizeVirtualMethods(
SyntheticProgramClassBuilder builder, DesugarInvoke desugarInvoke) {
@@ -227,6 +244,17 @@
.build());
feedback.classInitializerMayBePostponed(methods.get(1));
}
+ if (hasFactoryMethod()) {
+ methods.add(
+ DexEncodedMethod.syntheticBuilder()
+ .setMethod(factoryMethod)
+ .setAccessFlags(
+ MethodAccessFlags.fromSharedAccessFlags(
+ Constants.ACC_STATIC | Constants.ACC_PUBLIC | Constants.ACC_SYNTHETIC, false))
+ .setCode(LambdaClassFactorySourceCode.build(this))
+ .disableAndroidApiLevelCheck()
+ .build());
+ }
builder.setDirectMethods(methods);
}
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassFactorySourceCode.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassFactorySourceCode.java
new file mode 100644
index 0000000..0ca21cf
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassFactorySourceCode.java
@@ -0,0 +1,40 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.desugar;
+
+import com.android.tools.r8.cf.code.CfInstruction;
+import com.android.tools.r8.cf.code.CfInvoke;
+import com.android.tools.r8.cf.code.CfLoad;
+import com.android.tools.r8.cf.code.CfNew;
+import com.android.tools.r8.cf.code.CfReturn;
+import com.android.tools.r8.cf.code.CfStackInstruction;
+import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
+import com.android.tools.r8.graph.CfCode;
+import com.android.tools.r8.ir.code.ValueType;
+import com.google.common.collect.ImmutableList;
+import org.objectweb.asm.Opcodes;
+
+// Source code representing lambda factory method.
+final class LambdaClassFactorySourceCode {
+
+ public static CfCode build(LambdaClass lambda) {
+ int maxStack = 0;
+ int maxLocals = 0;
+ ImmutableList.Builder<CfInstruction> builder = ImmutableList.builder();
+ builder.add(new CfNew(lambda.type)).add(new CfStackInstruction(Opcode.Dup));
+ maxStack += 2;
+ int local = 0;
+ for (int i = 0; i < lambda.constructor.proto.getParameters().size(); i++) {
+ ValueType parameterType = ValueType.fromDexType(lambda.constructor.proto.getParameter(i));
+ builder.add(new CfLoad(parameterType, local));
+ maxStack += parameterType.requiredRegisters();
+ local += parameterType.requiredRegisters();
+ maxLocals = local;
+ }
+ builder
+ .add(new CfInvoke(Opcodes.INVOKESPECIAL, lambda.constructor, false))
+ .add(new CfReturn(ValueType.fromDexType(lambda.type)));
+ return new CfCode(lambda.type, maxStack, maxLocals, builder.build());
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/lambda/LambdaInstructionDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/lambda/LambdaInstructionDesugaring.java
index d336a16..532909b 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/lambda/LambdaInstructionDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/lambda/LambdaInstructionDesugaring.java
@@ -120,6 +120,11 @@
eventConsumer.acceptLambdaClass(lambdaClass, context);
+ if (lambdaClass.hasFactoryMethod()) {
+ return ImmutableList.of(
+ new CfInvoke(Opcodes.INVOKESTATIC, lambdaClass.getFactoryMethod(), false));
+ }
+
if (lambdaClass.isStatelessSingleton()) {
return ImmutableList.of(
new CfStaticFieldRead(lambdaClass.lambdaField, lambdaClass.lambdaField));
@@ -143,7 +148,6 @@
// elements on the stack, we load all the N arguments back onto the stack. At this point, we
// have the original N arguments on the stack plus the 2 new stack elements.
localStackAllocator.allocateLocalStack(2);
-
return replacement;
}
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 210ea92..e389b61 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -1872,6 +1872,8 @@
public Predicate<DexMethod> cfByteCodePassThrough = null;
public boolean enableExperimentalMapFileVersion = false;
+
+ public boolean alwaysGenerateLambdaFactoryMethods = false;
}
public MapVersion getMapFileVersion() {
diff --git a/src/test/java/com/android/tools/r8/desugar/lambdas/LambdaFactoryTest.java b/src/test/java/com/android/tools/r8/desugar/lambdas/LambdaFactoryTest.java
new file mode 100644
index 0000000..21ed11f
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/lambdas/LambdaFactoryTest.java
@@ -0,0 +1,158 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.desugar.lambdas;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.android.tools.r8.DesugarTestConfiguration;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.synthesis.SyntheticItemsTestUtils;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class LambdaFactoryTest extends TestBase {
+
+ @Parameter() public TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimes().withAllApiLevelsAlsoForCf().build();
+ }
+
+ private static final String EXPECTED_OUTPUT = StringUtils.lines("1", "2", "3", "4.0", "5");
+
+ private boolean isLambdaFactoryMethod(MethodSubject method) {
+ return method.isSynthetic() && method.isStatic() && method.getFinalName().equals("create");
+ }
+
+ private boolean isInvokingLambdaFactoryMethod(InstructionSubject instruction) {
+ return instruction.isInvokeStatic()
+ && SyntheticItemsTestUtils.isExternalSynthetic(
+ instruction.getMethod().getHolderType().asClassReference())
+ && instruction.getMethod().getName().toString().equals("create");
+ }
+
+ private void inspectDesugared(CodeInspector inspector) {
+ inspector.forAllClasses(
+ clazz -> {
+ if (SyntheticItemsTestUtils.isExternalSynthetic(clazz.getFinalReference())) {
+ assertTrue(clazz.allMethods().stream().anyMatch(this::isLambdaFactoryMethod));
+ }
+ });
+ assertEquals(
+ 3,
+ inspector
+ .clazz(TestClass.class)
+ .mainMethod()
+ .streamInstructions()
+ .filter(this::isInvokingLambdaFactoryMethod)
+ .count());
+ }
+
+ private void inspectNotDesugared(CodeInspector inspector) {
+ inspector.forAllClasses(
+ clazz -> {
+ if (SyntheticItemsTestUtils.isExternalSynthetic(clazz.getFinalReference())) {
+ assertTrue(clazz.allMethods().stream().noneMatch(this::isLambdaFactoryMethod));
+ }
+ });
+ assertEquals(
+ 0,
+ inspector
+ .clazz(TestClass.class)
+ .mainMethod()
+ .streamInstructions()
+ .filter(this::isInvokingLambdaFactoryMethod)
+ .count());
+ }
+
+ @Test
+ public void testDesugaring() throws Exception {
+ testForDesugaring(
+ parameters,
+ options -> {
+ options.testing.alwaysGenerateLambdaFactoryMethods = true;
+ })
+ .addInnerClasses(getClass())
+ .run(parameters.getRuntime(), TestClass.class)
+ .applyIf(
+ DesugarTestConfiguration::isDesugared,
+ r -> {
+ try {
+ r.inspect(this::inspectDesugared);
+ } catch (Exception e) {
+ fail();
+ }
+ },
+ r -> {
+ try {
+ r.inspect(this::inspectNotDesugared);
+ } catch (Exception e) {
+ fail();
+ }
+ })
+ .assertSuccessWithOutput(EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(TestClass.class)
+ .setMinApi(parameters.getApiLevel())
+ .run(parameters.getRuntime(), TestClass.class)
+ .inspect(
+ inspector -> {
+ if (parameters.isDexRuntime()) {
+ // Lambdas are fully inlined when desugaring.
+ assertEquals(1, inspector.allClasses().size());
+ assertEquals(1, inspector.clazz(TestClass.class).allMethods().size());
+ }
+ })
+ .assertSuccessWithOutput(EXPECTED_OUTPUT);
+ }
+
+ interface MyConsumer<T> {
+ void create(T o);
+ }
+
+ interface MyTriConsumer<T, U, V> {
+ void accept(T o1, U o2, V o3);
+ }
+
+ static class TestClass {
+
+ public static void greet() {
+ System.out.println("1");
+ }
+
+ public static void greet(MyConsumer<String> consumer) {
+ consumer.create("2");
+ }
+
+ public static void greetTri(long l, double d, String s) {
+ System.out.println(l);
+ System.out.println(d);
+ System.out.println(s);
+ }
+
+ public static void main(String[] args) throws Exception {
+ ((Runnable) TestClass::greet).run();
+ greet(System.out::println);
+ ((MyTriConsumer<Long, Double, String>) TestClass::greetTri).accept(3L, 4.0, "5");
+ }
+ }
+}