Desugar enumSwitch
Bug: b/336510513
Change-Id: Ief78d199973f95de7a5e3b5bf0f306670b12f4c1
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
index e6e99e7..da03ae1 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
@@ -43,6 +43,7 @@
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
+import java.util.function.BiConsumer;
import org.objectweb.asm.Opcodes;
public class TypeSwitchDesugaring implements CfInstructionDesugaring {
@@ -50,8 +51,9 @@
private final AppView<?> appView;
private final DexMethod typeSwitchMethod;
+ private final DexMethod enumSwitchMethod;
private final DexProto typeSwitchProto;
- private final DexProto typeSwitchHelperProto;
+ private final DexProto switchHelperProto;
private final DexMethod enumDescMethod;
private final DexMethod classDescMethod;
private final DexType matchException;
@@ -62,18 +64,22 @@
this.appView = appView;
this.factory = appView.dexItemFactory();
DexType switchBootstrap = factory.createType("Ljava/lang/runtime/SwitchBootstraps;");
+ DexProto methodHandleSwitchProto =
+ factory.createProto(
+ factory.callSiteType,
+ factory.methodHandlesLookupType,
+ factory.stringType,
+ factory.methodTypeType,
+ factory.objectArrayType);
typeSwitchMethod =
factory.createMethod(
- switchBootstrap,
- factory.createProto(
- factory.callSiteType,
- factory.methodHandlesLookupType,
- factory.stringType,
- factory.methodTypeType,
- factory.objectArrayType),
- factory.createString("typeSwitch"));
- typeSwitchProto = factory.createProto(factory.intType, factory.objectType, factory.intType);
- typeSwitchHelperProto =
+ switchBootstrap, methodHandleSwitchProto, factory.createString("typeSwitch"));
+ enumSwitchMethod =
+ factory.createMethod(
+ switchBootstrap, methodHandleSwitchProto, factory.createString("enumSwitch"));
+ this.typeSwitchProto =
+ factory.createProto(factory.intType, factory.objectType, factory.intType);
+ switchHelperProto =
factory.createProto(
factory.intType, factory.objectType, factory.intType, factory.objectArrayType);
enumDescMethod =
@@ -147,31 +153,62 @@
return DesugarDescription.nothing();
}
DexCallSite callSite = instruction.asInvokeDynamic().getCallSite();
- if (!(callSite.methodName.isIdenticalTo(typeSwitchMethod.getName())
+ if (callSite.methodName.isIdenticalTo(typeSwitchMethod.getName())
&& callSite.methodProto.isIdenticalTo(typeSwitchProto)
- && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, typeSwitchMethod))) {
- return DesugarDescription.nothing();
+ && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, typeSwitchMethod)) {
+ return DesugarDescription.builder()
+ .setDesugarRewrite(
+ (position,
+ freshLocalProvider,
+ localStackAllocator,
+ desugaringInfo,
+ eventConsumer,
+ theContext,
+ methodProcessingContext,
+ desugaringCollection,
+ dexItemFactory) -> {
+ // We add on stack (2) array, (3) dupped array, (4) index, (5) value.
+ localStackAllocator.allocateLocalStack(4);
+ List<CfInstruction> cfInstructions =
+ generateTypeSwitchLoadArguments(callSite, context);
+ generateInvokeToDesugaredMethod(
+ methodProcessingContext, cfInstructions, theContext, eventConsumer);
+ return cfInstructions;
+ })
+ .build();
}
- // Call the desugared method.
- return DesugarDescription.builder()
- .setDesugarRewrite(
- (position,
- freshLocalProvider,
- localStackAllocator,
- desugaringInfo,
- eventConsumer,
- theContext,
- methodProcessingContext,
- desugaringCollection,
- dexItemFactory) -> {
- // We add on stack (2) array, (3) dupped array, (4) index, (5) value.
- localStackAllocator.allocateLocalStack(4);
- List<CfInstruction> cfInstructions = generateLoadArguments(callSite, context);
- generateInvokeToDesugaredMethod(
- methodProcessingContext, cfInstructions, theContext, eventConsumer);
- return cfInstructions;
- })
- .build();
+ if (callSite.methodName.isIdenticalTo(enumSwitchMethod.getName())
+ && isEnumSwitchProto(callSite.methodProto)
+ && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, enumSwitchMethod)) {
+ DexType enumType = callSite.methodProto.getParameter(0);
+ return DesugarDescription.builder()
+ .setDesugarRewrite(
+ (position,
+ freshLocalProvider,
+ localStackAllocator,
+ desugaringInfo,
+ eventConsumer,
+ theContext,
+ methodProcessingContext,
+ desugaringCollection,
+ dexItemFactory) -> {
+ // We add on stack (2) array, (3) dupped array, (4) index, (5) value.
+ localStackAllocator.allocateLocalStack(4);
+ List<CfInstruction> cfInstructions =
+ generateEnumSwitchLoadArguments(callSite, context, enumType);
+ generateInvokeToDesugaredMethod(
+ methodProcessingContext, cfInstructions, theContext, eventConsumer);
+ return cfInstructions;
+ })
+ .build();
+ }
+ return DesugarDescription.nothing();
+ }
+
+ private boolean isEnumSwitchProto(DexProto methodProto) {
+ return methodProto.getReturnType().isIdenticalTo(factory.intType)
+ && methodProto.getArity() == 2
+ && methodProto.getParameter(1).isIdenticalTo(factory.intType);
}
private void generateInvokeToDesugaredMethod(
@@ -189,7 +226,7 @@
builder ->
builder
.disableAndroidApiLevelCheck()
- .setProto(typeSwitchHelperProto)
+ .setProto(switchHelperProto)
.setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
.setCode(
methodSig -> {
@@ -206,7 +243,50 @@
cfInstructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method.getReference(), false));
}
- private List<CfInstruction> generateLoadArguments(DexCallSite callSite, ProgramMethod context) {
+ private List<CfInstruction> generateEnumSwitchLoadArguments(
+ DexCallSite callSite, ProgramMethod context, DexType enumType) {
+ return generateSwitchLoadArguments(
+ callSite,
+ (bootstrapArg, cfInstructions) -> {
+ if (bootstrapArg.isDexValueType()) {
+ cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
+ } else if (bootstrapArg.isDexValueString()) {
+ DexField enumField =
+ getEnumField(bootstrapArg.asDexValueString().getValue(), enumType, context);
+ cfInstructions.add(new CfStaticFieldRead(enumField));
+ } else {
+ throw new CompilationError(
+ "Invalid bootstrap arg for enum switch " + bootstrapArg, context.getOrigin());
+ }
+ });
+ }
+
+ private List<CfInstruction> generateTypeSwitchLoadArguments(
+ DexCallSite callSite, ProgramMethod context) {
+ return generateSwitchLoadArguments(
+ callSite,
+ (bootstrapArg, cfInstructions) -> {
+ if (bootstrapArg.isDexValueType()) {
+ cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
+ } else if (bootstrapArg.isDexValueInt()) {
+ cfInstructions.add(
+ new CfConstNumber(bootstrapArg.asDexValueInt().getValue(), ValueType.INT));
+ cfInstructions.add(
+ new CfInvoke(Opcodes.INVOKESTATIC, factory.integerMembers.valueOf, false));
+ } else if (bootstrapArg.isDexValueString()) {
+ cfInstructions.add(new CfConstString(bootstrapArg.asDexValueString().getValue()));
+ } else if (bootstrapArg.isDexValueConstDynamic()) {
+ DexField enumField = extractEnumField(bootstrapArg.asDexValueConstDynamic(), context);
+ cfInstructions.add(new CfStaticFieldRead(enumField));
+ } else {
+ throw new CompilationError(
+ "Invalid bootstrap arg for type switch " + bootstrapArg, context.getOrigin());
+ }
+ });
+ }
+
+ private List<CfInstruction> generateSwitchLoadArguments(
+ DexCallSite callSite, BiConsumer<DexValue, List<CfInstruction>> adder) {
// We need to call the method with the bootstrap args as parameters.
// We need to convert the bootstrap args into a list of cf instructions.
// The object and the int are already pushed on stack, we simply need to push the extra array.
@@ -217,20 +297,7 @@
DexValue bootstrapArg = callSite.bootstrapArgs.get(i);
cfInstructions.add(new CfStackInstruction(Opcode.Dup));
cfInstructions.add(new CfConstNumber(i, ValueType.INT));
- if (bootstrapArg.isDexValueType()) {
- cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
- } else if (bootstrapArg.isDexValueInt()) {
- cfInstructions.add(
- new CfConstNumber(bootstrapArg.asDexValueInt().getValue(), ValueType.INT));
- cfInstructions.add(
- new CfInvoke(Opcodes.INVOKESTATIC, factory.integerMembers.valueOf, false));
- } else if (bootstrapArg.isDexValueString()) {
- cfInstructions.add(new CfConstString(bootstrapArg.asDexValueString().getValue()));
- } else {
- assert bootstrapArg.isDexValueConstDynamic();
- DexField enumField = extractEnumField(bootstrapArg.asDexValueConstDynamic(), context);
- cfInstructions.add(new CfStaticFieldRead(enumField));
- }
+ adder.accept(bootstrapArg, cfInstructions);
cfInstructions.add(new CfArrayStore(MemberType.OBJECT));
}
return cfInstructions;
@@ -280,6 +347,10 @@
DexString className = dexValueClassName.asDexValueString().getValue();
DexType enumType =
factory.createType(DescriptorUtils.javaTypeToDescriptor(className.toString()));
+ return getEnumField(fieldName, enumType, context);
+ }
+
+ private DexField getEnumField(DexString fieldName, DexType enumType, ProgramMethod context) {
DexClass enumClass = appView.definitionFor(enumType);
if (enumClass == null) {
throw throwEnumFieldConstantDynamic("Missing enum class " + enumType, context);
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java b/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java
new file mode 100644
index 0000000..32fad72
--- /dev/null
+++ b/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java
@@ -0,0 +1,140 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package switchpatternmatching;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.TestRuntime.CfVm;
+import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Assume;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class EnumSwitchUsingEnumSwitchBootstrapMethod extends TestBase {
+
+ @Parameter public TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ public static String EXPECTED_OUTPUT =
+ StringUtils.lines(
+ "null",
+ "Spades or Piques",
+ "Hearts or C\u0153ur",
+ "Diamonds or Carreaux",
+ "Clubs or Trefles",
+ "Trumps or Atouts",
+ "The Fool or L'Excuse");
+
+ @Test
+ public void testJvm() throws Exception {
+ assumeTrue(parameters.isCfRuntime());
+ CodeInspector inspector = new CodeInspector(ToolHelper.getClassFileForTestClass(Main.class));
+ // javac generated an invokedynamic using bootstrap method
+ // java.lang.runtime.SwitchBootstraps.enumSwitch.
+ assertEquals(
+ 1,
+ inspector
+ .clazz(Main.class)
+ .uniqueMethodWithOriginalName("enumSwitch")
+ .streamInstructions()
+ .filter(InstructionSubject::isInvokeDynamic)
+ .map(
+ instruction ->
+ instruction
+ .asCfInstruction()
+ .getInstruction()
+ .asInvokeDynamic()
+ .getCallSite()
+ .getBootstrapMethod()
+ .member
+ .asDexMethod())
+ .filter(
+ method ->
+ method
+ .getHolderType()
+ .toString()
+ .contains("java.lang.runtime.SwitchBootstraps"))
+ .filter(method -> method.toString().contains("enumSwitch"))
+ .count());
+
+ parameters.assumeJvmTestParameters();
+ testForJvm(parameters)
+ .addInnerClassesAndStrippedOuter(getClass())
+ .run(parameters.getRuntime(), Main.class)
+ .applyIf(
+ parameters.getCfRuntime().isNewerThanOrEqual(CfVm.JDK21),
+ r -> r.assertSuccessWithOutput(EXPECTED_OUTPUT),
+ r -> r.assertFailureWithErrorThatThrows(UnsupportedClassVersionError.class));
+ }
+
+ @Test
+ public void testD8() throws Exception {
+ parameters.assumeDexRuntime();
+ testForD8()
+ .addInnerClassesAndStrippedOuter(getClass())
+ .setMinApi(parameters)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ Assume.assumeTrue("For Cf we should compile with Jdk 21 library", parameters.isDexRuntime());
+ testForR8(parameters.getBackend())
+ .addInnerClassesAndStrippedOuter(getClass())
+ .setMinApi(parameters)
+ .addKeepMainRule(Main.class)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(EXPECTED_OUTPUT);
+ }
+
+ public enum Tarot {
+ SPADE,
+ HEART,
+ DIAMOND,
+ CLUB,
+ TRUMP,
+ EXCUSE
+ }
+
+ public static class Main {
+
+ public static void main(String[] args) {
+ enumSwitch(null);
+ enumSwitch(Tarot.SPADE);
+ enumSwitch(Tarot.HEART);
+ enumSwitch(Tarot.DIAMOND);
+ enumSwitch(Tarot.CLUB);
+ enumSwitch(Tarot.TRUMP);
+ enumSwitch(Tarot.EXCUSE);
+ }
+
+ static void enumSwitch(Tarot t1) {
+ switch (t1) {
+ case null -> System.out.println("null");
+ case SPADE -> System.out.println("Spades or Piques");
+ case HEART -> System.out.println("Hearts or C\u0153ur");
+ case Tarot t when t == Tarot.DIAMOND -> System.out.println("Diamonds or Carreaux");
+ case Tarot t when t == Tarot.CLUB -> System.out.println("Clubs or Trefles");
+ case Tarot t when t == Tarot.TRUMP -> System.out.println("Trumps or Atouts");
+ case Tarot t -> System.out.println("The Fool or L'Excuse");
+ }
+ }
+ }
+}