Desugar enumSwitch

Bug: b/336510513
Change-Id: Ief78d199973f95de7a5e3b5bf0f306670b12f4c1
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
index e6e99e7..da03ae1 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
@@ -43,6 +43,7 @@
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.function.BiConsumer;
 import org.objectweb.asm.Opcodes;
 
 public class TypeSwitchDesugaring implements CfInstructionDesugaring {
@@ -50,8 +51,9 @@
   private final AppView<?> appView;
 
   private final DexMethod typeSwitchMethod;
+  private final DexMethod enumSwitchMethod;
   private final DexProto typeSwitchProto;
-  private final DexProto typeSwitchHelperProto;
+  private final DexProto switchHelperProto;
   private final DexMethod enumDescMethod;
   private final DexMethod classDescMethod;
   private final DexType matchException;
@@ -62,18 +64,22 @@
     this.appView = appView;
     this.factory = appView.dexItemFactory();
     DexType switchBootstrap = factory.createType("Ljava/lang/runtime/SwitchBootstraps;");
+    DexProto methodHandleSwitchProto =
+        factory.createProto(
+            factory.callSiteType,
+            factory.methodHandlesLookupType,
+            factory.stringType,
+            factory.methodTypeType,
+            factory.objectArrayType);
     typeSwitchMethod =
         factory.createMethod(
-            switchBootstrap,
-            factory.createProto(
-                factory.callSiteType,
-                factory.methodHandlesLookupType,
-                factory.stringType,
-                factory.methodTypeType,
-                factory.objectArrayType),
-            factory.createString("typeSwitch"));
-    typeSwitchProto = factory.createProto(factory.intType, factory.objectType, factory.intType);
-    typeSwitchHelperProto =
+            switchBootstrap, methodHandleSwitchProto, factory.createString("typeSwitch"));
+    enumSwitchMethod =
+        factory.createMethod(
+            switchBootstrap, methodHandleSwitchProto, factory.createString("enumSwitch"));
+    this.typeSwitchProto =
+        factory.createProto(factory.intType, factory.objectType, factory.intType);
+    switchHelperProto =
         factory.createProto(
             factory.intType, factory.objectType, factory.intType, factory.objectArrayType);
     enumDescMethod =
@@ -147,31 +153,62 @@
       return DesugarDescription.nothing();
     }
     DexCallSite callSite = instruction.asInvokeDynamic().getCallSite();
-    if (!(callSite.methodName.isIdenticalTo(typeSwitchMethod.getName())
+    if (callSite.methodName.isIdenticalTo(typeSwitchMethod.getName())
         && callSite.methodProto.isIdenticalTo(typeSwitchProto)
-        && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, typeSwitchMethod))) {
-      return DesugarDescription.nothing();
+        && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, typeSwitchMethod)) {
+      return DesugarDescription.builder()
+          .setDesugarRewrite(
+              (position,
+                  freshLocalProvider,
+                  localStackAllocator,
+                  desugaringInfo,
+                  eventConsumer,
+                  theContext,
+                  methodProcessingContext,
+                  desugaringCollection,
+                  dexItemFactory) -> {
+                // We add on stack (2) array, (3) dupped array, (4) index, (5) value.
+                localStackAllocator.allocateLocalStack(4);
+                List<CfInstruction> cfInstructions =
+                    generateTypeSwitchLoadArguments(callSite, context);
+                generateInvokeToDesugaredMethod(
+                    methodProcessingContext, cfInstructions, theContext, eventConsumer);
+                return cfInstructions;
+              })
+          .build();
     }
