Generate one method per type switch

Bug: b/395999911
Change-Id: Ia7e916bd31c85c7f86062de3b48ad8c8d1bbb9c6
diff --git a/src/main/java/com/android/tools/r8/graph/UseRegistry.java b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
index 986267c..e3f2ba6 100644
--- a/src/main/java/com/android/tools/r8/graph/UseRegistry.java
+++ b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
@@ -265,11 +265,16 @@
             throw new CompilationError(
                 "Unsupported const dynamic in call site " + arg, getContext().getOrigin());
           }
-          DexField enumField =
-              TypeSwitchDesugaringHelper.extractEnumField(
-                  arg.asDexValueConstDynamic(), getMethodContext(), appView);
-          if (enumField != null) {
-            registerStaticFieldRead(enumField);
+          if (arg.asDexValueConstDynamic()
+              .getValue()
+              .getType()
+              .isIdenticalTo(appView.dexItemFactory().enumDescType)) {
+            DexField enumField =
+                TypeSwitchDesugaringHelper.extractEnumField(
+                    arg.asDexValueConstDynamic(), getMethodContext(), appView);
+            if (enumField != null) {
+              registerStaticFieldRead(enumField);
+            }
           }
           break;
         default:
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/SwitchHelperGenerator.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/SwitchHelperGenerator.java
index 95a5ea5..7934c56 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/SwitchHelperGenerator.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/SwitchHelperGenerator.java
@@ -4,103 +4,235 @@
 
 package com.android.tools.r8.ir.desugar.typeswitch;
 
-import com.android.tools.r8.cf.code.CfFrame;
-import com.android.tools.r8.cf.code.CfIf;
+import static com.android.tools.r8.ir.synthetic.TypeSwitchSyntheticCfCodeProvider.allowsInlinedIntegerEquality;
+
+import com.android.tools.r8.cf.code.CfConstNumber;
 import com.android.tools.r8.cf.code.CfInstruction;
-import com.android.tools.r8.cf.code.CfLabel;
-import com.android.tools.r8.cf.code.CfReturn;
-import com.android.tools.r8.cf.code.CfStackInstruction;
-import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
-import com.android.tools.r8.cf.code.CfStaticFieldRead;
+import com.android.tools.r8.cf.code.CfNewArray;
+import com.android.tools.r8.cf.code.CfReturnVoid;
 import com.android.tools.r8.cf.code.CfStaticFieldWrite;
+import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.CfCode;
+import com.android.tools.r8.graph.DexCallSite;
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.DexValue;
 import com.android.tools.r8.graph.FieldAccessFlags;
 import com.android.tools.r8.graph.MethodAccessFlags;
-import com.android.tools.r8.ir.code.IfType;
+import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.code.ValueType;
+import com.android.tools.r8.ir.desugar.CfInstructionDesugaringEventConsumer;
+import com.android.tools.r8.ir.synthetic.TypeSwitchSyntheticCfCodeProvider;
+import com.android.tools.r8.ir.synthetic.TypeSwitchSyntheticCfCodeProvider.Dispatcher;
 import com.android.tools.r8.synthesis.SyntheticProgramClassBuilder;
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.function.Consumer;
+import java.util.function.Function;
 
 public class SwitchHelperGenerator {
 
-  private final DexItemFactory factory;
-  private final DexType type;
-  private final DexField cacheField;
-  private final DexMethod getter;
+  private final AppView<?> appView;
+  private final DexCallSite dexCallSite;
+  private DexMethod dispatchMethod;
+  private DexMethod intEq;
+  private DexMethod enumEq;
+  private DexField enumCacheField;
+  private int enumCases = 0;
 
-  SwitchHelperGenerator(
+  SwitchHelperGenerator(AppView<?> appView, DexCallSite dexCallSite) {
+    this.appView = appView;
+    this.dexCallSite = dexCallSite;
+  }
+
+  public void build(
       SyntheticProgramClassBuilder builder,
-      AppView<?> appView,
-      Consumer<List<CfInstruction>> generator) {
-    this.factory = appView.dexItemFactory();
-    this.type = builder.getType();
-    this.cacheField =
-        factory.createField(type, factory.objectArrayType, factory.createString("switchCases"));
-    this.getter =
+      Scanner scanner,
+      Dispatcher dispatcher,
+      ProgramMethod context,
+      CfInstructionDesugaringEventConsumer eventConsumer,
+      MethodProcessingContext methodProcessingContext) {
+    DexItemFactory factory = appView.dexItemFactory();
+    scanArguments(
+        dexCallSite.bootstrapArgs, scanner, context, eventConsumer, methodProcessingContext);
+    DexEncodedMethod clinitMethod = null;
+    if (enumCases > 0) {
+      enumCacheField =
+          factory.createField(
+              builder.getType(), factory.objectArrayType, factory.createString("enumCache"));
+      synthesizeStaticField(builder);
+      clinitMethod = synthesizeClinit(enumCacheField);
+      enumEq = generateEnumEqMethod(context, eventConsumer, methodProcessingContext);
+    }
+    dispatchMethod =
         factory.createMethod(
-            type,
-            factory.createProto(factory.objectArrayType),
-            factory.createString("getSwitchCases"));
-    synthesizeStaticField(builder);
-    synthesizeStaticMethod(builder, generator);
+            builder.getType(), dexCallSite.methodProto, factory.createString("switchDispatch"));
+    DexEncodedMethod dispatchMethod =
+        synthesizeDispatchMethod(builder, dispatcher, dexCallSite, appView);
+    List<DexEncodedMethod> directMethods = new ArrayList<>();
+    directMethods.add(dispatchMethod);
+    if (clinitMethod != null) {
+      directMethods.add(clinitMethod);
+    }
+    builder.setDirectMethods(directMethods);
+  }
+
+  public DexMethod getDispatchMethod() {
+    return dispatchMethod;
+  }
+
+  @FunctionalInterface
+  public interface Scanner {
+
+    void scan(DexValue dexValue, Runnable intEqCheck, Runnable enumCase);
+  }
+
+  private void scanArguments(
+      List<DexValue> bootstrapArgs,
+      Scanner scanner,
+      ProgramMethod context,
+      CfInstructionDesugaringEventConsumer eventConsumer,
+      MethodProcessingContext methodProcessingContext) {
+    for (DexValue bootstrapArg : bootstrapArgs) {
+      scanner.scan(
+          bootstrapArg,
+          () -> {
+            DexItemFactory factory = appView.dexItemFactory();
+            DexType arg0Type = dexCallSite.methodProto.getParameter(0);
+            if (allowsInlinedIntegerEquality(arg0Type, factory)) {
+              return;
+            }
+            intEq = generateIntEqMethod(context, eventConsumer, methodProcessingContext);
+          },
+          () -> enumCases++);
+    }
+  }
+
+  private DexMethod generateEnumEqMethod(
+      ProgramMethod context,
+      TypeSwitchDesugaringEventConsumer eventConsumer,
+      MethodProcessingContext methodProcessingContext) {
+    DexItemFactory factory = appView.dexItemFactory();
+    DexProto proto =
+        factory.createProto(
+            factory.booleanType,
+            factory.objectType,
+            factory.objectArrayType,
+            factory.intType,
+            factory.stringType,
+            factory.stringType);
+    return generateMethod(
+        context,
+        eventConsumer,
+        methodProcessingContext,
+        proto,
+        methodSig -> TypeSwitchMethods.TypeSwitchMethods_switchEnumEq(factory, methodSig));
+  }
+
+  private DexMethod generateIntEqMethod(
+      ProgramMethod context,
+      TypeSwitchDesugaringEventConsumer eventConsumer,
+      MethodProcessingContext methodProcessingContext) {
+    DexItemFactory factory = appView.dexItemFactory();
+    DexProto proto = factory.createProto(factory.booleanType, factory.objectType, factory.intType);
+    return generateMethod(
+        context,
+        eventConsumer,
+        methodProcessingContext,
+        proto,
+        methodSig -> TypeSwitchMethods.TypeSwitchMethods_switchIntEq(factory, methodSig));
+  }
+
+  private DexMethod generateMethod(
+      ProgramMethod context,
+      TypeSwitchDesugaringEventConsumer eventConsumer,
+      MethodProcessingContext methodProcessingContext,
+      DexProto proto,
+      Function<DexMethod, CfCode> cfCodeGen) {
+    DexItemFactory factory = appView.dexItemFactory();
+    ProgramMethod method =
+        appView
+            .getSyntheticItems()
+            .createMethod(
+                kinds -> kinds.TYPE_SWITCH_HELPER,
+                methodProcessingContext.createUniqueContext(),
+                appView,
+                builder ->
+                    builder
+                        .disableAndroidApiLevelCheck()
+                        .setProto(proto)
+                        .setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
+                        .setCode(
+                            methodSig -> {
+                              CfCode code = cfCodeGen.apply(methodSig);
+                              if (appView.options().hasMappingFileSupport()) {
+                                return code.getCodeAsInlining(
+                                    methodSig, true, context.getReference(), false, factory);
+                              }
+                              return code;
+                            }));
+    eventConsumer.acceptTypeSwitchMethod(method, context);
+    return method.getReference();
   }
 
   private void synthesizeStaticField(SyntheticProgramClassBuilder builder) {
     builder.setStaticFields(
         ImmutableList.of(
             DexEncodedField.syntheticBuilder()
-                .setField(cacheField)
+                .setField(enumCacheField)
                 .setAccessFlags(FieldAccessFlags.createPublicStaticSynthetic())
                 .disableAndroidApiLevelCheck()
                 .build()));
   }
 
-  /**
-   * Generates the following code:
-   *
-   * <pre>
-   *   if (switchCases != null) {
-   *     return switchCases;
-   *   }
-   *   switchCases = <generate array from bootstrap method>;
-   *   return switchCases;
-   * </pre>
-   *
-   * We don't lock since the array generated is always the same and is never used in identity
-   * checks.
-   */
-  private void synthesizeStaticMethod(
-      SyntheticProgramClassBuilder builder, Consumer<List<CfInstruction>> generator) {
-    List<CfInstruction> instructions = new ArrayList<>();
-    instructions.add(new CfStaticFieldRead(cacheField));
-    CfLabel target = new CfLabel();
-    instructions.add(new CfIf(IfType.EQ, ValueType.OBJECT, target));
-    instructions.add(new CfStaticFieldRead(cacheField));
-    instructions.add(new CfReturn(ValueType.OBJECT));
-    instructions.add(target);
-    instructions.add(new CfFrame());
-    generator.accept(instructions);
-    instructions.add(new CfStackInstruction(Opcode.Dup));
-    instructions.add(new CfStaticFieldWrite(cacheField));
-    instructions.add(new CfReturn(ValueType.OBJECT));
+  private DexEncodedMethod synthesizeClinit(DexField enumCacheField) {
+    DexItemFactory factory = appView.dexItemFactory();
+    DexMethod clinitMethod = factory.createClinitMethod(enumCacheField.getHolderType());
+    return DexEncodedMethod.syntheticBuilder()
+        .setMethod(clinitMethod)
+        .setAccessFlags(MethodAccessFlags.createForClassInitializer())
+        .setCode(
+            new CfCode(enumCacheField.getHolderType(), 2, 1, instructionsForClinit(enumCacheField)))
+        .disableAndroidApiLevelCheck()
+        .build();
+  }
 
-    builder.setDirectMethods(
-        ImmutableList.of(
-            DexEncodedMethod.syntheticBuilder()
-                .setMethod(getter)
-                .setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
-                .setCode(new CfCode(type, 7, 3, instructions))
-                .disableAndroidApiLevelCheck()
-                .build()));
+  private List<CfInstruction> instructionsForClinit(DexField enumCacheField) {
+    DexItemFactory factory = appView.dexItemFactory();
+    List<CfInstruction> instructions = new ArrayList<>();
+    instructions.add(new CfConstNumber(enumCases, ValueType.INT));
+    instructions.add(new CfNewArray(factory.objectArrayType));
+    instructions.add(new CfStaticFieldWrite(enumCacheField));
+    instructions.add(new CfReturnVoid());
+    return instructions;
+  }
+
+  private DexEncodedMethod synthesizeDispatchMethod(
+      SyntheticProgramClassBuilder builder,
+      Dispatcher dispatcher,
+      DexCallSite dexCallSite,
+      AppView<?> appView) {
+    return DexEncodedMethod.syntheticBuilder()
+        .setMethod(dispatchMethod)
+        .setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
+        .setCode(
+            new TypeSwitchSyntheticCfCodeProvider(
+                    appView,
+                    builder.getType(),
+                    dexCallSite.methodProto.getParameter(0),
+                    dexCallSite.bootstrapArgs,
+                    dispatcher,
+                    intEq,
+                    enumEq,
+                    enumCacheField)
+                .generateCfCode())
+        .disableAndroidApiLevelCheck()
+        .build();
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
index 119ad06..776debb 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
@@ -4,47 +4,31 @@
 
 package com.android.tools.r8.ir.desugar.typeswitch;
 
-import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.extractEnumField;
-import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.getEnumField;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.dispatchEnumField;
 import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.isEnumSwitchCallSite;
 import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.isTypeSwitchCallSite;
+import static com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchDesugaringHelper.methodHandleIsInvokeStaticTo;
 
-import com.android.tools.r8.cf.code.CfArrayStore;
-import com.android.tools.r8.cf.code.CfConstClass;
-import com.android.tools.r8.cf.code.CfConstNull;
-import com.android.tools.r8.cf.code.CfConstNumber;
-import com.android.tools.r8.cf.code.CfConstString;
 import com.android.tools.r8.cf.code.CfInstruction;
 import com.android.tools.r8.cf.code.CfInvoke;
 import com.android.tools.r8.cf.code.CfNew;
-import com.android.tools.r8.cf.code.CfNewArray;
-import com.android.tools.r8.cf.code.CfStackInstruction;
-import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
-import com.android.tools.r8.cf.code.CfStaticFieldRead;
 import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
 import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.CfCode;
 import com.android.tools.r8.graph.DexCallSite;
-import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexValue;
-import com.android.tools.r8.graph.MethodAccessFlags;
 import com.android.tools.r8.graph.ProgramMethod;
-import com.android.tools.r8.ir.code.MemberType;
-import com.android.tools.r8.ir.code.ValueType;
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaring;
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaringEventConsumer;
 import com.android.tools.r8.ir.desugar.DesugarDescription;
-import com.android.tools.r8.ir.desugar.LocalStackAllocator;
+import com.android.tools.r8.ir.desugar.constantdynamic.ConstantDynamicReference;
+import com.android.tools.r8.ir.desugar.typeswitch.SwitchHelperGenerator.Scanner;
+import com.android.tools.r8.ir.synthetic.TypeSwitchSyntheticCfCodeProvider.Dispatcher;
 import com.google.common.collect.ImmutableList;
-import java.util.ArrayList;
-import java.util.Iterator;
 import java.util.List;
 import java.util.function.Consumer;
 import java.util.function.IntConsumer;
@@ -54,7 +38,7 @@
 
   private final AppView<?> appView;
 
-  private final DexProto switchHelperProto;
+  // private final DexProto switchHelperProto;
   private final DexType matchException;
   private final DexMethod matchExceptionInit;
   private final DexItemFactory factory;
@@ -62,9 +46,6 @@
   public TypeSwitchDesugaring(AppView<?> appView) {
     this.appView = appView;
     this.factory = appView.dexItemFactory();
-    switchHelperProto =
-        factory.createProto(
-            factory.intType, factory.objectType, factory.intType, factory.objectArrayType);
     matchException = factory.createType("Ljava/lang/MatchException;");
     matchExceptionInit =
         factory.createInstanceInitializer(
@@ -137,12 +118,12 @@
                   desugaringCollection,
                   dexItemFactory) ->
                   genSwitchMethod(
-                      localStackAllocator,
+                      callSite,
                       eventConsumer,
                       theContext,
                       methodProcessingContext,
-                      cfInstructions ->
-                          generateTypeSwitchLoadArguments(cfInstructions, callSite, context)))
+                      typeScanner(),
+                      typeDispatcher(context)))
           .build();
     }
     if (isEnumSwitchCallSite(callSite, factory)) {
@@ -159,25 +140,25 @@
                   desugaringCollection,
                   dexItemFactory) ->
                   genSwitchMethod(
-                      localStackAllocator,
+                      callSite,
                       eventConsumer,
                       theContext,
                       methodProcessingContext,
-                      cfInstructions ->
-                          generateEnumSwitchLoadArguments(
-                              cfInstructions, callSite, context, enumType)))
+                      enumScanner(),
+                      enumDispatcher(context, enumType)))
           .build();
     }
     return DesugarDescription.nothing();
   }
 
   private List<CfInstruction> genSwitchMethod(
-      LocalStackAllocator localStackAllocator,
+      DexCallSite dexCallSite,
       CfInstructionDesugaringEventConsumer eventConsumer,
-      ProgramMethod theContext,
+      ProgramMethod context,
       MethodProcessingContext methodProcessingContext,
-      Consumer<List<CfInstruction>> generator) {
-    localStackAllocator.allocateLocalStack(3);
+      Scanner scanner,
+      Dispatcher dispatcher) {
+    SwitchHelperGenerator gen = new SwitchHelperGenerator(appView, dexCallSite);
     DexProgramClass clazz =
         appView
             .getSyntheticItems()
@@ -185,119 +166,117 @@
                 kinds -> kinds.TYPE_SWITCH_CLASS,
                 methodProcessingContext.createUniqueContext(),
                 appView,
-                builder -> new SwitchHelperGenerator(builder, appView, generator));
-    eventConsumer.acceptTypeSwitchClass(clazz, theContext);
-    List<CfInstruction> cfInstructions = new ArrayList<>();
-    Iterator<DexEncodedMethod> iter = clazz.methods().iterator();
-    cfInstructions.add(new CfInvoke(Opcodes.INVOKESTATIC, iter.next().getReference(), false));
-    assert !iter.hasNext();
-    generateInvokeToDesugaredMethod(
-        cfInstructions, methodProcessingContext, theContext, eventConsumer);
-    return cfInstructions;
-  }
-
-  private void generateInvokeToDesugaredMethod(
-      List<CfInstruction> cfInstructions,
-      MethodProcessingContext methodProcessingContext,
-      ProgramMethod context,
-      CfInstructionDesugaringEventConsumer eventConsumer) {
-    ProgramMethod method =
-        appView
-            .getSyntheticItems()
-            .createMethod(
-                kinds -> kinds.TYPE_SWITCH_HELPER,
-                methodProcessingContext.createUniqueContext(),
-                appView,
                 builder ->
-                    builder
-                        .disableAndroidApiLevelCheck()
-                        .setProto(switchHelperProto)
-                        .setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
-                        .setCode(
-                            methodSig -> {
-                              CfCode code =
-                                  TypeSwitchMethods.TypeSwitchMethods_typeSwitch(
-                                      factory, methodSig);
-                              if (appView.options().hasMappingFileSupport()) {
-                                return code.getCodeAsInlining(
-                                    methodSig, true, context.getReference(), false, factory);
-                              }
-                              return code;
-                            }));
-    eventConsumer.acceptTypeSwitchMethod(method, context);
-    cfInstructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method.getReference(), false));
+                    gen.build(
+                        builder,
+                        scanner,
+                        dispatcher,
+                        context,
+                        eventConsumer,
+                        methodProcessingContext));
+    eventConsumer.acceptTypeSwitchClass(clazz, context);
+    assert gen.getDispatchMethod() != null;
+    return ImmutableList.of(new CfInvoke(Opcodes.INVOKESTATIC, gen.getDispatchMethod(), false));
   }
 
