Fix type switch in R8 cf to cf

Bug: b/336510513
Change-Id: Ie3aa7af2094288f3c7602a3054eb7df2dc8cc5b7
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfConstDynamic.java b/src/main/java/com/android/tools/r8/cf/code/CfConstDynamic.java
index 9936ef6..ed3a275 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfConstDynamic.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfConstDynamic.java
@@ -138,7 +138,9 @@
             reference.getBootstrapMethodArguments(), NOT_ARGUMENT_TO_LAMBDA_METAFACTORY, context);
     Object[] bsmArgs = new Object[rewrittenArguments.size()];
     for (int i = 0; i < rewrittenArguments.size(); i++) {
-      bsmArgs[i] = CfInvokeDynamic.decodeBootstrapArgument(rewrittenArguments.get(i), namingLens);
+      bsmArgs[i] =
+          CfInvokeDynamic.decodeBootstrapArgument(
+              rewrittenArguments.get(i), namingLens, dexItemFactory);
     }
     ConstantDynamic constantDynamic =
         new ConstantDynamic(
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfInvokeDynamic.java b/src/main/java/com/android/tools/r8/cf/code/CfInvokeDynamic.java
index a36f61b..8e39bb9 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfInvokeDynamic.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfInvokeDynamic.java
@@ -23,6 +23,7 @@
 import com.android.tools.r8.ir.conversion.CfState;
 import com.android.tools.r8.ir.conversion.IRBuilder;
 import com.android.tools.r8.ir.conversion.LensCodeRewriterUtils;
+import com.android.tools.r8.ir.desugar.constantdynamic.ConstantDynamicReference;
 import com.android.tools.r8.naming.NamingLens;
 import com.android.tools.r8.optimize.interfaces.analysis.CfAnalysisConfig;
 import com.android.tools.r8.optimize.interfaces.analysis.CfFrameState;
@@ -31,6 +32,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.ListIterator;
+import org.objectweb.asm.ConstantDynamic;
 import org.objectweb.asm.Handle;
 import org.objectweb.asm.MethodVisitor;
 import org.objectweb.asm.Opcodes;
@@ -86,7 +88,7 @@
     List<DexValue> bootstrapArgs = rewrittenCallSite.bootstrapArgs;
     Object[] bsmArgs = new Object[bootstrapArgs.size()];
     for (int i = 0; i < bootstrapArgs.size(); i++) {
-      bsmArgs[i] = decodeBootstrapArgument(bootstrapArgs.get(i), namingLens);
+      bsmArgs[i] = decodeBootstrapArgument(bootstrapArgs.get(i), namingLens, dexItemFactory);
     }
     Handle bsmHandle = bootstrapMethod.toAsmHandle(namingLens);
     DexString methodName = namingLens.lookupMethodName(rewrittenCallSite, appView);
@@ -102,7 +104,8 @@
     return 5;
   }
 
