Reapply "Fix TypeSwitch for enums"

This reverts commit 700298b73999005a05be4c3190a4d8aaa6690b9e.

Bug: b/336510513
Change-Id: I00b04b80e36750be6fc6404ae62d80ba05452d6a
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index e3eaa24..77e3369 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -261,6 +261,8 @@
   public final DexString classLoaderDescriptor = createString("Ljava/lang/ClassLoader;");
   public final DexString autoCloseableDescriptor = createString("Ljava/lang/AutoCloseable;");
   public final DexString classArrayDescriptor = createString("[Ljava/lang/Class;");
+  public final DexString classDescDescriptor = createString("Ljava/lang/constant/ClassDesc;");
+  public final DexString enumDescDescriptor = createString("Ljava/lang/Enum$EnumDesc;");
   public final DexString constructorDescriptor = createString("Ljava/lang/reflect/Constructor;");
   public final DexString fieldDescriptor = createString("Ljava/lang/reflect/Field;");
   public final DexString methodDescriptor = createString("Ljava/lang/reflect/Method;");
@@ -457,6 +459,9 @@
   public final DexType stringBuilderType = createStaticallyKnownType(stringBuilderDescriptor);
   public final DexType stringBufferType = createStaticallyKnownType(stringBufferDescriptor);
 
+  public final DexType classDescType = createStaticallyKnownType(classDescDescriptor);
+  public final DexType enumDescType = createStaticallyKnownType(enumDescDescriptor);
+
   public final DexType javaLangAnnotationRetentionPolicyType =
       createStaticallyKnownType("Ljava/lang/annotation/RetentionPolicy;");
   public final DexType javaLangReflectArrayType =
@@ -817,6 +822,17 @@
   public final DexType callSiteType = createStaticallyKnownType("Ljava/lang/invoke/CallSite;");
   public final DexType lookupType =
       createStaticallyKnownType("Ljava/lang/invoke/MethodHandles$Lookup;");
+  public final DexMethod constantDynamicBootstrapMethod =
+      createMethod(
+          constantBootstrapsType,
+          createProto(
+              objectType,
+              methodHandlesLookupType,
+              stringType,
+              classType,
+              methodHandleType,
+              objectArrayType),
+          invokeMethodName);
   public final DexType objectMethodsType =
       createStaticallyKnownType("Ljava/lang/runtime/ObjectMethods;");
   public final DexType typeDescriptorType =
diff --git a/src/main/java/com/android/tools/r8/graph/DexValue.java b/src/main/java/com/android/tools/r8/graph/DexValue.java
index f4f7a60..fe532fc 100644
--- a/src/main/java/com/android/tools/r8/graph/DexValue.java
+++ b/src/main/java/com/android/tools/r8/graph/DexValue.java
@@ -2118,6 +2118,10 @@
       return DexValueKind.CONST_DYNAMIC;
     }
 
+    public ConstantDynamicReference getValue() {
+      return value;
+    }
+
     private CompilationError throwCannotConvertToDex() {
       throw new CompilationError("DexValueConstDynamic should be desugared");
     }
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
index ce86ab6..e6e99e7 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
@@ -10,20 +10,27 @@
 import com.android.tools.r8.cf.code.CfConstString;
 import com.android.tools.r8.cf.code.CfInstruction;
 import com.android.tools.r8.cf.code.CfInvoke;
+import com.android.tools.r8.cf.code.CfNew;
 import com.android.tools.r8.cf.code.CfNewArray;
 import com.android.tools.r8.cf.code.CfStackInstruction;
 import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
+import com.android.tools.r8.cf.code.CfStaticFieldRead;
 import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
-import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.CfCode;
 import com.android.tools.r8.graph.DexCallSite;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedField;
+import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexMethodHandle;
 import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexValue;
+import com.android.tools.r8.graph.DexValue.DexValueConstDynamic;
 import com.android.tools.r8.graph.MethodAccessFlags;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.code.MemberType;
@@ -31,6 +38,9 @@
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaring;
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaringEventConsumer;
 import com.android.tools.r8.ir.desugar.DesugarDescription;
+import com.android.tools.r8.ir.desugar.constantdynamic.ConstantDynamicReference;
+import com.android.tools.r8.utils.DescriptorUtils;
+import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.List;
 import org.objectweb.asm.Opcodes;
@@ -39,17 +49,19 @@
 
   private final AppView<?> appView;
 