-  private void generateEnumSwitchLoadArguments(
-      List<CfInstruction> cfInstructions,
-      DexCallSite callSite,
+  private Dispatcher enumDispatcher(ProgramMethod context, DexType enumType) {
+    return (dexValue,
+        dexTypeConsumer,
+        intValueConsumer,
+        dexStringConsumer,
+        enumConsumer,
+        booleanConsumer,
+        numberConsumer) -> {
+      if (dexValue.isDexValueType()) {
+        dexTypeConsumer.accept(dexValue.asDexValueType().getValue());
+      } else if (dexValue.isDexValueString()) {
+        enumConsumer.accept(
+            factory.createString(enumType.getTypeName()), dexValue.asDexValueString().getValue());
+      } else {
+        throw new CompilationError(
+            "Invalid bootstrap arg for enum switch " + dexValue, context.getOrigin());
+      }
+    };
+  }
+
+  private Scanner enumScanner() {
+    return (dexValue, intEqCheck, enumCase) -> {
+      if (dexValue.isDexValueString()) {
+        enumCase.run();
+      }
+    };
+  }
+
+  private Dispatcher typeDispatcher(ProgramMethod context) {
+    return (dexValue,
+        dexTypeConsumer,
+        intValueConsumer,
+        dexStringConsumer,
+        enumConsumer,
+        booleanConsumer,
+        numberConsumer) -> {
+      if (dexValue.isDexValueType()) {
+        dexTypeConsumer.accept(dexValue.asDexValueType().getValue());
+      } else if (dexValue.isDexValueInt()) {
+        intValueConsumer.accept(dexValue.asDexValueInt().getValue());
+      } else if (dexValue.isDexValueString()) {
+        dexStringConsumer.accept(dexValue.asDexValueString().getValue());
+      } else if (dexValue.isDexValueConstDynamic()) {
+        ConstantDynamicReference constDynamic = dexValue.asDexValueConstDynamic().getValue();
+        if (constDynamic.getType().isIdenticalTo(factory.boxedBooleanType)) {
+          dispatchBooleanField(context, dexValue, booleanConsumer, constDynamic);
+        } else {
+          assert constDynamic.getType().isIdenticalTo(factory.enumDescType);
+          dispatchEnumField(enumConsumer, constDynamic, context, appView);
+        }
+      } else if (dexValue.isDexValueNumber()) {
+        assert dexValue.isDexValueDouble()
+            || dexValue.isDexValueFloat()
+            || dexValue.isDexValueLong();
+        numberConsumer.accept(dexValue.asDexValueNumber());
+      } else {
+        throw new CompilationError(
+            "Invalid bootstrap arg for type switch " + dexValue, context.getOrigin());
+      }
+    };
+  }
+
+  private void dispatchBooleanField(
       ProgramMethod context,
-      DexType enumType) {
-    generateSwitchLoadArguments(
-        cfInstructions,
-        callSite,
-        bootstrapArg -> {
-          if (bootstrapArg.isDexValueType()) {
-            cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
-          } else if (bootstrapArg.isDexValueString()) {
-            DexField enumField =
-                getEnumField(bootstrapArg.asDexValueString().getValue(), enumType, appView);
-            pushEnumField(cfInstructions, enumField);
-          } else {
-            throw new CompilationError(
-                "Invalid bootstrap arg for enum switch " + bootstrapArg, context.getOrigin());
-          }
-        });
-  }
-
-  private void generateTypeSwitchLoadArguments(
-      List<CfInstruction> cfInstructions, DexCallSite callSite, ProgramMethod context) {
-    generateSwitchLoadArguments(
-        cfInstructions,
-        callSite,
-        bootstrapArg -> {
-          if (bootstrapArg.isDexValueType()) {
-            cfInstructions.add(new CfConstClass(bootstrapArg.asDexValueType().getValue()));
-          } else if (bootstrapArg.isDexValueInt()) {
-            cfInstructions.add(
-                new CfConstNumber(bootstrapArg.asDexValueInt().getValue(), ValueType.INT));
-            cfInstructions.add(
-                new CfInvoke(Opcodes.INVOKESTATIC, factory.integerMembers.valueOf, false));
-          } else if (bootstrapArg.isDexValueString()) {
-            cfInstructions.add(new CfConstString(bootstrapArg.asDexValueString().getValue()));
-          } else if (bootstrapArg.isDexValueConstDynamic()) {
-            DexField enumField =
-                extractEnumField(bootstrapArg.asDexValueConstDynamic(), context, appView);
-            pushEnumField(cfInstructions, enumField);
-          } else {
-            throw new CompilationError(
-                "Invalid bootstrap arg for type switch " + bootstrapArg, context.getOrigin());
-          }
-        });
-  }
-
-  private void pushEnumField(List<CfInstruction> cfInstructions, DexField enumField) {
-    if (enumField == null) {
-      // Extremely rare case where the compilation is invalid, the case is unreachable.
-      cfInstructions.add(new CfConstNull());
-    } else {
-      cfInstructions.add(new CfStaticFieldRead(enumField));
+      DexValue dexValue,
+      Consumer<Boolean> booleanConsumer,
+      ConstantDynamicReference constDynamic) {
+    if (methodHandleIsInvokeStaticTo(
+        constDynamic.getBootstrapMethod(),
+        factory.createMethod(
+            factory.constantBootstrapsType,
+            factory.createProto(
+                factory.objectType,
+                factory.methodHandlesLookupType,
+                factory.stringType,
+                factory.classType),
+            "getStaticFinal"))) {
+      String name = constDynamic.getName().toString();
+      if (name.equals("TRUE")) {
+        booleanConsumer.accept(true);
+        return;
+      }
+      if (name.equals("FALSE")) {
+        booleanConsumer.accept(false);
+        return;
+      }
     }
+    throw new CompilationError(
+        "Invalid Boolean bootstrap arg for type switch " + dexValue, context.getOrigin());
   }
 
-  private void generateSwitchLoadArguments(
-      List<CfInstruction> cfInstructions, DexCallSite callSite, Consumer<DexValue> adder) {
-    // We need to call the method with the bootstrap args as parameters.
-    // We need to convert the bootstrap args into a list of cf instructions.
-    // The object and the int are already pushed on stack, we simply need to push the extra array.
-    cfInstructions.add(new CfConstNumber(callSite.bootstrapArgs.size(), ValueType.INT));
-    cfInstructions.add(new CfNewArray(factory.objectArrayType));
-    for (int i = 0; i < callSite.bootstrapArgs.size(); i++) {
-      DexValue bootstrapArg = callSite.bootstrapArgs.get(i);
-      cfInstructions.add(new CfStackInstruction(Opcode.Dup));
-      cfInstructions.add(new CfConstNumber(i, ValueType.INT));
-      adder.accept(bootstrapArg);
-      cfInstructions.add(new CfArrayStore(MemberType.OBJECT));
-    }
+  private Scanner typeScanner() {
+    return (dexValue, intEqCheck, enumCase) -> {
+      if (dexValue.isDexValueInt()) {
+        intEqCheck.run();
+      } else if (dexValue.isDexValueConstDynamic()) {
+        enumCase.run();
+      }
+    };
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaringHelper.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaringHelper.java
index 128bda9..da7fc02 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaringHelper.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaringHelper.java
@@ -21,6 +21,7 @@
 import com.android.tools.r8.graph.DexValue.DexValueConstDynamic;
 import com.android.tools.r8.ir.desugar.constantdynamic.ConstantDynamicReference;
 import com.android.tools.r8.utils.DescriptorUtils;
+import java.util.function.BiConsumer;
 
 public class TypeSwitchDesugaringHelper {
   private static CompilationError throwEnumFieldConstantDynamic(
@@ -54,7 +55,7 @@
     return methodHandleIsInvokeStaticTo(dexValue.asDexValueMethodHandle().getValue(), method);
   }
 
-  private static boolean methodHandleIsInvokeStaticTo(
+  public static boolean methodHandleIsInvokeStaticTo(
       DexMethodHandle methodHandle, DexMethod method) {
     return methodHandle.type.isInvokeStatic() && methodHandle.asMethod().isIdenticalTo(method);
   }
@@ -65,6 +66,49 @@
         && methodProto.getParameter(1).isIdenticalTo(intType);
   }
 
+  public static void dispatchEnumField(
+      BiConsumer<DexString, DexString> enumConsumer,
+      ConstantDynamicReference enumCstDynamic,
+      DexClassAndMethod context,
+      AppView<?> appView) {
+    DexItemFactory factory = appView.dexItemFactory();
+    DexMethod bootstrapMethod = factory.constantDynamicBootstrapMethod;
+    if (!(enumCstDynamic.getType().isIdenticalTo(factory.enumDescType)
+        && enumCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
+        && enumCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
+        && enumCstDynamic.getBootstrapMethodArguments().size() == 3
+        && methodHandleIsInvokeStaticTo(
+            enumCstDynamic.getBootstrapMethodArguments().get(0), factory.enumDescMethod))) {
+      throw throwEnumFieldConstantDynamic("Invalid EnumDesc", context);
+    }
+    DexValue dexValueFieldName = enumCstDynamic.getBootstrapMethodArguments().get(2);
+    if (!dexValueFieldName.isDexValueString()) {
+      throw throwEnumFieldConstantDynamic("Field name " + dexValueFieldName, context);
+    }
+    DexString fieldName = dexValueFieldName.asDexValueString().getValue();
+
+    DexValue dexValueClassCstDynamic = enumCstDynamic.getBootstrapMethodArguments().get(1);
+    if (!dexValueClassCstDynamic.isDexValueConstDynamic()) {
+      throw throwEnumFieldConstantDynamic("Enum class " + dexValueClassCstDynamic, context);
+    }
+    ConstantDynamicReference classCstDynamic =
+        dexValueClassCstDynamic.asDexValueConstDynamic().getValue();
+    if (!(classCstDynamic.getType().isIdenticalTo(factory.classDescType)
+        && classCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
+        && classCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
+        && classCstDynamic.getBootstrapMethodArguments().size() == 2
+        && methodHandleIsInvokeStaticTo(
+            classCstDynamic.getBootstrapMethodArguments().get(0), factory.classDescMethod))) {
+      throw throwEnumFieldConstantDynamic("Class descriptor " + classCstDynamic, context);
+    }
+    DexValue dexValueClassName = classCstDynamic.getBootstrapMethodArguments().get(1);
+    if (!dexValueClassName.isDexValueString()) {
+      throw throwEnumFieldConstantDynamic("Class name " + dexValueClassName, context);
+    }
+    DexString className = dexValueClassName.asDexValueString().getValue();
+    enumConsumer.accept(className, fieldName);
+  }
+
   public static DexField extractEnumField(
       DexValueConstDynamic dexValueConstDynamic, DexClassAndMethod context, AppView<?> appView) {
     DexItemFactory factory = appView.dexItemFactory();
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchMethods.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchMethods.java
index ed9cbdf..f48c2a0 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchMethods.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchMethods.java
@@ -8,21 +8,24 @@
 
 package com.android.tools.r8.ir.desugar.typeswitch;
 
-import com.android.tools.r8.cf.code.CfArrayLength;
 import com.android.tools.r8.cf.code.CfArrayLoad;
+import com.android.tools.r8.cf.code.CfArrayStore;
 import com.android.tools.r8.cf.code.CfCheckCast;
+import com.android.tools.r8.cf.code.CfConstNull;
 import com.android.tools.r8.cf.code.CfConstNumber;
 import com.android.tools.r8.cf.code.CfFrame;
 import com.android.tools.r8.cf.code.CfGoto;
 import com.android.tools.r8.cf.code.CfIf;
 import com.android.tools.r8.cf.code.CfIfCmp;
-import com.android.tools.r8.cf.code.CfIinc;
 import com.android.tools.r8.cf.code.CfInstanceOf;
 import com.android.tools.r8.cf.code.CfInvoke;
 import com.android.tools.r8.cf.code.CfLabel;
 import com.android.tools.r8.cf.code.CfLoad;
+import com.android.tools.r8.cf.code.CfNew;
 import com.android.tools.r8.cf.code.CfReturn;
+import com.android.tools.r8.cf.code.CfStackInstruction;
 import com.android.tools.r8.cf.code.CfStore;
+import com.android.tools.r8.cf.code.CfTryCatch;
 import com.android.tools.r8.cf.code.frame.FrameType;
 import com.android.tools.r8.graph.CfCode;
 import com.android.tools.r8.graph.DexItemFactory;
@@ -32,14 +35,18 @@
 import com.android.tools.r8.ir.code.ValueType;
 import com.google.common.collect.ImmutableList;
 import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
+import java.util.ArrayDeque;
+import java.util.Arrays;
 
 public final class TypeSwitchMethods {
 
   public static void registerSynthesizedCodeReferences(DexItemFactory factory) {
+    factory.createSynthesizedType("Ljava/lang/Enum;");
+    factory.createSynthesizedType("Ljava/lang/Number;");
     factory.createSynthesizedType("[Ljava/lang/Object;");
   }
 
-  public static CfCode TypeSwitchMethods_typeSwitch(DexItemFactory factory, DexMethod method) {
+  public static CfCode TypeSwitchMethods_switchEnumEq(DexItemFactory factory, DexMethod method) {
     CfLabel label0 = new CfLabel();
     CfLabel label1 = new CfLabel();
     CfLabel label2 = new CfLabel();
@@ -53,107 +60,247 @@
     CfLabel label10 = new CfLabel();
     CfLabel label11 = new CfLabel();
     CfLabel label12 = new CfLabel();
+    CfLabel label13 = new CfLabel();
+    CfLabel label14 = new CfLabel();
     return new CfCode(
         method.holder,
-        2,
-        5,
+        4,
+        8,
         ImmutableList.of(
             label0,
-            new CfLoad(ValueType.OBJECT, 0),
-            new CfIf(IfType.NE, ValueType.OBJECT, label2),
-            label1,
-            new CfConstNumber(-1, ValueType.INT),
-            new CfReturn(ValueType.INT),
-            label2,
-            new CfFrame(
-                new Int2ObjectAVLTreeMap<>(
-                    new int[] {0, 1, 2},
-                    new FrameType[] {
-                      FrameType.initializedNonNullReference(factory.objectType),
-                      FrameType.intType(),
-                      FrameType.initializedNonNullReference(
-                          factory.createType("[Ljava/lang/Object;"))
-                    })),
-            new CfLoad(ValueType.INT, 1),
-            new CfStore(ValueType.INT, 3),
-            label3,
-            new CfFrame(
-                new Int2ObjectAVLTreeMap<>(
-                    new int[] {0, 1, 2, 3},
-                    new FrameType[] {
-                      FrameType.initializedNonNullReference(factory.objectType),
-                      FrameType.intType(),
-                      FrameType.initializedNonNullReference(
-                          factory.createType("[Ljava/lang/Object;")),
-                      FrameType.intType()
-                    })),
-            new CfLoad(ValueType.INT, 3),
-            new CfLoad(ValueType.OBJECT, 2),
-            new CfArrayLength(),
-            new CfIfCmp(IfType.GE, ValueType.INT, label11),
-            label4,
-            new CfLoad(ValueType.OBJECT, 2),
-            new CfLoad(ValueType.INT, 3),
+            new CfLoad(ValueType.OBJECT, 1),
+            new CfLoad(ValueType.INT, 2),
             new CfArrayLoad(MemberType.OBJECT),
-            new CfStore(ValueType.OBJECT, 4),
-            label5,
-            new CfLoad(ValueType.OBJECT, 4),
-            new CfInstanceOf(factory.classType),
-            new CfIf(IfType.EQ, ValueType.INT, label8),
-            label6,
-            new CfLoad(ValueType.OBJECT, 4),
-            new CfCheckCast(factory.classType),
-            new CfLoad(ValueType.OBJECT, 0),
+            new CfIf(IfType.NE, ValueType.OBJECT, label11),
+            label1,
+            new CfConstNull(),
+            new CfStore(ValueType.OBJECT, 5),
+            label2,
+            new CfLoad(ValueType.OBJECT, 3),
+            new CfInvoke(
+                184,
+                factory.createMethod(
+                    factory.classType,
+                    factory.createProto(factory.classType, factory.stringType),
+                    factory.createString("forName")),
+                false),
+            new CfStore(ValueType.OBJECT, 6),
+            label3,
+            new CfLoad(ValueType.OBJECT, 6),
             new CfInvoke(
                 182,
                 factory.createMethod(
                     factory.classType,
-                    factory.createProto(factory.booleanType, factory.objectType),
-                    factory.createString("isInstance")),
+                    factory.createProto(factory.booleanType),
+                    factory.createString("isEnum")),
                 false),
-            new CfIf(IfType.EQ, ValueType.INT, label10),
+            new CfIf(IfType.EQ, ValueType.INT, label6),
+            label4,
+            new CfLoad(ValueType.OBJECT, 6),
+            new CfStore(ValueType.OBJECT, 7),
+            label5,
+            new CfLoad(ValueType.OBJECT, 7),
+            new CfLoad(ValueType.OBJECT, 4),
+            new CfInvoke(
+                184,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/Enum;"),
+                    factory.createProto(
+                        factory.createType("Ljava/lang/Enum;"),
+                        factory.classType,
+                        factory.stringType),
+                    factory.createString("valueOf")),
+                false),
+            new CfStore(ValueType.OBJECT, 5),
+            label6,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4, 5},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(
+                          factory.createType("[Ljava/lang/Object;")),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.objectType)
+                    })),
+            new CfGoto(label8),
             label7,
-            new CfLoad(ValueType.INT, 3),
-            new CfReturn(ValueType.INT),
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4, 5},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(
+                          factory.createType("[Ljava/lang/Object;")),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.objectType)
+                    }),
+                new ArrayDeque<>(
+                    Arrays.asList(FrameType.initializedNonNullReference(factory.throwableType)))),
+            new CfStore(ValueType.OBJECT, 6),
             label8,
             new CfFrame(
                 new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4, 5},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(
+                          factory.createType("[Ljava/lang/Object;")),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.objectType)
+                    })),
+            new CfLoad(ValueType.OBJECT, 1),
+            new CfLoad(ValueType.INT, 2),
+            new CfLoad(ValueType.OBJECT, 5),
+            new CfIf(IfType.NE, ValueType.OBJECT, label9),
+            new CfNew(factory.objectType),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfInvoke(
+                183,
+                factory.createMethod(
+                    factory.objectType,
+                    factory.createProto(factory.voidType),
+                    factory.createString("<init>")),
+                false),
+            new CfGoto(label10),
+            label9,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4, 5},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(
+                          factory.createType("[Ljava/lang/Object;")),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.objectType)
+                    }),
+                new ArrayDeque<>(
+                    Arrays.asList(
+                        FrameType.initializedNonNullReference(
+                            factory.createType("[Ljava/lang/Object;")),
+                        FrameType.intType()))),
+            new CfLoad(ValueType.OBJECT, 5),
+            label10,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4, 5},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(
+                          factory.createType("[Ljava/lang/Object;")),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.objectType)
+                    }),
+                new ArrayDeque<>(
+                    Arrays.asList(
+                        FrameType.initializedNonNullReference(
+                            factory.createType("[Ljava/lang/Object;")),
+                        FrameType.intType(),
+                        FrameType.initializedNonNullReference(factory.objectType)))),
+            new CfArrayStore(MemberType.OBJECT),
+            label11,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
                     new int[] {0, 1, 2, 3, 4},
                     new FrameType[] {
                       FrameType.initializedNonNullReference(factory.objectType),
-                      FrameType.intType(),
                       FrameType.initializedNonNullReference(
                           factory.createType("[Ljava/lang/Object;")),
                       FrameType.intType(),
-                      FrameType.initializedNonNullReference(factory.objectType)
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.stringType)
                     })),
             new CfLoad(ValueType.OBJECT, 0),