-    // Call the desugared method.
-    return DesugarDescription.builder()
-        .setDesugarRewrite(
-            (position,
-                freshLocalProvider,
-                localStackAllocator,
-                desugaringInfo,
-                eventConsumer,
-                theContext,
-                methodProcessingContext,
-                desugaringCollection,
-                dexItemFactory) -> {
-              // We add on stack (2) array, (3) dupped array, (4) index, (5) value.
-              localStackAllocator.allocateLocalStack(4);
-              List<CfInstruction> cfInstructions = generateLoadArguments(callSite, context);
-              generateInvokeToDesugaredMethod(
-                  methodProcessingContext, cfInstructions, theContext, eventConsumer);
-              return cfInstructions;
-            })
-        .build();
+    if (callSite.methodName.isIdenticalTo(enumSwitchMethod.getName())
+        && isEnumSwitchProto(callSite.methodProto)
+        && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, enumSwitchMethod)) {
+      DexType enumType = callSite.methodProto.getParameter(0);
+      return DesugarDescription.builder()
+          .setDesugarRewrite(
+              (position,
+                  freshLocalProvider,
+                  localStackAllocator,
+                  desugaringInfo,
+                  eventConsumer,
+                  theContext,
+                  methodProcessingContext,
+                  desugaringCollection,
+                  dexItemFactory) -> {
+                // We add on stack (2) array, (3) dupped array, (4) index, (5) value.
+                localStackAllocator.allocateLocalStack(4);
+                List<CfInstruction> cfInstructions =
+                    generateEnumSwitchLoadArguments(callSite, context, enumType);
+                generateInvokeToDesugaredMethod(
+                    methodProcessingContext, cfInstructions, theContext, eventConsumer);
+                return cfInstructions;
+              })
+          .build();
+    }
+    return DesugarDescription.nothing();
+  }
+
+  private boolean isEnumSwitchProto(DexProto methodProto) {
+    return methodProto.getReturnType().isIdenticalTo(factory.intType)
+        && methodProto.getArity() == 2
+        && methodProto.getParameter(1).isIdenticalTo(factory.intType);
   }
 
   private void generateInvokeToDesugaredMethod(
@@ -189,7 +226,7 @@
                 builder ->
                     builder
                         .disableAndroidApiLevelCheck()
-                        .setProto(typeSwitchHelperProto)
+                        .setProto(switchHelperProto)
                         .setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
                         .setCode(
                             methodSig -> {
@@ -206,7 +243,50 @@
     cfInstructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method.getReference(), false));
   }
 