-  private final DexString typeSwitch;
   private final DexMethod typeSwitchMethod;
   private final DexProto typeSwitchProto;
   private final DexProto typeSwitchHelperProto;
+  private final DexMethod enumDescMethod;
+  private final DexMethod classDescMethod;
+  private final DexType matchException;
+  private final DexMethod matchExceptionInit;
+  private final DexItemFactory factory;
 
   public TypeSwitchDesugaring(AppView<?> appView) {
     this.appView = appView;
-    DexItemFactory factory = appView.dexItemFactory();
-    typeSwitchProto = factory.createProto(factory.intType, factory.objectType, factory.intType);
+    this.factory = appView.dexItemFactory();
     DexType switchBootstrap = factory.createType("Ljava/lang/runtime/SwitchBootstraps;");
-    typeSwitch = factory.createString("typeSwitch");
     typeSwitchMethod =
         factory.createMethod(
             switchBootstrap,
@@ -59,22 +71,85 @@
                 factory.stringType,
                 factory.methodTypeType,
                 factory.objectArrayType),
-            typeSwitch);
+            factory.createString("typeSwitch"));
+    typeSwitchProto = factory.createProto(factory.intType, factory.objectType, factory.intType);
     typeSwitchHelperProto =
         factory.createProto(
             factory.intType, factory.objectType, factory.intType, factory.objectArrayType);
+    enumDescMethod =
+        factory.createMethod(
+            factory.enumDescType,
+            factory.createProto(factory.enumDescType, factory.classDescType, factory.stringType),
+            "of");
+    classDescMethod =
+        factory.createMethod(
+            factory.classDescType,
+            factory.createProto(factory.classDescType, factory.stringType),
+            "of");
+    matchException = factory.createType("Ljava/lang/MatchException;");
+    matchExceptionInit =
+        factory.createInstanceInitializer(
+            matchException, factory.stringType, factory.throwableType);
+  }
+
+  private boolean methodHandleIsInvokeStaticTo(DexValue dexValue, DexMethod method) {
+    if (!dexValue.isDexValueMethodHandle()) {
+      return false;
+    }
+    return methodHandleIsInvokeStaticTo(dexValue.asDexValueMethodHandle().getValue(), method);
+  }
+
+  private boolean methodHandleIsInvokeStaticTo(DexMethodHandle methodHandle, DexMethod method) {
+    return methodHandle.type.isInvokeStatic() && methodHandle.asMethod().isIdenticalTo(method);
   }
 
   @Override
   public DesugarDescription compute(CfInstruction instruction, ProgramMethod context) {
     if (!instruction.isInvokeDynamic()) {
+      // We need to replace the new MatchException with RuntimeException.
+      if (instruction.isNew() && instruction.asNew().getType().isIdenticalTo(matchException)) {
+        return DesugarDescription.builder()
+            .setDesugarRewrite(
+                (position,
+                    freshLocalProvider,
+                    localStackAllocator,
+                    desugaringInfo,
+                    eventConsumer,
+                    theContext,
+                    methodProcessingContext,
+                    desugaringCollection,
+                    dexItemFactory) -> ImmutableList.of(new CfNew(factory.runtimeExceptionType)))
+            .build();
+      }
+      if (instruction.isInvokeSpecial()
+          && instruction.asInvoke().getMethod().isIdenticalTo(matchExceptionInit)) {
+        return DesugarDescription.builder()
+            .setDesugarRewrite(
+                (position,
+                    freshLocalProvider,
+                    localStackAllocator,
+                    desugaringInfo,
+                    eventConsumer,
+                    theContext,
+                    methodProcessingContext,
+                    desugaringCollection,
+                    dexItemFactory) ->
+                    ImmutableList.of(
+                        new CfInvoke(
+                            Opcodes.INVOKESPECIAL,
+                            factory.createInstanceInitializer(
+                                factory.runtimeExceptionType,
+                                factory.stringType,
+                                factory.throwableType),
+                            false)))
+            .build();
+      }
       return DesugarDescription.nothing();
     }
     DexCallSite callSite = instruction.asInvokeDynamic().getCallSite();
-    if (!(callSite.methodName.isIdenticalTo(typeSwitch)
+    if (!(callSite.methodName.isIdenticalTo(typeSwitchMethod.getName())
         && callSite.methodProto.isIdenticalTo(typeSwitchProto)
-        && callSite.bootstrapMethod.member.isDexMethod()
-        && callSite.bootstrapMethod.member.asDexMethod().isIdenticalTo(typeSwitchMethod))) {
+        && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, typeSwitchMethod))) {
       return DesugarDescription.nothing();
     }
     // Call the desugared method.
@@ -91,7 +166,7 @@
                 dexItemFactory) -> {
               // We add on stack (2) array, (3) dupped array, (4) index, (5) value.
               localStackAllocator.allocateLocalStack(4);
-              List<CfInstruction> cfInstructions = generateLoadArguments(callSite);
+              List<CfInstruction> cfInstructions = generateLoadArguments(callSite, context);
               generateInvokeToDesugaredMethod(
                   methodProcessingContext, cfInstructions, theContext, eventConsumer);
               return cfInstructions;
@@ -120,14 +195,10 @@
                             methodSig -> {
                               CfCode code =
                                   TypeSwitchMethods.TypeSwitchMethods_typeSwitch(
-                                      appView.dexItemFactory(), methodSig);
+                                      factory, methodSig);
                               if (appView.options().hasMappingFileSupport()) {
                                 return code.getCodeAsInlining(
-                                    methodSig,
-                                    true,
-                                    context.getReference(),
-                                    false,
-                                    appView.dexItemFactory());
+                                    methodSig, true, context.getReference(), false, factory);
                               }
                               return code;
                             }));