-            new CfLoad(ValueType.OBJECT, 4),
+            new CfLoad(ValueType.OBJECT, 1),
+            new CfLoad(ValueType.INT, 2),
+            new CfArrayLoad(MemberType.OBJECT),
+            new CfIfCmp(IfType.NE, ValueType.OBJECT, label12),
+            new CfConstNumber(1, ValueType.INT),
+            new CfGoto(label13),
+            label12,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(
+                          factory.createType("[Ljava/lang/Object;")),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.stringType)
+                    })),
+            new CfConstNumber(0, ValueType.INT),
+            label13,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(
+                          factory.createType("[Ljava/lang/Object;")),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.stringType)
+                    }),
+                new ArrayDeque<>(Arrays.asList(FrameType.intType()))),
+            new CfReturn(ValueType.INT),
+            label14),
+        ImmutableList.of(
+            new CfTryCatch(
+                label2, label6, ImmutableList.of(factory.throwableType), ImmutableList.of(label7))),
+        ImmutableList.of());
+  }
+
+  public static CfCode TypeSwitchMethods_switchIntEq(DexItemFactory factory, DexMethod method) {
+    CfLabel label0 = new CfLabel();
+    CfLabel label1 = new CfLabel();
+    CfLabel label2 = new CfLabel();
+    CfLabel label3 = new CfLabel();
+    CfLabel label4 = new CfLabel();
+    CfLabel label5 = new CfLabel();
+    CfLabel label6 = new CfLabel();
+    CfLabel label7 = new CfLabel();
+    CfLabel label8 = new CfLabel();
+    CfLabel label9 = new CfLabel();
+    CfLabel label10 = new CfLabel();
+    CfLabel label11 = new CfLabel();
+    return new CfCode(
+        method.holder,
+        2,
+        3,
+        ImmutableList.of(
+            label0,
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceOf(factory.createType("Ljava/lang/Number;")),
+            new CfIf(IfType.EQ, ValueType.INT, label5),
+            label1,
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfCheckCast(factory.createType("Ljava/lang/Number;")),
+            new CfStore(ValueType.OBJECT, 2),
+            label2,
+            new CfLoad(ValueType.INT, 1),
+            new CfLoad(ValueType.OBJECT, 2),
             new CfInvoke(
                 182,
                 factory.createMethod(
-                    factory.objectType,
-                    factory.createProto(factory.booleanType, factory.objectType),
-                    factory.createString("equals")),
+                    factory.createType("Ljava/lang/Number;"),
+                    factory.createProto(factory.intType),
+                    factory.createString("intValue")),
                 false),
-            new CfIf(IfType.EQ, ValueType.INT, label10),
-            label9,
-            new CfLoad(ValueType.INT, 3),
-            new CfReturn(ValueType.INT),
-            label10,
-            new CfFrame(
-                new Int2ObjectAVLTreeMap<>(
-                    new int[] {0, 1, 2, 3},
-                    new FrameType[] {
-                      FrameType.initializedNonNullReference(factory.objectType),
-                      FrameType.intType(),
-                      FrameType.initializedNonNullReference(
-                          factory.createType("[Ljava/lang/Object;")),
-                      FrameType.intType()
-                    })),
-            new CfIinc(3, 1),
-            new CfGoto(label3),
-            label11,
+            new CfIfCmp(IfType.NE, ValueType.INT, label3),
+            new CfConstNumber(1, ValueType.INT),
+            new CfGoto(label4),
+            label3,
             new CfFrame(
                 new Int2ObjectAVLTreeMap<>(
                     new int[] {0, 1, 2},
@@ -161,11 +308,79 @@
                       FrameType.initializedNonNullReference(factory.objectType),
                       FrameType.intType(),
                       FrameType.initializedNonNullReference(
-                          factory.createType("[Ljava/lang/Object;"))
+                          factory.createType("Ljava/lang/Number;"))
                     })),