-  private List<CfInstruction> generateLoadArguments(DexCallSite callSite, ProgramMethod context) {
+  private List<CfInstruction> generateEnumSwitchLoadArguments(
+      DexCallSite callSite, ProgramMethod context, DexType enumType) {
+    return generateSwitchLoadArguments(
+        callSite,
+        (bootstrapArg, cfInstructions) -> {
+          if (bootstrapArg.isDexValueType()) {
+            cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
+          } else if (bootstrapArg.isDexValueString()) {
+            DexField enumField =
+                getEnumField(bootstrapArg.asDexValueString().getValue(), enumType, context);
+            cfInstructions.add(new CfStaticFieldRead(enumField));
+          } else {
+            throw new CompilationError(
+                "Invalid bootstrap arg for enum switch " + bootstrapArg, context.getOrigin());
+          }
+        });
+  }
+
+  private List<CfInstruction> generateTypeSwitchLoadArguments(
+      DexCallSite callSite, ProgramMethod context) {
+    return generateSwitchLoadArguments(
+        callSite,
+        (bootstrapArg, cfInstructions) -> {
+          if (bootstrapArg.isDexValueType()) {
+            cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
+          } else if (bootstrapArg.isDexValueInt()) {
+            cfInstructions.add(
+                new CfConstNumber(bootstrapArg.asDexValueInt().getValue(), ValueType.INT));
+            cfInstructions.add(
+                new CfInvoke(Opcodes.INVOKESTATIC, factory.integerMembers.valueOf, false));
+          } else if (bootstrapArg.isDexValueString()) {
+            cfInstructions.add(new CfConstString(bootstrapArg.asDexValueString().getValue()));
+          } else if (bootstrapArg.isDexValueConstDynamic()) {
+            DexField enumField = extractEnumField(bootstrapArg.asDexValueConstDynamic(), context);
+            cfInstructions.add(new CfStaticFieldRead(enumField));
+          } else {
+            throw new CompilationError(
+                "Invalid bootstrap arg for type switch " + bootstrapArg, context.getOrigin());
+          }
+        });
+  }
+
+  private List<CfInstruction> generateSwitchLoadArguments(
+      DexCallSite callSite, BiConsumer<DexValue, List<CfInstruction>> adder) {
     // We need to call the method with the bootstrap args as parameters.
     // We need to convert the bootstrap args into a list of cf instructions.
     // The object and the int are already pushed on stack, we simply need to push the extra array.
@@ -217,20 +297,7 @@
       DexValue bootstrapArg = callSite.bootstrapArgs.get(i);
       cfInstructions.add(new CfStackInstruction(Opcode.Dup));
       cfInstructions.add(new CfConstNumber(i, ValueType.INT));
-      if (bootstrapArg.isDexValueType()) {
-        cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
-      } else if (bootstrapArg.isDexValueInt()) {
-        cfInstructions.add(
-            new CfConstNumber(bootstrapArg.asDexValueInt().getValue(), ValueType.INT));
-        cfInstructions.add(
-            new CfInvoke(Opcodes.INVOKESTATIC, factory.integerMembers.valueOf, false));
-      } else if (bootstrapArg.isDexValueString()) {
-        cfInstructions.add(new CfConstString(bootstrapArg.asDexValueString().getValue()));
-      } else {
-        assert bootstrapArg.isDexValueConstDynamic();
-        DexField enumField = extractEnumField(bootstrapArg.asDexValueConstDynamic(), context);
-        cfInstructions.add(new CfStaticFieldRead(enumField));
-      }
+      adder.accept(bootstrapArg, cfInstructions);
       cfInstructions.add(new CfArrayStore(MemberType.OBJECT));
     }
     return cfInstructions;
@@ -280,6 +347,10 @@
     DexString className = dexValueClassName.asDexValueString().getValue();
     DexType enumType =
         factory.createType(DescriptorUtils.javaTypeToDescriptor(className.toString()));
+    return getEnumField(fieldName, enumType, context);
+  }
+
+  private DexField getEnumField(DexString fieldName, DexType enumType, ProgramMethod context) {
     DexClass enumClass = appView.definitionFor(enumType);
     if (enumClass == null) {
       throw throwEnumFieldConstantDynamic("Missing enum class " + enumType, context);
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java b/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java
new file mode 100644
index 0000000..32fad72
--- /dev/null
+++ b/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java
@@ -0,0 +1,140 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package switchpatternmatching;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.TestRuntime.CfVm;
+import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Assume;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class EnumSwitchUsingEnumSwitchBootstrapMethod extends TestBase {
+
+  @Parameter public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public static String EXPECTED_OUTPUT =
+      StringUtils.lines(
+          "null",
+          "Spades or Piques",
+          "Hearts or C\u0153ur",
+          "Diamonds or Carreaux",
+          "Clubs or Trefles",
+          "Trumps or Atouts",
+          "The Fool or L'Excuse");
+
+  @Test
+  public void testJvm() throws Exception {
+    assumeTrue(parameters.isCfRuntime());
+    CodeInspector inspector = new CodeInspector(ToolHelper.getClassFileForTestClass(Main.class));
+    // javac generated an invokedynamic using bootstrap method
+    // java.lang.runtime.SwitchBootstraps.enumSwitch.
+    assertEquals(
+        1,
+        inspector
+            .clazz(Main.class)
+            .uniqueMethodWithOriginalName("enumSwitch")
+            .streamInstructions()
+            .filter(InstructionSubject::isInvokeDynamic)
+            .map(
+                instruction ->
+                    instruction
+                        .asCfInstruction()
+                        .getInstruction()
+                        .asInvokeDynamic()
+                        .getCallSite()
+                        .getBootstrapMethod()
+                        .member
+                        .asDexMethod())
+            .filter(
+                method ->
+                    method
+                        .getHolderType()
+                        .toString()
+                        .contains("java.lang.runtime.SwitchBootstraps"))
+            .filter(method -> method.toString().contains("enumSwitch"))
+            .count());
+
+    parameters.assumeJvmTestParameters();
+    testForJvm(parameters)
+        .addInnerClassesAndStrippedOuter(getClass())
+        .run(parameters.getRuntime(), Main.class)
+        .applyIf(
+            parameters.getCfRuntime().isNewerThanOrEqual(CfVm.JDK21),
+            r -> r.assertSuccessWithOutput(EXPECTED_OUTPUT),
+            r -> r.assertFailureWithErrorThatThrows(UnsupportedClassVersionError.class));
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    parameters.assumeDexRuntime();
+    testForD8()
+        .addInnerClassesAndStrippedOuter(getClass())
+        .setMinApi(parameters)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    Assume.assumeTrue("For Cf we should compile with Jdk 21 library", parameters.isDexRuntime());
+    testForR8(parameters.getBackend())
+        .addInnerClassesAndStrippedOuter(getClass())
+        .setMinApi(parameters)
+        .addKeepMainRule(Main.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
+  }
+
+  public enum Tarot {
+    SPADE,
+    HEART,
+    DIAMOND,
+    CLUB,
+    TRUMP,
+    EXCUSE
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      enumSwitch(null);
+      enumSwitch(Tarot.SPADE);
+      enumSwitch(Tarot.HEART);
+      enumSwitch(Tarot.DIAMOND);
+      enumSwitch(Tarot.CLUB);
+      enumSwitch(Tarot.TRUMP);
+      enumSwitch(Tarot.EXCUSE);
+    }
+
+    static void enumSwitch(Tarot t1) {
+      switch (t1) {
+        case null -> System.out.println("null");
+        case SPADE -> System.out.println("Spades or Piques");
+        case HEART -> System.out.println("Hearts or C\u0153ur");
+        case Tarot t when t == Tarot.DIAMOND -> System.out.println("Diamonds or Carreaux");
+        case Tarot t when t == Tarot.CLUB -> System.out.println("Clubs or Trefles");
+        case Tarot t when t == Tarot.TRUMP -> System.out.println("Trumps or Atouts");
+        case Tarot t -> System.out.println("The Fool or L'Excuse");
+      }
+    }
+  }
+}