@@ -135,13 +206,13 @@
     cfInstructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method.getReference(), false));
   }
 
-  private List<CfInstruction> generateLoadArguments(DexCallSite callSite) {
+  private List<CfInstruction> generateLoadArguments(DexCallSite callSite, ProgramMethod context) {
     // We need to call the method with the bootstrap args as parameters.
     // We need to convert the bootstrap args into a list of cf instructions.
     // The object and the int are already pushed on stack, we simply need to push the extra array.
     List<CfInstruction> cfInstructions = new ArrayList<>();
     cfInstructions.add(new CfConstNumber(callSite.bootstrapArgs.size(), ValueType.INT));
-    cfInstructions.add(new CfNewArray(appView.dexItemFactory().objectArrayType));
+    cfInstructions.add(new CfNewArray(factory.objectArrayType));
     for (int i = 0; i < callSite.bootstrapArgs.size(); i++) {
       DexValue bootstrapArg = callSite.bootstrapArgs.get(i);
       cfInstructions.add(new CfStackInstruction(Opcode.Dup));
@@ -152,16 +223,72 @@
         cfInstructions.add(
             new CfConstNumber(bootstrapArg.asDexValueInt().getValue(), ValueType.INT));
         cfInstructions.add(
-            new CfInvoke(
-                Opcodes.INVOKESTATIC, appView.dexItemFactory().integerMembers.valueOf, false));
+            new CfInvoke(Opcodes.INVOKESTATIC, factory.integerMembers.valueOf, false));
       } else if (bootstrapArg.isDexValueString()) {
         cfInstructions.add(new CfConstString(bootstrapArg.asDexValueString().getValue()));
       } else {
         assert bootstrapArg.isDexValueConstDynamic();
-        throw new Unreachable("TODO(b/336510513): Enum descriptor should be implemented");
+        DexField enumField = extractEnumField(bootstrapArg.asDexValueConstDynamic(), context);
+        cfInstructions.add(new CfStaticFieldRead(enumField));
       }
       cfInstructions.add(new CfArrayStore(MemberType.OBJECT));
     }
     return cfInstructions;
   }