-            new CfConstNumber(-2, ValueType.INT),
+            new CfConstNumber(0, ValueType.INT),
+            label4,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(
+                          factory.createType("Ljava/lang/Number;"))
+                    }),
+                new ArrayDeque<>(Arrays.asList(FrameType.intType()))),
             new CfReturn(ValueType.INT),
-            label12),
+            label5,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType), FrameType.intType()
+                    })),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceOf(factory.boxedCharType),
+            new CfIf(IfType.EQ, ValueType.INT, label10),
+            label6,
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfCheckCast(factory.boxedCharType),
+            new CfStore(ValueType.OBJECT, 2),
+            label7,
+            new CfLoad(ValueType.INT, 1),
+            new CfLoad(ValueType.OBJECT, 2),
+            new CfInvoke(
+                182,
+                factory.createMethod(
+                    factory.boxedCharType,
+                    factory.createProto(factory.charType),
+                    factory.createString("charValue")),
+                false),
+            new CfIfCmp(IfType.NE, ValueType.INT, label8),
+            new CfConstNumber(1, ValueType.INT),
+            new CfGoto(label9),
+            label8,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.boxedCharType)
+                    })),
+            new CfConstNumber(0, ValueType.INT),
+            label9,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.intType(),
+                      FrameType.initializedNonNullReference(factory.boxedCharType)
+                    }),
+                new ArrayDeque<>(Arrays.asList(FrameType.intType()))),
+            new CfReturn(ValueType.INT),
+            label10,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.objectType), FrameType.intType()
+                    })),
+            new CfConstNumber(0, ValueType.INT),
+            new CfReturn(ValueType.INT),
+            label11),
         ImmutableList.of(),
         ImmutableList.of());
   }