-  public static Object decodeBootstrapArgument(DexValue value, NamingLens lens) {
+  public static Object decodeBootstrapArgument(
+      DexValue value, NamingLens lens, DexItemFactory factory) {
     switch (value.getValueKind()) {
       case DOUBLE:
         return value.asDexValueDouble().getValue();
@@ -121,6 +124,18 @@
         return innerValue == null ? null : innerValue.toString();
       case TYPE:
         return Type.getType(lens.lookupDescriptor(value.asDexValueType().value).toString());
+      case CONST_DYNAMIC:
+        ConstantDynamicReference ref = value.asDexValueConstDynamic().getValue();
+        List<DexValue> bootstrapArgs = ref.getBootstrapMethodArguments();
+        Object[] bsmArgs = new Object[bootstrapArgs.size()];
+        for (int i = 0; i < bootstrapArgs.size(); i++) {
+          bsmArgs[i] = CfInvokeDynamic.decodeBootstrapArgument(bootstrapArgs.get(i), lens, factory);
+        }
+        return new ConstantDynamic(
+            ref.getName().toString(),
+            lens.lookupType(ref.getType(), factory).toDescriptorString(),
+            ref.getBootstrapMethod().toAsmHandle(lens),
+            bsmArgs);
       default:
         throw new Unreachable(
             "Unsupported bootstrap argument of type " + value.getClass().getSimpleName());
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 77e3369..a8233fd 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -819,6 +819,7 @@
       createStaticallyKnownType("Ljava/lang/invoke/LambdaMetafactory;");
   public final DexType constantBootstrapsType =
       createStaticallyKnownType("Ljava/lang/invoke/ConstantBootstraps;");
+  public final DexType switchBootstrapType = createType("Ljava/lang/runtime/SwitchBootstraps;");
   public final DexType callSiteType = createStaticallyKnownType("Ljava/lang/invoke/CallSite;");
   public final DexType lookupType =
       createStaticallyKnownType("Ljava/lang/invoke/MethodHandles$Lookup;");
@@ -833,6 +834,18 @@
               methodHandleType,
               objectArrayType),
           invokeMethodName);
+  public final DexProto switchBootstrapMethodProto =
+      createProto(
+          callSiteType, methodHandlesLookupType, stringType, methodTypeType, objectArrayType);
+  public final DexMethod typeSwitchMethod =
+      createMethod(switchBootstrapType, switchBootstrapMethodProto, createString("typeSwitch"));
+  public final DexMethod enumSwitchMethod =
+      createMethod(switchBootstrapType, switchBootstrapMethodProto, createString("enumSwitch"));
+  public final DexProto typeSwitchProto = createProto(intType, objectType, intType);
+  public final DexMethod enumDescMethod =
+      createMethod(enumDescType, createProto(enumDescType, classDescType, stringType), "of");
+  public final DexMethod classDescMethod =
+      createMethod(classDescType, createProto(classDescType, stringType), "of");
   public final DexType objectMethodsType =
       createStaticallyKnownType("Ljava/lang/runtime/ObjectMethods;");
   public final DexType typeDescriptorType =
diff --git a/src/main/java/com/android/tools/r8/graph/UseRegistry.java b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
index 5d92dc1..9d623da 100644
--- a/src/main/java/com/android/tools/r8/graph/UseRegistry.java
+++ b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
@@ -3,13 +3,17 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.graph;
 
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.isTypeSwitchCallSite;
+
 import com.android.tools.r8.dex.code.CfOrDexInstanceFieldRead;
 import com.android.tools.r8.dex.code.CfOrDexInstruction;
 import com.android.tools.r8.dex.code.CfOrDexStaticFieldRead;
+import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.graph.bytecodemetadata.BytecodeInstructionMetadata;
 import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.ir.code.InvokeType;
 import com.android.tools.r8.ir.code.Position;
+import com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper;
 import com.android.tools.r8.utils.TraversalContinuation;
 import java.util.ListIterator;
 
@@ -256,6 +260,16 @@
         case TYPE:
           registerTypeReference(arg.asDexValueType().value);
           break;
+        case CONST_DYNAMIC:
+          if (!isTypeSwitchCallSite(callSite, appView.dexItemFactory())) {
+            throw new CompilationError(
+                "Unsupported const dynamic in call site " + arg, getContext().getOrigin());
+          }
+          DexField dexField =
+              TypeSwitchDesugaringHelper.extractEnumField(
+                  arg.asDexValueConstDynamic(), getMethodContext(), appView);
+          registerStaticFieldRead(dexField);
+          break;
         default:
           assert arg.isDexValueInt()
               || arg.isDexValueLong()
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
index 7103887..2dc50c6 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
@@ -4,6 +4,11 @@
 
 package com.android.tools.r8.ir.desugar.typeswitch;
 
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.extractEnumField;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.getEnumField;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.isEnumSwitchCallSite;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.isTypeSwitchCallSite;
+
 import com.android.tools.r8.cf.code.CfArrayStore;
 import com.android.tools.r8.cf.code.CfConstClass;
 import com.android.tools.r8.cf.code.CfConstNumber;
@@ -20,19 +25,14 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.CfCode;
 import com.android.tools.r8.graph.DexCallSite;
-import com.android.tools.r8.graph.DexClass;
-import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
-import com.android.tools.r8.graph.DexMethodHandle;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexProto;
-import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexValue;
-import com.android.tools.r8.graph.DexValue.DexValueConstDynamic;
 import com.android.tools.r8.graph.MethodAccessFlags;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.code.MemberType;
@@ -41,8 +41,6 @@
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaringEventConsumer;
 import com.android.tools.r8.ir.desugar.DesugarDescription;
 import com.android.tools.r8.ir.desugar.LocalStackAllocator;
-import com.android.tools.r8.ir.desugar.constantdynamic.ConstantDynamicReference;
-import com.android.tools.r8.utils.DescriptorUtils;
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.Iterator;
@@ -54,12 +52,7 @@
 
   private final AppView<?> appView;
 
-  private final DexMethod typeSwitchMethod;
-  private final DexMethod enumSwitchMethod;
-  private final DexProto typeSwitchProto;
   private final DexProto switchHelperProto;
-  private final DexMethod enumDescMethod;
-  private final DexMethod classDescMethod;
   private final DexType matchException;
   private final DexMethod matchExceptionInit;
   private final DexItemFactory factory;
@@ -67,52 +60,15 @@
   public TypeSwitchDesugaring(AppView<?> appView) {
     this.appView = appView;
     this.factory = appView.dexItemFactory();
-    DexType switchBootstrap = factory.createType("Ljava/lang/runtime/SwitchBootstraps;");
-    DexProto methodHandleSwitchProto =
-        factory.createProto(
-            factory.callSiteType,
-            factory.methodHandlesLookupType,
-            factory.stringType,
-            factory.methodTypeType,
-            factory.objectArrayType);
-    typeSwitchMethod =
-        factory.createMethod(
-            switchBootstrap, methodHandleSwitchProto, factory.createString("typeSwitch"));
-    enumSwitchMethod =
-        factory.createMethod(
-            switchBootstrap, methodHandleSwitchProto, factory.createString("enumSwitch"));
-    this.typeSwitchProto =
-        factory.createProto(factory.intType, factory.objectType, factory.intType);
     switchHelperProto =
         factory.createProto(
             factory.intType, factory.objectType, factory.intType, factory.objectArrayType);
-    enumDescMethod =
-        factory.createMethod(
-            factory.enumDescType,
-            factory.createProto(factory.enumDescType, factory.classDescType, factory.stringType),
-            "of");
-    classDescMethod =
-        factory.createMethod(
-            factory.classDescType,
-            factory.createProto(factory.classDescType, factory.stringType),
-            "of");
     matchException = factory.createType("Ljava/lang/MatchException;");
     matchExceptionInit =
         factory.createInstanceInitializer(
             matchException, factory.stringType, factory.throwableType);
   }
 
-  private boolean methodHandleIsInvokeStaticTo(DexValue dexValue, DexMethod method) {
-    if (!dexValue.isDexValueMethodHandle()) {
-      return false;
-    }
-    return methodHandleIsInvokeStaticTo(dexValue.asDexValueMethodHandle().getValue(), method);
-  }
-
-  private boolean methodHandleIsInvokeStaticTo(DexMethodHandle methodHandle, DexMethod method) {
-    return methodHandle.type.isInvokeStatic() && methodHandle.asMethod().isIdenticalTo(method);
-  }
-
   @Override
   public DesugarDescription compute(CfInstruction instruction, ProgramMethod context) {
     if (!instruction.isInvokeDynamic()) {
@@ -157,9 +113,7 @@
       return DesugarDescription.nothing();
     }
     DexCallSite callSite = instruction.asInvokeDynamic().getCallSite();
-    if (callSite.methodName.isIdenticalTo(typeSwitchMethod.getName())
-        && callSite.methodProto.isIdenticalTo(typeSwitchProto)
-        && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, typeSwitchMethod)) {
+    if (isTypeSwitchCallSite(callSite, factory)) {
       return DesugarDescription.builder()
           .setDesugarRewrite(
               (position,
@@ -180,9 +134,7 @@
                           generateTypeSwitchLoadArguments(cfInstructions, callSite, context)))
           .build();
     }
-    if (callSite.methodName.isIdenticalTo(enumSwitchMethod.getName())
-        && isEnumSwitchProto(callSite.methodProto)
-        && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, enumSwitchMethod)) {
+    if (isEnumSwitchCallSite(callSite, factory)) {
       DexType enumType = callSite.methodProto.getParameter(0);
       return DesugarDescription.builder()
           .setDesugarRewrite(
@@ -233,12 +185,6 @@
     return cfInstructions;
   }
 
-  private boolean isEnumSwitchProto(DexProto methodProto) {
-    return methodProto.getReturnType().isIdenticalTo(factory.intType)
-        && methodProto.getArity() == 2
-        && methodProto.getParameter(1).isIdenticalTo(factory.intType);
-  }
-
   private void generateInvokeToDesugaredMethod(
       List<CfInstruction> cfInstructions,
       MethodProcessingContext methodProcessingContext,
@@ -271,12 +217,12 @@
     cfInstructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method.getReference(), false));
   }
 
-  private List<CfInstruction> generateEnumSwitchLoadArguments(
+  private void generateEnumSwitchLoadArguments(
       List<CfInstruction> cfInstructions,
       DexCallSite callSite,
       ProgramMethod context,
       DexType enumType) {
-    return generateSwitchLoadArguments(
+    generateSwitchLoadArguments(
         cfInstructions,
         callSite,
         bootstrapArg -> {
@@ -284,7 +230,8 @@
             cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
           } else if (bootstrapArg.isDexValueString()) {
             DexField enumField =
-                getEnumField(bootstrapArg.asDexValueString().getValue(), enumType, context);
+                getEnumField(
+                    bootstrapArg.asDexValueString().getValue(), enumType, context, appView);
             cfInstructions.add(new CfStaticFieldRead(enumField));
           } else {
             throw new CompilationError(
@@ -293,9 +240,9 @@
         });
   }
 
-  private List<CfInstruction> generateTypeSwitchLoadArguments(
+  private void generateTypeSwitchLoadArguments(
       List<CfInstruction> cfInstructions, DexCallSite callSite, ProgramMethod context) {
-    return generateSwitchLoadArguments(
+    generateSwitchLoadArguments(
         cfInstructions,
         callSite,
         bootstrapArg -> {
@@ -309,7 +256,8 @@
           } else if (bootstrapArg.isDexValueString()) {
             cfInstructions.add(new CfConstString(bootstrapArg.asDexValueString().getValue()));
           } else if (bootstrapArg.isDexValueConstDynamic()) {
-            DexField enumField = extractEnumField(bootstrapArg.asDexValueConstDynamic(), context);
+            DexField enumField =
+                extractEnumField(bootstrapArg.asDexValueConstDynamic(), context, appView);
             cfInstructions.add(new CfStaticFieldRead(enumField));
           } else {
             throw new CompilationError(
@@ -318,7 +266,7 @@
         });
   }
 
-  private List<CfInstruction> generateSwitchLoadArguments(
+  private void generateSwitchLoadArguments(
       List<CfInstruction> cfInstructions, DexCallSite callSite, Consumer<DexValue> adder) {
     // We need to call the method with the bootstrap args as parameters.
     // We need to convert the bootstrap args into a list of cf instructions.
@@ -332,66 +280,7 @@
       adder.accept(bootstrapArg);
       cfInstructions.add(new CfArrayStore(MemberType.OBJECT));
     }
-    return cfInstructions;
   }
 
-  private CompilationError throwEnumFieldConstantDynamic(String msg, ProgramMethod context) {
-    throw new CompilationError(
-        "Unexpected ConstantDynamic in TypeSwitch: " + msg, context.getOrigin());
-  }
 
-  private DexField extractEnumField(
-      DexValueConstDynamic dexValueConstDynamic, ProgramMethod context) {
-    ConstantDynamicReference enumCstDynamic = dexValueConstDynamic.getValue();
-    DexMethod bootstrapMethod = factory.constantDynamicBootstrapMethod;
-    if (!(enumCstDynamic.getType().isIdenticalTo(factory.enumDescType)
-        && enumCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
-        && enumCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
-        && enumCstDynamic.getBootstrapMethodArguments().size() == 3
-        && methodHandleIsInvokeStaticTo(
-            enumCstDynamic.getBootstrapMethodArguments().get(0), enumDescMethod))) {
-      throw throwEnumFieldConstantDynamic("Invalid EnumDesc", context);
-    }
-    DexValue dexValueFieldName = enumCstDynamic.getBootstrapMethodArguments().get(2);
-    if (!dexValueFieldName.isDexValueString()) {
-      throw throwEnumFieldConstantDynamic("Field name " + dexValueFieldName, context);
-    }
-    DexString fieldName = dexValueFieldName.asDexValueString().getValue();
-
-    DexValue dexValueClassCstDynamic = enumCstDynamic.getBootstrapMethodArguments().get(1);
-    if (!dexValueClassCstDynamic.isDexValueConstDynamic()) {
-      throw throwEnumFieldConstantDynamic("Enum class " + dexValueClassCstDynamic, context);
-    }
-    ConstantDynamicReference classCstDynamic =
-        dexValueClassCstDynamic.asDexValueConstDynamic().getValue();
-    if (!(classCstDynamic.getType().isIdenticalTo(factory.classDescType)
-        && classCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
-        && classCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
-        && classCstDynamic.getBootstrapMethodArguments().size() == 2
-        && methodHandleIsInvokeStaticTo(
-            classCstDynamic.getBootstrapMethodArguments().get(0), classDescMethod))) {
-      throw throwEnumFieldConstantDynamic("Class descriptor " + classCstDynamic, context);
-    }
-    DexValue dexValueClassName = classCstDynamic.getBootstrapMethodArguments().get(1);
-    if (!dexValueClassName.isDexValueString()) {
-      throw throwEnumFieldConstantDynamic("Class name " + dexValueClassName, context);
-    }
-    DexString className = dexValueClassName.asDexValueString().getValue();
-    DexType enumType =
-        factory.createType(DescriptorUtils.javaTypeToDescriptor(className.toString()));
-    return getEnumField(fieldName, enumType, context);
-  }
-
-  private DexField getEnumField(DexString fieldName, DexType enumType, ProgramMethod context) {
-    DexClass enumClass = appView.definitionFor(enumType);
-    if (enumClass == null) {
-      throw throwEnumFieldConstantDynamic("Missing enum class " + enumType, context);
-    }
-    DexEncodedField dexEncodedField = enumClass.lookupUniqueStaticFieldWithName(fieldName);
-    if (dexEncodedField == null) {
-      throw throwEnumFieldConstantDynamic(
-          "Missing enum field " + fieldName + " in " + enumType, context);
-    }
-    return dexEncodedField.getReference();
-  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaringHelper.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaringHelper.java
new file mode 100644
index 0000000..fbb5e8f
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaringHelper.java
@@ -0,0 +1,118 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.desugar.typeswitch;
+
+import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexCallSite;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexClassAndMethod;
+import com.android.tools.r8.graph.DexEncodedField;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexMethodHandle;
+import com.android.tools.r8.graph.DexProto;
+import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.DexValue;
+import com.android.tools.r8.graph.DexValue.DexValueConstDynamic;
+import com.android.tools.r8.ir.desugar.constantdynamic.ConstantDynamicReference;
+import com.android.tools.r8.utils.DescriptorUtils;
+
+public class TypeSwitchDesugaringHelper {
+  private static CompilationError throwEnumFieldConstantDynamic(
+      String msg, DexClassAndMethod context) {
+    throw new CompilationError(
+        "Unexpected ConstantDynamic in TypeSwitch: " + msg, context.getOrigin());
+  }
+
+  public static boolean isTypeSwitchCallSite(DexCallSite callSite, DexItemFactory factory) {
+    return callSite.methodName.isIdenticalTo(factory.typeSwitchMethod.getName())
+        && callSite.methodProto.isIdenticalTo(factory.typeSwitchProto)
+        && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, factory.typeSwitchMethod);
+  }
+
+  public static boolean isEnumSwitchCallSite(DexCallSite callSite, DexItemFactory factory) {
+    return callSite.methodName.isIdenticalTo(factory.enumSwitchMethod.getName())
+        && isEnumSwitchProto(callSite.methodProto, factory.intType)
+        && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, factory.enumSwitchMethod);
+  }
+
+  private static boolean methodHandleIsInvokeStaticTo(DexValue dexValue, DexMethod method) {
+    if (!dexValue.isDexValueMethodHandle()) {
+      return false;
+    }
+    return methodHandleIsInvokeStaticTo(dexValue.asDexValueMethodHandle().getValue(), method);
+  }
+
+  private static boolean methodHandleIsInvokeStaticTo(
+      DexMethodHandle methodHandle, DexMethod method) {
+    return methodHandle.type.isInvokeStatic() && methodHandle.asMethod().isIdenticalTo(method);
+  }
+
+  private static boolean isEnumSwitchProto(DexProto methodProto, DexType intType) {
+    return methodProto.getReturnType().isIdenticalTo(intType)
+        && methodProto.getArity() == 2
+        && methodProto.getParameter(1).isIdenticalTo(intType);
+  }
+
+  public static DexField extractEnumField(
+      DexValueConstDynamic dexValueConstDynamic, DexClassAndMethod context, AppView<?> appView) {
+    DexItemFactory factory = appView.dexItemFactory();
+    ConstantDynamicReference enumCstDynamic = dexValueConstDynamic.getValue();
+    DexMethod bootstrapMethod = factory.constantDynamicBootstrapMethod;
+    if (!(enumCstDynamic.getType().isIdenticalTo(factory.enumDescType)
+        && enumCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
+        && enumCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
+        && enumCstDynamic.getBootstrapMethodArguments().size() == 3
+        && methodHandleIsInvokeStaticTo(
+            enumCstDynamic.getBootstrapMethodArguments().get(0), factory.enumDescMethod))) {
+      throw throwEnumFieldConstantDynamic("Invalid EnumDesc", context);
+    }
+    DexValue dexValueFieldName = enumCstDynamic.getBootstrapMethodArguments().get(2);
+    if (!dexValueFieldName.isDexValueString()) {
+      throw throwEnumFieldConstantDynamic("Field name " + dexValueFieldName, context);
+    }
+    DexString fieldName = dexValueFieldName.asDexValueString().getValue();
+
+    DexValue dexValueClassCstDynamic = enumCstDynamic.getBootstrapMethodArguments().get(1);
+    if (!dexValueClassCstDynamic.isDexValueConstDynamic()) {
+      throw throwEnumFieldConstantDynamic("Enum class " + dexValueClassCstDynamic, context);
+    }
+    ConstantDynamicReference classCstDynamic =
+        dexValueClassCstDynamic.asDexValueConstDynamic().getValue();
+    if (!(classCstDynamic.getType().isIdenticalTo(factory.classDescType)
+        && classCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
+        && classCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
+        && classCstDynamic.getBootstrapMethodArguments().size() == 2
+        && methodHandleIsInvokeStaticTo(
+            classCstDynamic.getBootstrapMethodArguments().get(0), factory.classDescMethod))) {
+      throw throwEnumFieldConstantDynamic("Class descriptor " + classCstDynamic, context);
+    }
+    DexValue dexValueClassName = classCstDynamic.getBootstrapMethodArguments().get(1);
+    if (!dexValueClassName.isDexValueString()) {
+      throw throwEnumFieldConstantDynamic("Class name " + dexValueClassName, context);
+    }
+    DexString className = dexValueClassName.asDexValueString().getValue();
+    DexType enumType =
+        factory.createType(DescriptorUtils.javaTypeToDescriptor(className.toString()));
+    return getEnumField(fieldName, enumType, context, appView);
+  }
+
+  public static DexField getEnumField(
+      DexString fieldName, DexType enumType, DexClassAndMethod context, AppView<?> appView) {
+    DexClass enumClass = appView.definitionFor(enumType);
+    if (enumClass == null) {
+      throw throwEnumFieldConstantDynamic("Missing enum class " + enumType, context);
+    }
+    DexEncodedField dexEncodedField = enumClass.lookupUniqueStaticFieldWithName(fieldName);
+    if (dexEncodedField == null) {
+      throw throwEnumFieldConstantDynamic(
+          "Missing enum field " + fieldName + " in " + enumType, context);
+    }
+    return dexEncodedField.getReference();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java b/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
index 667479b..a25a3e3 100644
--- a/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
+++ b/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
@@ -5,6 +5,10 @@
 package com.android.tools.r8.shaking;
 
 import static com.android.tools.r8.ir.desugar.records.RecordRewriterHelper.isInvokeDynamicOnRecord;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.extractEnumField;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.getEnumField;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.isEnumSwitchCallSite;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.isTypeSwitchCallSite;
 import static com.android.tools.r8.utils.MapUtils.ignoreKey;
 
 import com.android.tools.r8.androidapi.AndroidApiLevelCompute;
@@ -15,10 +19,13 @@
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexMethodHandle;
 import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.DexValue;
 import com.android.tools.r8.graph.OriginalFieldWitness;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.code.InvokeType;
@@ -172,7 +179,7 @@
   }
 
   private void registerInstanceFieldReadFromRecordMethodHandle(DexField field) {
-    super.registerInstanceFieldWriteFromMethodHandle(field);
+    super.registerInstanceFieldReadFromMethodHandle(field);
     enqueuer.traceInstanceFieldReadFromRecordMethodHandle(field, getContext());
   }
 
@@ -206,6 +213,11 @@
     enqueuer.traceStaticFieldReadFromMethodHandle(field, getContext());
   }
 
+  private void registerStaticFieldReadFromSwitchMethodHandle(DexField field) {
+    super.registerStaticFieldReadFromMethodHandle(field);
+    enqueuer.traceStaticFieldReadFromSwitchMethodHandle(field, getContext());
+  }
+
   @Override
   public void registerStaticFieldWrite(DexField field) {
     super.registerStaticFieldWrite(field);
@@ -268,12 +280,57 @@
     super.registerCallSiteExceptBootstrapArgs(callSite);
     if (isInvokeDynamicOnRecord(callSite, appViewWithClassHierarchy, getContext())) {
       registerRecordCallSiteBootstrapArgs(callSite);
+    } else if (isTypeSwitchCallSite(callSite, appView.dexItemFactory())) {
+      registerTypeSwitchCallSiteBootstrapArgs(callSite);
+    } else if (isEnumSwitchCallSite(callSite, appView.dexItemFactory())) {
+      registerEnumSwitchCallSiteBootstrapArgs(callSite);
     } else {
       super.registerCallSiteBootstrapArgs(callSite, 0, callSite.bootstrapArgs.size());
     }
     enqueuer.traceCallSite(callSite, getContext(), this);
   }
 
+  private void registerEnumMethods(DexType enumType) {
+    DexItemFactory factory = dexItemFactory();
+    DexMethod values =
+        factory.createMethod(
+            enumType,
+            factory.createProto(factory.createArrayType(1, enumType)),
+            factory.valuesMethodName);
+    registerInvokeStatic(values);
+    DexMethod valueOf =
+        factory.createMethod(
+            enumType, factory.createProto(enumType, factory.stringType), factory.valueOfMethodName);
+    registerInvokeStatic(valueOf);
+  }
+
+  private void registerTypeSwitchCallSiteBootstrapArgs(DexCallSite callSite) {
+    for (DexValue bootstrapArg : callSite.bootstrapArgs) {
+      if (bootstrapArg.isDexValueType()) {
+        registerTypeReference(bootstrapArg.asDexValueType().value);
+      } else if (bootstrapArg.isDexValueConstDynamic()) {
+        DexField enumField =
+            extractEnumField(bootstrapArg.asDexValueConstDynamic(), getContext(), appView);
+        registerStaticFieldReadFromSwitchMethodHandle(enumField);
+        registerEnumMethods(enumField.getHolderType());
+      }
+    }
+  }
+
+  private void registerEnumSwitchCallSiteBootstrapArgs(DexCallSite callSite) {
+    DexType enumType = callSite.getMethodProto().getParameter(0);
+    for (DexValue bootstrapArg : callSite.bootstrapArgs) {
+      if (bootstrapArg.isDexValueType()) {
+        registerTypeReference(bootstrapArg.asDexValueType().value);
+      } else if (bootstrapArg.isDexValueString()) {
+        DexString fieldName = bootstrapArg.asDexValueString().value;
+        DexField enumField = getEnumField(fieldName, enumType, getContext(), appView);
+        registerStaticFieldReadFromSwitchMethodHandle(enumField);
+        registerEnumMethods(enumType);
+      }
+    }
+  }
+
   @SuppressWarnings("HidingField")
   private void registerRecordCallSiteBootstrapArgs(DexCallSite callSite) {
     // The Instance Get method handle in invokeDynamicOnRecord are considered:
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index a543c0a..54dc4c2 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -1804,12 +1804,15 @@
     private static final int DEFERRED_MASK = 1;
     private static final int FROM_METHOD_HANDLE_MASK = 2;
     private static final int FROM_RECORD_METHOD_HANDLE_MASK = 4;
+    private static final int FROM_SWITCH_METHOD_HANDLE_MASK = 8;
 
     static FieldAccessMetadata DEFAULT = new FieldAccessMetadata(0);
     static FieldAccessMetadata FROM_METHOD_HANDLE =
         new FieldAccessMetadata(FROM_METHOD_HANDLE_MASK);
     static FieldAccessMetadata FROM_RECORD_METHOD_HANDLE =
         new FieldAccessMetadata(FROM_RECORD_METHOD_HANDLE_MASK);
+    static FieldAccessMetadata FROM_SWITCH_METHOD_HANDLE =
+        new FieldAccessMetadata(FROM_SWITCH_METHOD_HANDLE_MASK);
 
     private final FieldAccessMetadata deferred;
     private final int flags;
@@ -1831,6 +1834,10 @@
       return (flags & FROM_RECORD_METHOD_HANDLE_MASK) != 0;
     }
 
+    boolean isFromSwitchMethodHandle() {
+      return (flags & FROM_SWITCH_METHOD_HANDLE_MASK) != 0;
+    }
+
     public FieldAccessMetadata toDeferred() {
       return deferred;
     }
@@ -1980,6 +1987,10 @@
     traceStaticFieldRead(field, currentMethod, FieldAccessMetadata.FROM_METHOD_HANDLE);
   }
 
+  void traceStaticFieldReadFromSwitchMethodHandle(DexField field, ProgramMethod currentMethod) {
+    traceStaticFieldRead(field, currentMethod, FieldAccessMetadata.FROM_SWITCH_METHOD_HANDLE);
+  }
+
   @SuppressWarnings("ReferenceEquality")
   void traceStaticFieldRead(
       DexField fieldReference, ProgramMethod currentMethod, FieldAccessMetadata metadata) {
@@ -2030,6 +2041,15 @@
 
           if (metadata.isFromMethodHandle()) {
             fieldAccessInfoCollection.get(field.getReference()).setReadFromMethodHandle();
+          } else if (metadata.isFromSwitchMethodHandle()) {
+            // TODO(b/340187630): This disables any optimization on such enum fields. We could
+            //  support rewriting fields in switch method handles instead.
+            keepInfo.joinClass(
+                field.getHolder(),
+                joiner -> joiner.disallowMinification().disallowOptimization().disallowShrinking());
+            keepInfo.joinField(
+                field,
+                joiner -> joiner.disallowMinification().disallowOptimization().disallowShrinking());
           }
 
           if (field.getReference() != fieldReference) {
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java b/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
index e57e31c..b2c99cb 100644
--- a/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
@@ -6,6 +6,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assume.assumeTrue;
 
+import com.android.tools.r8.JdkClassFileProvider;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestBuilder;
 import com.android.tools.r8.TestParameters;
@@ -122,13 +123,24 @@
 
   @Test
   public void testR8() throws Exception {
-    Assume.assumeTrue("For Cf we should compile with Jdk 21 library", parameters.isDexRuntime());
+    Assume.assumeTrue(
+        parameters.isDexRuntime()
+            || (parameters.isCfRuntime()
+                && parameters.getCfRuntime().isNewerThanOrEqual(CfVm.JDK21)));
     testForR8(parameters.getBackend())
         .apply(this::addModifiedProgramClasses)
+        .applyIf(
+            parameters.isCfRuntime(),
+            b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
         .run(parameters.getRuntime(), Main.class)
-        .assertSuccessWithOutput(String.format(EXPECTED_OUTPUT, "java.lang.RuntimeException"));
+        .assertSuccessWithOutput(
+            String.format(
+                EXPECTED_OUTPUT,
+                parameters.isCfRuntime()
+                    ? "java.lang.MatchException"
+                    : "java.lang.RuntimeException"));
   }
 
   // D is added to the list of permitted subclasses to reproduce the MatchException.
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java b/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java
index 32fad72..ce82a66 100644
--- a/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java
+++ b/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethod.java
@@ -6,6 +6,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assume.assumeTrue;
 
+import com.android.tools.r8.JdkClassFileProvider;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
@@ -95,9 +96,15 @@
 
   @Test
   public void testR8() throws Exception {
-    Assume.assumeTrue("For Cf we should compile with Jdk 21 library", parameters.isDexRuntime());
+    Assume.assumeTrue(
+        parameters.isDexRuntime()
+            || (parameters.isCfRuntime()
+                && parameters.getCfRuntime().isNewerThanOrEqual(CfVm.JDK21)));
     testForR8(parameters.getBackend())
         .addInnerClassesAndStrippedOuter(getClass())
+        .applyIf(
+            parameters.isCfRuntime(),
+            b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
         .run(parameters.getRuntime(), Main.class)
diff --git a/src/test/examplesJava21/switchpatternmatching/StringSwitchTest.java b/src/test/examplesJava21/switchpatternmatching/StringSwitchTest.java
index 0b5bfa4..7b1fb6f 100644
--- a/src/test/examplesJava21/switchpatternmatching/StringSwitchTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/StringSwitchTest.java
@@ -6,6 +6,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assume.assumeTrue;
 
+import com.android.tools.r8.JdkClassFileProvider;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
@@ -90,9 +91,15 @@
 
   @Test
   public void testR8() throws Exception {
-    Assume.assumeTrue("For Cf we should compile with Jdk 21 library", parameters.isDexRuntime());
+    Assume.assumeTrue(
+        parameters.isDexRuntime()
+            || (parameters.isCfRuntime()
+                && parameters.getCfRuntime().isNewerThanOrEqual(CfVm.JDK21)));
     testForR8(parameters.getBackend())
         .addInnerClassesAndStrippedOuter(getClass())
+        .applyIf(
+            parameters.isCfRuntime(),
+            b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
         .run(parameters.getRuntime(), Main.class)
diff --git a/src/test/examplesJava21/switchpatternmatching/TypeSwitchTest.java b/src/test/examplesJava21/switchpatternmatching/TypeSwitchTest.java
index 12e5fad..fa474c0 100644
--- a/src/test/examplesJava21/switchpatternmatching/TypeSwitchTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/TypeSwitchTest.java
@@ -6,6 +6,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assume.assumeTrue;
 
+import com.android.tools.r8.JdkClassFileProvider;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
@@ -97,9 +98,15 @@
 
   @Test
   public void testR8() throws Exception {
-    Assume.assumeTrue("For Cf we should compile with Jdk 21 library", parameters.isDexRuntime());
+    Assume.assumeTrue(
+        parameters.isDexRuntime()
+            || (parameters.isCfRuntime()
+                && parameters.getCfRuntime().isNewerThanOrEqual(CfVm.JDK21)));
     testForR8(parameters.getBackend())
         .addInnerClassesAndStrippedOuter(getClass())
+        .applyIf(
+            parameters.isCfRuntime(),
+            b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .addKeepMainRule(Main.class)
         .setMinApi(parameters)
         .run(parameters.getRuntime(), Main.class)
diff --git a/src/test/testbase/java/com/android/tools/r8/desugar/LibraryFilesHelper.java b/src/test/testbase/java/com/android/tools/r8/desugar/LibraryFilesHelper.java
index 0e07bf4..0a605cb 100644
--- a/src/test/testbase/java/com/android/tools/r8/desugar/LibraryFilesHelper.java
+++ b/src/test/testbase/java/com/android/tools/r8/desugar/LibraryFilesHelper.java
@@ -148,6 +148,7 @@
     return createClass(
         ACC_PUBLIC | ACC_INTERFACE | ACC_ABSTRACT,
         "java/util/function/Supplier",
+        "java/lang/Object",
         methodAdder ->
             methodAdder.add(
                 ACC_PUBLIC | ACC_ABSTRACT, "get", methodDescriptor(false, Object.class)));
@@ -156,7 +157,18 @@
   private static void addClassToZipBuilder(
       ZipBuilder builder, int access, String binaryName, Consumer<MethodAdder> consumer)
       throws Exception {
-    builder.addBytes(binaryName + ".class", createClass(access, binaryName, consumer));
+    addClassToZipBuilder(builder, access, binaryName, "java/lang/Object", consumer);
+  }
+
+  private static void addClassToZipBuilder(
+      ZipBuilder builder,
+      int access,
+      String binaryName,
+      String binarySuperName,
+      Consumer<MethodAdder> consumer)
+      throws Exception {
+    builder.addBytes(
+        binaryName + ".class", createClass(access, binaryName, binarySuperName, consumer));
   }
 
   @FunctionalInterface
@@ -181,9 +193,10 @@
     return sb.toString();
   }
 
-  public static byte[] createClass(int access, String binaryName, Consumer<MethodAdder> consumer) {
+  public static byte[] createClass(
+      int access, String binaryName, String binarySuperName, Consumer<MethodAdder> consumer) {
     ClassWriter cw = new ClassWriter(0);
-    cw.visit(V1_8, access, binaryName, null, "java/lang/Object", null);
+    cw.visit(V1_8, access, binaryName, null, binarySuperName, null);
 
     consumer.accept(
         (access1, name, descriptor1) -> cw.visitMethod(access1, name, descriptor1, null, null));