+
+  private CompilationError throwEnumFieldConstantDynamic(String msg, ProgramMethod context) {
+    throw new CompilationError(
+        "Unexpected ConstantDynamic in TypeSwitch: " + msg, context.getOrigin());
+  }
+
+  private DexField extractEnumField(
+      DexValueConstDynamic dexValueConstDynamic, ProgramMethod context) {
+    ConstantDynamicReference enumCstDynamic = dexValueConstDynamic.getValue();
+    DexMethod bootstrapMethod = factory.constantDynamicBootstrapMethod;
+    if (!(enumCstDynamic.getType().isIdenticalTo(factory.enumDescType)
+        && enumCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
+        && enumCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
+        && enumCstDynamic.getBootstrapMethodArguments().size() == 3
+        && methodHandleIsInvokeStaticTo(
+            enumCstDynamic.getBootstrapMethodArguments().get(0), enumDescMethod))) {
+      throw throwEnumFieldConstantDynamic("Invalid EnumDesc", context);
+    }
+    DexValue dexValueFieldName = enumCstDynamic.getBootstrapMethodArguments().get(2);
+    if (!dexValueFieldName.isDexValueString()) {
+      throw throwEnumFieldConstantDynamic("Field name " + dexValueFieldName, context);
+    }
+    DexString fieldName = dexValueFieldName.asDexValueString().getValue();
+
+    DexValue dexValueClassCstDynamic = enumCstDynamic.getBootstrapMethodArguments().get(1);
+    if (!dexValueClassCstDynamic.isDexValueConstDynamic()) {
+      throw throwEnumFieldConstantDynamic("Enum class " + dexValueClassCstDynamic, context);
+    }
+    ConstantDynamicReference classCstDynamic =
+        dexValueClassCstDynamic.asDexValueConstDynamic().getValue();
+    if (!(classCstDynamic.getType().isIdenticalTo(factory.classDescType)
+        && classCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
+        && classCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
+        && classCstDynamic.getBootstrapMethodArguments().size() == 2
+        && methodHandleIsInvokeStaticTo(
+            classCstDynamic.getBootstrapMethodArguments().get(0), classDescMethod))) {
+      throw throwEnumFieldConstantDynamic("Class descriptor " + classCstDynamic, context);
+    }
+    DexValue dexValueClassName = classCstDynamic.getBootstrapMethodArguments().get(1);
+    if (!dexValueClassName.isDexValueString()) {
+      throw throwEnumFieldConstantDynamic("Class name " + dexValueClassName, context);
+    }
+    DexString className = dexValueClassName.asDexValueString().getValue();
+    DexType enumType =
+        factory.createType(DescriptorUtils.javaTypeToDescriptor(className.toString()));
+    DexClass enumClass = appView.definitionFor(enumType);
+    if (enumClass == null) {
+      throw throwEnumFieldConstantDynamic("Missing enum class " + enumType, context);
+    }
+    DexEncodedField dexEncodedField = enumClass.lookupUniqueStaticFieldWithName(fieldName);
+    if (dexEncodedField == null) {
+      throw throwEnumFieldConstantDynamic(
+          "Missing enum field " + fieldName + " in " + enumType, context);
+    }
+    return dexEncodedField.getReference();
+  }
 }
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java b/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
index 90001a5..e57e31c 100644
--- a/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
@@ -4,11 +4,10 @@
 package switchpatternmatching;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThrows;
 import static org.junit.Assume.assumeTrue;
 
-import com.android.tools.r8.CompilationFailedException;
 import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestBuilder;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.TestRuntime.CfVm;
@@ -16,11 +15,13 @@
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Assume;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
+import switchpatternmatching.StringSwitchTest.Main;
 
 @RunWith(Parameterized.class)
 public class EnumSwitchTest extends TestBase {
@@ -32,7 +33,8 @@
     return getTestParameters().withAllRuntimesAndApiLevels().build();
   }
 
-  public static String EXPECTED_OUTPUT = StringUtils.lines("null", "E1", "E2", "E3", "E4", "a C");
+  public static String EXPECTED_OUTPUT =
+      StringUtils.lines("null", "E1", "E2", "E3", "E4", "a C", "class %s");
 
   @Test
   public void testJvm() throws Exception {
@@ -68,36 +70,72 @@
 
     parameters.assumeJvmTestParameters();
     testForJvm(parameters)
-        .addInnerClassesAndStrippedOuter(getClass())
+        .apply(this::addModifiedProgramClasses)
         .run(parameters.getRuntime(), Main.class)
         .applyIf(
             parameters.getCfRuntime().isNewerThanOrEqual(CfVm.JDK21),
-            r -> r.assertSuccessWithOutput(EXPECTED_OUTPUT),
+            r ->
+                r.assertSuccessWithOutput(
+                    String.format(EXPECTED_OUTPUT, "java.lang.MatchException")),
             r -> r.assertFailureWithErrorThatThrows(UnsupportedClassVersionError.class));
   }
 