diff --git a/src/main/java/com/android/tools/r8/ir/synthetic/TypeSwitchSyntheticCfCodeProvider.java b/src/main/java/com/android/tools/r8/ir/synthetic/TypeSwitchSyntheticCfCodeProvider.java
new file mode 100644
index 0000000..0edb7aa
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/synthetic/TypeSwitchSyntheticCfCodeProvider.java
@@ -0,0 +1,274 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.synthetic;
+
+import static com.android.tools.r8.naming.dexitembasedstring.ClassNameComputationInfo.ClassNameMapping.NAME;
+
+import com.android.tools.r8.cf.code.CfConstNumber;
+import com.android.tools.r8.cf.code.CfConstString;
+import com.android.tools.r8.cf.code.CfDexItemBasedConstString;
+import com.android.tools.r8.cf.code.CfFrame;
+import com.android.tools.r8.cf.code.CfIf;
+import com.android.tools.r8.cf.code.CfIfCmp;
+import com.android.tools.r8.cf.code.CfInstanceOf;
+import com.android.tools.r8.cf.code.CfInstruction;
+import com.android.tools.r8.cf.code.CfInvoke;
+import com.android.tools.r8.cf.code.CfLabel;
+import com.android.tools.r8.cf.code.CfLoad;
+import com.android.tools.r8.cf.code.CfReturn;
+import com.android.tools.r8.cf.code.CfStackInstruction;
+import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
+import com.android.tools.r8.cf.code.CfStaticFieldRead;
+import com.android.tools.r8.cf.code.CfSwitch;
+import com.android.tools.r8.cf.code.CfSwitch.Kind;
+import com.android.tools.r8.cf.code.frame.FrameType;
+import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.CfCode;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.DexValue;
+import com.android.tools.r8.graph.DexValue.DexValueNumber;
+import com.android.tools.r8.ir.code.IfType;
+import com.android.tools.r8.ir.code.ValueType;
+import com.android.tools.r8.naming.dexitembasedstring.ClassNameComputationInfo;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.DescriptorUtils;
+import com.android.tools.r8.utils.IntBox;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.IntConsumer;
+import org.objectweb.asm.Opcodes;
+
+public class TypeSwitchSyntheticCfCodeProvider extends SyntheticCfCodeProvider {
+
+  private final List<DexValue> bootstrapArgs;
+  private final DexType arg0Type;
+  private final Dispatcher dispatcher;
+  private final DexMethod intEq;
+  private final DexMethod enumEq;
+  private final DexField enumFieldCache;
+
+  @FunctionalInterface
+  public interface Dispatcher {
+    void generate(
+        DexValue dexValue,
+        Consumer<DexType> dexTypeConsumer,
+        IntConsumer intValueConsumer,
+        Consumer<DexString> dexStringConsumer,
+        BiConsumer<DexString, DexString> enumConsumer,
+        Consumer<Boolean> booleanConsumer,
+        Consumer<DexValueNumber> numberConsumer);
+  }
+
+  public TypeSwitchSyntheticCfCodeProvider(
+      AppView<?> appView,
+      DexType holder,
+      DexType arg0Type,
+      List<DexValue> bootstrapArgs,
+      Dispatcher dispatcher,
+      DexMethod intEq,
+      DexMethod enumEq,
+      DexField enumFieldCache) {
+    super(appView, holder);
+    this.arg0Type = arg0Type;
+    this.bootstrapArgs = bootstrapArgs;
+    this.dispatcher = dispatcher;
+    this.intEq = intEq;
+    this.enumEq = enumEq;
+    this.enumFieldCache = enumFieldCache;
+  }
+
+  @Override
+  public CfCode generateCfCode() {
+    // arg 0: Object obj
+    // arg 1: int restart
+    DexItemFactory factory = appView.dexItemFactory();
+    List<CfInstruction> instructions = new ArrayList<>();
+
+    CfFrame frame =
+        CfFrame.builder()
+            .appendLocal(FrameType.initialized(arg0Type))
+            .appendLocal(FrameType.intType())
+            .build();
+
+    // Objects.checkIndex(restart, length + 1);
+    instructions.add(new CfLoad(ValueType.INT, 1));
+    instructions.add(new CfConstNumber(bootstrapArgs.size() + 1, ValueType.INT));
+    DexMethod checkIndex =
+        factory.createMethod(
+            factory.objectsType,
+            factory.createProto(factory.intType, factory.intType, factory.intType),
+            "checkIndex");
+    instructions.add(new CfInvoke(Opcodes.INVOKESTATIC, checkIndex, false));
+    instructions.add(new CfStackInstruction(Opcode.Pop));
+
+    // if (obj == null) { return -1; }
+    instructions.add(new CfLoad(ValueType.OBJECT, 0));
+    CfLabel nonNull = new CfLabel();
+    instructions.add(new CfIf(IfType.NE, ValueType.OBJECT, nonNull));
+    instructions.add(new CfConstNumber(-1, ValueType.INT));
+    instructions.add(new CfReturn(ValueType.INT));
+    instructions.add(nonNull);
+    instructions.add(frame);
+
+    // If no cases, return 0;
+    if (bootstrapArgs.isEmpty()) {
+      instructions.add(new CfConstNumber(0, ValueType.INT));
+      instructions.add(new CfReturn(ValueType.INT));
+    }
+
+    // The tableSwitch for the restart dispatch.
+    CfLabel defaultLabel = new CfLabel();
+    List<CfLabel> cfLabels = new ArrayList<>();
+    for (int i = 0; i < bootstrapArgs.size(); i++) {
+      cfLabels.add(new CfLabel());
+    }
+    cfLabels.add(defaultLabel);
+    instructions.add(new CfLoad(ValueType.INT, 1));
+    instructions.add(new CfSwitch(Kind.TABLE, defaultLabel, new int[] {0}, cfLabels));
+
+    IntBox index = new IntBox(0);
+    IntBox enumIndex = new IntBox(0);
+    bootstrapArgs.forEach(
+        dexValue ->
+            dispatcher.generate(
+                dexValue,
+                dexType -> {
+                  instructions.add(cfLabels.get(index.get()));
+                  instructions.add(frame);
+                  instructions.add(new CfLoad(ValueType.OBJECT, 0));
+                  instructions.add(new CfInstanceOf(dexType));
+                  instructions.add(
+                      new CfIf(IfType.EQ, ValueType.INT, cfLabels.get(index.get() + 1)));
+                  instructions.add(new CfConstNumber(index.getAndIncrement(), ValueType.INT));
+                  instructions.add(new CfReturn(ValueType.INT));
+                },
+                intValue -> {
+                  instructions.add(cfLabels.get(index.get()));
+                  instructions.add(frame);
+                  instructions.add(new CfLoad(ValueType.OBJECT, 0));
+                  if (allowsInlinedIntegerEquality(arg0Type, factory)) {
+                    instructions.add(
+                        new CfInvoke(
+                            Opcodes.INVOKEVIRTUAL,
+                            factory.unboxPrimitiveMethod.get(arg0Type),
+                            false));
+                    instructions.add(new CfConstNumber(intValue, ValueType.INT));
+                    instructions.add(
+                        new CfIfCmp(IfType.NE, ValueType.INT, cfLabels.get(index.get() + 1)));
+                  } else {
+                    instructions.add(new CfConstNumber(intValue, ValueType.INT));
+                    assert intEq != null;
+                    instructions.add(new CfInvoke(Opcodes.INVOKESTATIC, intEq, false));
+                    instructions.add(
+                        new CfIf(IfType.NE, ValueType.INT, cfLabels.get(index.get() + 1)));
+                  }
+                  instructions.add(new CfConstNumber(index.getAndIncrement(), ValueType.INT));
+                  instructions.add(new CfReturn(ValueType.INT));
+                },
+                dexString -> {
+                  instructions.add(cfLabels.get(index.get()));
+                  instructions.add(frame);
+                  instructions.add(new CfLoad(ValueType.OBJECT, 0));
+                  instructions.add(new CfConstString(dexString));
+                  instructions.add(
+                      new CfInvoke(Opcodes.INVOKEVIRTUAL, factory.objectMembers.equals, false));
+                  instructions.add(
+                      new CfIf(IfType.EQ, ValueType.INT, cfLabels.get(index.get() + 1)));
+                  instructions.add(new CfConstNumber(index.getAndIncrement(), ValueType.INT));
+                  instructions.add(new CfReturn(ValueType.INT));
+                },
+                (enumClass, enumField) -> {
+                  instructions.add(cfLabels.get(index.get()));
+                  instructions.add(frame);
+                  // TODO(b/399808482): In R8 release, we can analyze at compile-time program enum
+                  //  and generate a fast check based on the field. But these information are not
+                  //  available in Cf instructions.
+                  instructions.add(new CfLoad(ValueType.OBJECT, 0));
+                  assert enumFieldCache != null;
+                  instructions.add(new CfStaticFieldRead(enumFieldCache));
+                  instructions.add(new CfConstNumber(enumIndex.getAndIncrement(), ValueType.INT));
+                  if (appView.enableWholeProgramOptimizations()) {
+                    DexType type =
+                        factory.createType(
+                            DescriptorUtils.javaTypeToDescriptor(enumClass.toString()));
+                    instructions.add(
+                        new CfDexItemBasedConstString(
+                            type,
+                            ClassNameComputationInfo.create(NAME, type.getArrayTypeDimensions())));
+                  } else {
+                    instructions.add(new CfConstString(enumClass));
+                  }
+                  instructions.add(new CfConstString(enumField));
+                  assert enumEq != null;
+                  instructions.add(new CfInvoke(Opcodes.INVOKESTATIC, enumEq, false));
+                  instructions.add(
+                      new CfIf(IfType.EQ, ValueType.INT, cfLabels.get(index.get() + 1)));
+                  instructions.add(new CfConstNumber(index.getAndIncrement(), ValueType.INT));
+                  instructions.add(new CfReturn(ValueType.INT));
+                },
+                bool -> {
+                  instructions.add(cfLabels.get(index.get()));
+                  instructions.add(frame);
+                  instructions.add(new CfLoad(ValueType.OBJECT, 0));
+                  instructions.add(new CfConstNumber(BooleanUtils.intValue(bool), ValueType.INT));
+                  instructions.add(
+                      new CfInvoke(Opcodes.INVOKESTATIC, factory.booleanMembers.valueOf, false));
+                  instructions.add(
+                      new CfInvoke(Opcodes.INVOKEVIRTUAL, factory.objectMembers.equals, false));
+                  instructions.add(
+                      new CfIf(IfType.EQ, ValueType.INT, cfLabels.get(index.get() + 1)));
+                  instructions.add(new CfConstNumber(index.getAndIncrement(), ValueType.INT));
+                  instructions.add(new CfReturn(ValueType.INT));
+                },
+                dexNumber -> {
+                  instructions.add(cfLabels.get(index.get()));
+                  instructions.add(frame);
+                  instructions.add(new CfLoad(ValueType.OBJECT, 0));
+                  if (dexNumber.isDexValueFloat()) {
+                    instructions.add(new CfConstNumber(dexNumber.getRawValue(), ValueType.FLOAT));
+                    instructions.add(
+                        new CfInvoke(Opcodes.INVOKESTATIC, factory.floatMembers.valueOf, false));
+                  } else if (dexNumber.isDexValueDouble()) {
+                    instructions.add(new CfConstNumber(dexNumber.getRawValue(), ValueType.DOUBLE));
+                    instructions.add(
+                        new CfInvoke(Opcodes.INVOKESTATIC, factory.doubleMembers.valueOf, false));
+                  } else if (dexNumber.isDexValueLong()) {
+                    instructions.add(new CfConstNumber(dexNumber.getRawValue(), ValueType.LONG));
+                    instructions.add(
+                        new CfInvoke(Opcodes.INVOKESTATIC, factory.longMembers.valueOf, false));
+                  } else {
+                    throw new CompilationError(
+                        "Unexpected dexNumber in type switch desugaring " + dexNumber);
+                  }
+                  instructions.add(
+                      new CfInvoke(Opcodes.INVOKEVIRTUAL, factory.objectMembers.equals, false));
+                  instructions.add(
+                      new CfIf(IfType.EQ, ValueType.INT, cfLabels.get(index.get() + 1)));
+                  instructions.add(new CfConstNumber(index.getAndIncrement(), ValueType.INT));
+                  instructions.add(new CfReturn(ValueType.INT));
+                }));
+
+    assert index.get() == bootstrapArgs.size();
+    instructions.add(defaultLabel);
+    instructions.add(frame);
+    instructions.add(new CfConstNumber(-2, ValueType.INT));
+    instructions.add(new CfReturn(ValueType.INT));
+    return standardCfCodeFromInstructions(instructions);
+  }
+
+  public static boolean allowsInlinedIntegerEquality(DexType arg0Type, DexItemFactory factory) {
+    return arg0Type.isIdenticalTo(factory.boxedByteType)
+        || arg0Type.isIdenticalTo(factory.boxedCharType)
+        || arg0Type.isIdenticalTo(factory.boxedShortType)
+        || arg0Type.isIdenticalTo(factory.boxedIntType);
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java b/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
index 73429d6..14877d0 100644
--- a/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
+++ b/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
@@ -308,7 +308,12 @@
     for (DexValue bootstrapArg : callSite.bootstrapArgs) {
       if (bootstrapArg.isDexValueType()) {
         registerTypeReference(bootstrapArg.asDexValueType().value);
-      } else if (bootstrapArg.isDexValueConstDynamic()) {
+      } else if (bootstrapArg.isDexValueConstDynamic()
+          && bootstrapArg
+              .asDexValueConstDynamic()
+              .getValue()
+              .getType()
+              .isIdenticalTo(appView.dexItemFactory().enumDescType)) {
         DexField enumField =
             extractEnumField(bootstrapArg.asDexValueConstDynamic(), getContext(), appView);
         if (enumField != null) {
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
index 1ef8edc..6881a1b 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
@@ -54,7 +54,7 @@
   // Locally generated synthetic classes.
   public final SyntheticKind LAMBDA = generator.forInstanceClass("Lambda");
   public final SyntheticKind THREAD_LOCAL = generator.forInstanceClass("ThreadLocal");
-  public final SyntheticKind TYPE_SWITCH_CLASS = generator.forInstanceClass("TypeSwitch");
+  public final SyntheticKind TYPE_SWITCH_CLASS = generator.forInstanceClass("TypeSwitchClass");
 
   // Merging not permitted since this could defeat the purpose of the synthetic class.
   public final SyntheticKind SHARED_SUPER_CLASS =
@@ -71,7 +71,7 @@
   public final SyntheticKind AUTOCLOSEABLE_FORWARDER =
       generator.forSingleMethodWithGlobalMerging("AutoCloseableForwarder");
   public final SyntheticKind TYPE_SWITCH_HELPER =
-      generator.forSingleMethodWithGlobalMerging("TypeSwitch");
+      generator.forSingleMethodWithGlobalMerging("TypeSwitchHelper");
   public final SyntheticKind ENUM_UNBOXING_CHECK_NOT_ZERO_METHOD =
       generator.forSingleMethodWithGlobalMerging("CheckNotZero");
   public final SyntheticKind RECORD_HELPER = generator.forSingleMethodWithGlobalMerging("Record");
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumLessCasesAtRuntimeSwitchTest.java b/src/test/examplesJava21/switchpatternmatching/EnumLessCasesAtRuntimeSwitchTest.java
index 50baa2e..f7045a5 100644
--- a/src/test/examplesJava21/switchpatternmatching/EnumLessCasesAtRuntimeSwitchTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/EnumLessCasesAtRuntimeSwitchTest.java
@@ -94,6 +94,7 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumMoreCasesAtRuntimeSwitchTest.java b/src/test/examplesJava21/switchpatternmatching/EnumMoreCasesAtRuntimeSwitchTest.java
index 493d588..1a79f51 100644
--- a/src/test/examplesJava21/switchpatternmatching/EnumMoreCasesAtRuntimeSwitchTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/EnumMoreCasesAtRuntimeSwitchTest.java
@@ -59,24 +59,6 @@
           "4",
           "0");
 
-  public static String UNEXPECTED_OUTPUT_R8_DEX =
-      StringUtils.lines(
-          "TYPE",
-          "null",
-          "E1",
-          "class %s",
-          "E3",
-          "class %s",
-          "E5",
-          "a C",
-          "ENUM",
-          "null",
-          "1",
-          "0",
-          "3",
-          "0", // This is the difference from the EXPECTED_OUTPUT.
-          "0");
-
   @Test
   public void testJvm() throws Exception {
     assumeTrue(parameters.isCfRuntime());
@@ -144,20 +126,10 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
         .run(parameters.getRuntime(), Main.class)
-        .applyIf(
-            parameters.isDexRuntime(),
-            // TODO(b/381825147): Should same output.
-            r ->
-                r.assertSuccessWithOutput(
-                    String.format(
-                        UNEXPECTED_OUTPUT_R8_DEX,
-                        matchException(parameters),
-                        matchException(parameters))),
-            r ->
-                r.assertSuccessWithOutput(
-                    String.format(
-                        EXPECTED_OUTPUT, matchException(parameters), matchException(parameters))));
+        .assertSuccessWithOutput(
+            String.format(EXPECTED_OUTPUT, matchException(parameters), matchException(parameters)));
   }
 
   sealed interface I permits CompileTimeE, C {}
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java b/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
index bd328bf..774a285 100644
--- a/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
@@ -106,6 +106,7 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutput(String.format(EXPECTED_OUTPUT, matchException(parameters)));
   }
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethodTest.java b/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethodTest.java
index 605fd80..5954a9d 100644
--- a/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethodTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethodTest.java
@@ -91,6 +91,7 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
diff --git a/src/test/java/com/android/tools/r8/cfmethodgeneration/TypeSwitchMethods.java b/src/test/java/com/android/tools/r8/cfmethodgeneration/TypeSwitchMethods.java
index 58bd939..53ae439 100644
--- a/src/test/java/com/android/tools/r8/cfmethodgeneration/TypeSwitchMethods.java
+++ b/src/test/java/com/android/tools/r8/cfmethodgeneration/TypeSwitchMethods.java
@@ -6,24 +6,33 @@
 
 public class TypeSwitchMethods {
 
-  public static int typeSwitch(Object obj, int restart, Object[] tests) {
-    if (obj == null) {
-      return -1;
-    }
-    for (int i = restart; i < tests.length; i++) {
-      Object test = tests[i];
-      if (test instanceof Class<?>) {
-        if (((Class<?>) test).isInstance(obj)) {
-          return i;
+  public static boolean switchEnumEq(
+      Object value, Object[] cache, int index, String enumClass, String name) {
+    if (cache[index] == null) {
+      Object resolved = null;
+      try {
+        Class<?> clazz = Class.forName(enumClass);
+        if (clazz.isEnum()) {
+          Class<? extends Enum> enumClazz = (Class<? extends Enum>) clazz;
+          resolved = Enum.valueOf(enumClazz, name);
         }
-      } else {
-        // This is an integer, a string or an enum instance.
-        if (obj.equals(test)) {
-          return i;
-        }
+      } catch (Throwable t) {
       }
+      // R8 sets a sentinel if resolution has failed.
+      cache[index] = resolved == null ? new Object() : resolved;
     }
-    // Default case.
-    return -2;
+    return value == cache[index];
+  }
+
+  public static boolean switchIntEq(Object value, int constant) {
+    if (value instanceof Number) {
+      Number num = (Number) value;
+      return constant == num.intValue();
+    }
+    if (value instanceof Character) {
+      Character ch = (Character) value;
+      return constant == ch.charValue();
+    }
+    return false;
   }
 }
diff --git a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexNumberValueSwitchTest.java b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexNumberValueSwitchTest.java
index 9e20a4b..4734d64 100644
--- a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexNumberValueSwitchTest.java
+++ b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexNumberValueSwitchTest.java
@@ -16,7 +16,6 @@
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -37,7 +36,10 @@
         .build();
   }
 
-  public static String EXPECTED_OUTPUT = StringUtils.lines("TODO");
+  public static String EXPECTED_OUTPUT =
+      StringUtils.lines(
+          "null", "42", "positif", "negatif", "null", "42", "positif", "negatif", "null", "42",
+          "positif", "negatif", "null", "true", "false");
 
   @Test
   public void testJvm() throws Exception {
@@ -47,13 +49,12 @@
         hasJdk21TypeSwitch(inspector.clazz(Main.class).uniqueMethodWithOriginalName("longSwitch")));
     parameters.assumeJvmTestParameters();
     testForJvm(parameters)
+        .enablePreview()
         .addInnerClassesAndStrippedOuter(getClass())
         .run(parameters.getRuntime(), Main.class)
-        // This is successful with the jvm with --enable-preview flag only.
-        .assertFailureWithErrorThatThrows(BootstrapMethodError.class);
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
-  @Ignore("Fixed in next CL")
   @Test
   public void testD8() throws Exception {
     testForD8(parameters.getBackend())
@@ -63,7 +64,6 @@
         .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
-  @Ignore("Fixed in next CL")
   @Test
   public void testR8() throws Exception {
     parameters.assumeR8TestParameters();
@@ -74,6 +74,8 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .compile()
+        .applyIf(parameters.isCfRuntime(), b -> b.addVmArguments("--enable-preview"))
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
diff --git a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexIntValueSwitchTest.java b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexValueSwitchTest.java
similarity index 68%
rename from src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexIntValueSwitchTest.java
rename to src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexValueSwitchTest.java
index 43475eb..9f63637 100644
--- a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexIntValueSwitchTest.java
+++ b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/DexValueSwitchTest.java
@@ -23,7 +23,7 @@
 import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
-public class DexIntValueSwitchTest extends TestBase {
+public class DexValueSwitchTest extends TestBase {
 
   @Parameter public TestParameters parameters;
 
@@ -41,16 +41,6 @@
           "null", "42", "positif", "negatif", "null", "c", "upper", "lower", "null", "42",
           "positif", "negatif", "null", "42", "positif", "negatif");
 
-  public static String FAILED_DESUGARED_OUTPUT =
-      StringUtils.lines(
-          "null", "42", "positif", "negatif", "null", "lower", "upper", "lower", "null", "positif",
-          "positif", "negatif", "null", "positif", "positif", "negatif");
-
-  public static String FAILED_DESUGARED_OUTPUT_R8 =
-      StringUtils.lines(
-          "null", "42", "positif", "negatif", "null", "lower", "lower", "lower", "null", "negatif",
-          "negatif", "negatif", "null", "negatif", "negatif", "negatif");
-
   @Test
   public void testJvm() throws Exception {
     assumeTrue(parameters.isCfRuntime());
@@ -71,7 +61,7 @@
         .addInnerClassesAndStrippedOuter(getClass())
         .setMinApi(parameters)
         .run(parameters.getRuntime(), Main.class)
-        .assertSuccessWithOutput(FAILED_DESUGARED_OUTPUT);
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
   @Test
@@ -85,14 +75,76 @@
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
         .run(parameters.getRuntime(), Main.class)
-        .applyIf(
-            parameters.isCfRuntime(),
-            b -> b.assertSuccessWithOutput(EXPECTED_OUTPUT),
-            b -> b.assertSuccessWithOutput(FAILED_DESUGARED_OUTPUT_R8));
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
   static class Main {
 
+    // static void booleanSwitch(Boolean b) {
+    //   switch (b) {
+    //     case null -> {
+    //       System.out.println("null");
+    //     }
+    //     case true -> {
+    //       System.out.println("true");
+    //     }
+    //     default -> {
+    //       System.out.println("false");
+    //     }
+    //   }
+    // }
+    //
+    // static void doubleSwitch(Double d) {
+    //   switch (d) {
+    //     case null -> {
+    //       System.out.println("null");
+    //     }
+    //     case 42.0 -> {
+    //       System.out.println("42");
+    //     }
+    //     case Double f2 when f2 > 0 -> {
+    //       System.out.println("positif");
+    //     }
+    //     default -> {
+    //       System.out.println("negatif");
+    //     }
+    //   }
+    // }
+    //
+    // static void floatSwitch(Float f) {
+    //   switch (f) {
+    //     case null -> {
+    //       System.out.println("null");
+    //     }
+    //     case 42.0f -> {
+    //       System.out.println("42");
+    //     }
+    //     case Float f2 when f2 > 0 -> {
+    //       System.out.println("positif");
+    //     }
+    //     default -> {
+    //       System.out.println("negatif");
+    //     }
+    //   }
+    // }
+    //
+    // static void longSwitch(Long l) {
+    //   switch (l) {
+    //     case null -> {
+    //       System.out.println("null");
+    //     }
+    //     case 42L -> {
+    //       System.out.println("42");
+    //     }
+    //     case Long i2 when i2 > 0 -> {
+    //       System.out.println("positif");
+    //     }
+    //     default -> {
+    //       System.out.println("negatif");
+    //     }
+    //   }
+    // }
+
     static void intSwitch(Integer i) {
       switch (i) {
         case null -> {
@@ -166,18 +218,40 @@
       intSwitch(42);
       intSwitch(12);
       intSwitch(-1);
+
       charSwitch(null);
       charSwitch('c');
       charSwitch('X');
       charSwitch('x');
+
       byteSwitch(null);
       byteSwitch((byte) 42);
       byteSwitch((byte) 12);
       byteSwitch((byte) -1);
+
       shortSwitch(null);
       shortSwitch((short) 42);
       shortSwitch((short) 12);
       shortSwitch((short) -1);
+
+      // longSwitch(null);
+      // longSwitch(42L);
+      // longSwitch(12L);
+      // longSwitch(-1L);
+      //
+      // floatSwitch(null);
+      // floatSwitch(42.0f);
+      // floatSwitch(12.0f);
+      // floatSwitch(-1.0f);
+      //
+      // doubleSwitch(null);
+      // doubleSwitch(42.0);
+      // doubleSwitch(12.0);
+      // doubleSwitch(-1.0);
+      //
+      // booleanSwitch(null);
+      // booleanSwitch(true);
+      // booleanSwitch(false);
     }
   }
 }
diff --git a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumLessCasesAtRuntimeSwitchTest.java b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumLessCasesAtRuntimeSwitchTest.java
index b87d40f..558da5b 100644
--- a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumLessCasesAtRuntimeSwitchTest.java
+++ b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumLessCasesAtRuntimeSwitchTest.java
@@ -95,6 +95,7 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
diff --git a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumMoreCasesAtRuntimeSwitchTest.java b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumMoreCasesAtRuntimeSwitchTest.java
index 83af928..0903865 100644
--- a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumMoreCasesAtRuntimeSwitchTest.java
+++ b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumMoreCasesAtRuntimeSwitchTest.java
@@ -60,24 +60,6 @@
           "4",
           "0");
 
-  public static String UNEXPECTED_OUTPUT_R8_DEX =
-      StringUtils.lines(
-          "TYPE",
-          "null",
-          "E1",
-          "class %s",
-          "E3",
-          "class %s",
-          "E5",
-          "a C",
-          "ENUM",
-          "null",
-          "1",
-          "0",
-          "3",
-          "0", // This is the difference from the EXPECTED_OUTPUT.
-          "0");
-
   @Test
   public void testJvm() throws Exception {
     assumeTrue(parameters.isCfRuntime());
@@ -145,20 +127,10 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
         .run(parameters.getRuntime(), Main.class)
-        .applyIf(
-            parameters.isDexRuntime(),
-            // TODO(b/381825147): Should same output.
-            r ->
-                r.assertSuccessWithOutput(
-                    String.format(
-                        UNEXPECTED_OUTPUT_R8_DEX,
-                        matchException(parameters),
-                        matchException(parameters))),
-            r ->
-                r.assertSuccessWithOutput(
-                    String.format(
-                        EXPECTED_OUTPUT, matchException(parameters), matchException(parameters))));
+        .assertSuccessWithOutput(
+            String.format(EXPECTED_OUTPUT, matchException(parameters), matchException(parameters)));
   }
 
   sealed interface I permits CompileTimeE, C {}
diff --git a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumSwitchTest.java b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumSwitchTest.java
index 4a27c29..e9c5ae5 100644
--- a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumSwitchTest.java
+++ b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumSwitchTest.java
@@ -101,6 +101,7 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutput(String.format(EXPECTED_OUTPUT, matchException(parameters)));
   }
diff --git a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethodTest.java b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethodTest.java
index ed408fc..9028df6 100644
--- a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethodTest.java
+++ b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/EnumSwitchUsingEnumSwitchBootstrapMethodTest.java
@@ -92,6 +92,7 @@
             b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
         .setMinApi(parameters)
         .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
diff --git a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/StringSwitchTest.java b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/StringSwitchTest.java
index cd4b157..7786d21 100644
--- a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/StringSwitchTest.java
+++ b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/StringSwitchTest.java
@@ -81,6 +81,7 @@
   }
 
   static class Main {
+
     static void stringSwitch(String string) {
       switch (string) {
         case null -> {
diff --git a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/TypeSwitchTest.java b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/TypeSwitchTest.java
index 686408d..e1eea16 100644
--- a/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/TypeSwitchTest.java
+++ b/src/test/java23/com/android/tools/r8/java23/switchpatternmatching/TypeSwitchTest.java
@@ -62,11 +62,7 @@
         .addInnerClassesAndStrippedOuter(getClass())
         .setMinApi(parameters)
         .run(parameters.getRuntime(), Main.class)
-        .applyIf(
-            isRecordsFullyDesugaredForD8(parameters)
-                || runtimeWithRecordsSupport(parameters.getRuntime()),
-            r -> r.assertSuccessWithOutput(EXPECTED_OUTPUT),
-            r -> r.assertFailureWithErrorThatThrows(NoClassDefFoundError.class));
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
   @Test