+  private <T extends TestBuilder<?, T>> void addModifiedProgramClasses(
+      TestBuilder<?, T> testBuilder) throws Exception {
+    testBuilder
+        .addStrippedOuter(getClass())
+        .addProgramClasses(FakeI.class, E.class, C.class)
+        .addProgramClassFileData(
+            transformer(I.class)
+                .setPermittedSubclasses(I.class, E.class, C.class, D.class)
+                .transform())
+        .addProgramClassFileData(transformer(D.class).setImplements(I.class).transform())
+        .addProgramClassFileData(
+            transformer(Main.class)
+                .transformTypeInsnInMethod(
+                    "getD",
+                    (opcode, type, visitor) ->
+                        visitor.visitTypeInsn(opcode, "switchpatternmatching/EnumSwitchTest$D"))
+                .transformMethodInsnInMethod(
+                    "getD",
+                    (opcode, owner, name, descriptor, isInterface, visitor) -> {
+                      assert name.equals("<init>");
+                      visitor.visitMethodInsn(
+                          opcode,
+                          "switchpatternmatching/EnumSwitchTest$D",
+                          name,
+                          descriptor,
+                          isInterface);
+                    })
+                .transform());
+  }
+
   @Test
   public void testD8() throws Exception {
     parameters.assumeDexRuntime();
-    assertThrows(
-        CompilationFailedException.class,
-        () -> testForD8().addInnerClasses(getClass()).setMinApi(parameters).compile());
+    testForD8()
+        .apply(this::addModifiedProgramClasses)
+        .setMinApi(parameters)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(String.format(EXPECTED_OUTPUT, "java.lang.RuntimeException"));
   }
 
   @Test
   public void testR8() throws Exception {
-    assertThrows(
-        CompilationFailedException.class,
-        () ->
-            testForR8(parameters.getBackend())
-                .addInnerClasses(getClass())
-                .setMinApi(parameters)
-                .addKeepMainRule(Main.class)
-                .compile());
+    Assume.assumeTrue("For Cf we should compile with Jdk 21 library", parameters.isDexRuntime());
+    testForR8(parameters.getBackend())
+        .apply(this::addModifiedProgramClasses)
+        .setMinApi(parameters)
+        .addKeepMainRule(Main.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(String.format(EXPECTED_OUTPUT, "java.lang.RuntimeException"));
   }
 
+  // D is added to the list of permitted subclasses to reproduce the MatchException.
   sealed interface I permits E, C {}
 
+  interface FakeI {}
+
   public enum E implements I {
     E1,
     E2,
@@ -105,6 +143,9 @@
     E4
   }
 
+  // Replaced with I.
+  static final class D implements FakeI {}
+
   static final class C implements I {}
 
   static class Main {
@@ -140,6 +181,16 @@
       enumSwitch(E.E3);
       enumSwitch(E.E4);
       enumSwitch(new C());
+      try {
+        enumSwitch(getD());
+      } catch (Throwable t) {
+        System.out.println(t.getClass());
+      }
+    }
+
+    public static I getD() {
+      // Replaced by new D();
+      return new C();
     }
   }
 }
diff --git a/src/test/testbase/java/com/android/tools/r8/TestBuilder.java b/src/test/testbase/java/com/android/tools/r8/TestBuilder.java
index 02d0e06..32a40c0 100644
--- a/src/test/testbase/java/com/android/tools/r8/TestBuilder.java
+++ b/src/test/testbase/java/com/android/tools/r8/TestBuilder.java
@@ -147,16 +147,19 @@
     return addProgramFiles(getFilesForInnerClasses(classes));
   }
 
-  public T addInnerClassesAndStrippedOuter(Class<?> clazz) throws IOException {
+  public T addStrippedOuter(Class<?> clazz) throws IOException {
     return addProgramClassFileData(
-            TestBase.transformer(clazz)
-                .removeFields(FieldPredicate.all())
-                .removeMethods(MethodPredicate.all())
-                .removeAllAnnotations()
-                .setSuper(descriptor(Object.class))
-                .setImplements()
-                .transform())
-        .addInnerClasses(ImmutableList.of(clazz));
+        TestBase.transformer(clazz)
+            .removeFields(FieldPredicate.all())
+            .removeMethods(MethodPredicate.all())
+            .removeAllAnnotations()
+            .setSuper(descriptor(Object.class))
+            .setImplements()
+            .transform());
+  }
+
+  public T addInnerClassesAndStrippedOuter(Class<?> clazz) throws IOException {
+    return addStrippedOuter(clazz).addInnerClasses(ImmutableList.of(clazz));
   }
 
   public abstract T addLibraryFiles(Collection<Path> files);