Rewrite enumEq calls into identity checks in type switches

Bug: b/399808482
Change-Id: I918f98a8590d3ead97a1ec9de5b98bc2a2f9a774
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/StaticFieldValues.java b/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/StaticFieldValues.java
index 0b9e480..b5e6590 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/StaticFieldValues.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/StaticFieldValues.java
@@ -9,11 +9,15 @@
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.ir.analysis.value.objectstate.EnumValuesObjectState;
 import com.android.tools.r8.ir.analysis.value.objectstate.ObjectState;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Pair;
 import com.google.common.collect.ImmutableMap;
+import java.util.Map.Entry;
 
 public abstract class StaticFieldValues {
 
@@ -112,6 +116,36 @@
     public ObjectState getObjectStateForPossiblyPinnedField(DexField field) {
       return enumAbstractValues.get(field);
     }
+
+    public Pair<DexField, Integer> getFieldAccessorForName(DexString name, DexItemFactory factory) {
+      Entry<DexField, ObjectState> array = null;
+      for (Entry<DexField, ObjectState> fieldAndState : enumAbstractValues.entrySet()) {
+        DexField field = fieldAndState.getKey();
+        if (field.getType().isArrayType()) {
+          array = fieldAndState;
+        } else {
+          AbstractValue fieldValue =
+              fieldAndState.getValue().getAbstractFieldValue(factory.enumMembers.nameField);
+          if (fieldValue.isSingleStringValue()
+              && fieldValue.asSingleStringValue().getDexString().isIdenticalTo(name)) {
+            return new Pair<>(field, null);
+          }
+        }
+      }
+      if (array != null && array.getValue().isEnumValuesObjectState()) {
+        EnumValuesObjectState valuesState = array.getValue().asEnumValuesObjectState();
+        ObjectState[] objectStates = valuesState.getState();
+        for (int i = 0; i < objectStates.length; i++) {
+          AbstractValue fieldValue =
+              objectStates[i].getAbstractFieldValue(factory.enumMembers.nameField);
+          if (fieldValue.isSingleStringValue()
+              && fieldValue.asSingleStringValue().getDexString().isIdenticalTo(name)) {
+            return new Pair<>(array.getKey(), i);
+          }
+        }
+      }
+      return null;
+    }
   }
 
   public static class EmptyStaticValues extends StaticFieldValues {
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/value/objectstate/EnumValuesObjectState.java b/src/main/java/com/android/tools/r8/ir/analysis/value/objectstate/EnumValuesObjectState.java
index 25a9566..68c7172 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/value/objectstate/EnumValuesObjectState.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/value/objectstate/EnumValuesObjectState.java
@@ -43,6 +43,10 @@
     return UnknownValue.getInstance();
   }
 
+  public ObjectState[] getState() {
+    return state;
+  }
+
   public ObjectState getObjectStateForOrdinal(int ordinal) {
     if (ordinal < 0 || ordinal >= state.length) {
       return ObjectState.empty();
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 120ee81..8d39236 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -41,6 +41,7 @@
 import com.android.tools.r8.ir.conversion.passes.ThrowCatchOptimizer;
 import com.android.tools.r8.ir.conversion.passes.TrivialPhiSimplifier;
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaringCollectionSupplier;
+import com.android.tools.r8.ir.desugar.typeswitch.TypeSwitchIRRewriter;
 import com.android.tools.r8.ir.optimize.AssertionErrorTwoArgsConstructorRewriter;
 import com.android.tools.r8.ir.optimize.AssertionsRewriter;
 import com.android.tools.r8.ir.optimize.AssumeInserter;
@@ -122,6 +123,7 @@
   private final Devirtualizer devirtualizer;
   private final TypeChecker typeChecker;
   protected EnumUnboxer enumUnboxer;
+  protected final TypeSwitchIRRewriter typeSwitchIRRewriter;
   protected final NumberUnboxer numberUnboxer;
   protected final RemoveVerificationErrorForUnknownReturnedValues
       removeVerificationErrorForUnknownReturnedValues;
@@ -201,6 +203,7 @@
       this.typeChecker = null;
       this.methodOptimizationInfoCollector = null;
       this.enumUnboxer = EnumUnboxer.empty();
+      this.typeSwitchIRRewriter = null;
       this.numberUnboxer = NumberUnboxer.empty();
       this.assumeInserter = null;
       this.removeVerificationErrorForUnknownReturnedValues = null;
@@ -234,6 +237,7 @@
               ? new LibraryMethodOverrideAnalysis(appViewWithLiveness)
               : null;
       this.enumUnboxer = EnumUnboxer.create(appViewWithLiveness);
+      this.typeSwitchIRRewriter = TypeSwitchIRRewriter.create(appViewWithLiveness);
       this.numberUnboxer = NumberUnboxer.create(appViewWithLiveness);
       this.outliner = Outliner.create(appViewWithLiveness);
       this.memberValuePropagation = new R8MemberValuePropagation(appViewWithLiveness);
@@ -269,6 +273,7 @@
       this.typeChecker = null;
       this.methodOptimizationInfoCollector = null;
       this.enumUnboxer = EnumUnboxer.empty();
+      this.typeSwitchIRRewriter = null;
       this.numberUnboxer = NumberUnboxer.empty();
     }
   }
@@ -687,6 +692,10 @@
     rewriterPassCollection.run(
         code, methodProcessor, methodProcessingContext, timing, previous, options);
 
+    if (typeSwitchIRRewriter != null) {
+      typeSwitchIRRewriter.run(code);
+    }
+
     timing.begin("Optimize class initializers");
     ClassInitializerDefaultsResult classInitializerDefaultsResult =
         classInitializerDefaultsOptimization.optimize(code, feedback);
@@ -892,6 +901,9 @@
       }
     }
     enumUnboxer.recordEnumState(method.getHolder(), staticFieldValues);
+    if (typeSwitchIRRewriter != null) {
+      typeSwitchIRRewriter.recordEnumState(method.getHolder(), staticFieldValues);
+    }
     if (appView.options().protoShrinking().enableRemoveProtoEnumSwitchMap()) {
       appView
           .protoShrinker()
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/SwitchHelperGenerator.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/SwitchHelperGenerator.java
index 2c901d2..72098d9 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/SwitchHelperGenerator.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/SwitchHelperGenerator.java
@@ -31,6 +31,7 @@
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaringEventConsumer;
 import com.android.tools.r8.ir.synthetic.TypeSwitchSyntheticCfCodeProvider;
 import com.android.tools.r8.ir.synthetic.TypeSwitchSyntheticCfCodeProvider.Dispatcher;
+import com.android.tools.r8.synthesis.SyntheticItems.SyntheticKindSelector;
 import com.android.tools.r8.synthesis.SyntheticProgramClassBuilder;
 import com.android.tools.r8.utils.ListUtils;
 import com.google.common.collect.ImmutableList;
@@ -49,7 +50,7 @@
   private DexMethod intEq;
   private DexField enumCacheField;
   private int enumCases = 0;
-  private Map<DexType, DexMethod> enumEqMethods = new IdentityHashMap<>();
+  private final Map<DexType, DexMethod> enumEqMethods = new IdentityHashMap<>();
 
   SwitchHelperGenerator(AppView<?> appView, DexCallSite dexCallSite) {
     this.appView = appView;
@@ -148,7 +149,8 @@
                   cfCode.getInstructions(), i -> i.isConstClass() ? new CfConstClass(enumType) : i);
           cfCode.setInstructions(newInstructions);
           return cfCode;
-        });
+        },
+        kinds -> kinds.TYPE_SWITCH_HELPER_ENUM);
   }
 
   private DexMethod generateIntEqMethod(
@@ -162,7 +164,8 @@
         eventConsumer,
         methodProcessingContext,
         proto,
-        methodSig -> TypeSwitchMethods.TypeSwitchMethods_switchIntEq(factory, methodSig));
+        methodSig -> TypeSwitchMethods.TypeSwitchMethods_switchIntEq(factory, methodSig),
+        kinds -> kinds.TYPE_SWITCH_HELPER_INT);
   }
 
   private DexMethod generateMethod(
@@ -170,13 +173,14 @@
       TypeSwitchDesugaringEventConsumer eventConsumer,
       MethodProcessingContext methodProcessingContext,
       DexProto proto,
-      Function<DexMethod, CfCode> cfCodeGen) {
+      Function<DexMethod, CfCode> cfCodeGen,
+      SyntheticKindSelector kindSelector) {
     DexItemFactory factory = appView.dexItemFactory();
     ProgramMethod method =
         appView
             .getSyntheticItems()
             .createMethod(
-                kinds -> kinds.TYPE_SWITCH_HELPER,
+                kindSelector,
                 methodProcessingContext.createUniqueContext(),
                 appView,
                 builder ->
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchIRRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchIRRewriter.java
new file mode 100644
index 0000000..ea11bfe
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchIRRewriter.java
@@ -0,0 +1,221 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.desugar.typeswitch;
+
+import com.android.tools.r8.dex.code.CfOrDexInstruction;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DefaultUseRegistryWithResult;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.UseRegistryWithResult;
+import com.android.tools.r8.ir.analysis.fieldvalueanalysis.StaticFieldValues;
+import com.android.tools.r8.ir.analysis.fieldvalueanalysis.StaticFieldValues.EnumStaticFieldValues;
+import com.android.tools.r8.ir.analysis.type.Nullability;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
+import com.android.tools.r8.ir.code.ArrayGet;
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.If;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.InvokeStatic;
+import com.android.tools.r8.ir.code.MemberType;
+import com.android.tools.r8.ir.code.StaticGet;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Pair;
+import com.android.tools.r8.utils.collections.BidirectionalManyToOneHashMap;
+import com.android.tools.r8.utils.collections.BidirectionalManyToOneMap;
+import com.google.common.collect.ImmutableList;
+import java.util.IdentityHashMap;
+import java.util.ListIterator;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+public class TypeSwitchIRRewriter {
+  private final AppView<AppInfoWithLiveness> appView;
+  private final BidirectionalManyToOneMap<DexMethod, DexType> enumEqMethods;
+
+  private final Map<DexType, EnumStaticFieldValues> staticFieldValuesMap =
+      new ConcurrentHashMap<>();
+
+  public static TypeSwitchIRRewriter create(AppView<AppInfoWithLiveness> appView) {
+    BidirectionalManyToOneMap<DexMethod, DexType> enumEqMethods = buildEnumEqMethods(appView);
+    if (enumEqMethods.isEmpty()) {
+      return null;
+    }
+    return new TypeSwitchIRRewriter(appView, enumEqMethods);
+  }
+
+  private TypeSwitchIRRewriter(
+      AppView<AppInfoWithLiveness> appView,
+      BidirectionalManyToOneMap<DexMethod, DexType> enumEqMethods) {
+    this.appView = appView;
+    this.enumEqMethods = enumEqMethods;
+  }
+
+  private static BidirectionalManyToOneMap<DexMethod, DexType> buildEnumEqMethods(
+      AppView<AppInfoWithLiveness> appView) {
+    BidirectionalManyToOneHashMap<DexMethod, DexType> res =
+        BidirectionalManyToOneHashMap.newIdentityHashMap();
+    for (DexProgramClass clazz : appView.appInfo().classes()) {
+      if (appView
+          .getSyntheticItems()
+          .isSyntheticOfKind(clazz.getType(), kinds -> kinds.TYPE_SWITCH_HELPER_ENUM)) {
+        if (!clazz.hasMethods()) {
+          continue;
+        }
+        ProgramMethod uniqueStaticMethod = clazz.directProgramMethods().iterator().next();
+        DexMethod uniqueMethod = uniqueStaticMethod.getReference();
+        UseRegistryWithResult<DexType, ProgramMethod> registry =
+            new DefaultUseRegistryWithResult<>(appView, uniqueStaticMethod) {
+              @Override
+              public void registerConstClass(
+                  DexType type,
+                  ListIterator<? extends CfOrDexInstruction> iterator,
+                  boolean ignoreCompatRules) {
+                assert getResult() == null;
+                setResult(type);
+              }
+            };
+        DexType enumType = uniqueStaticMethod.registerCodeReferencesWithResult(registry);
+        if (enumType != null) {
+          res.put(uniqueMethod, enumType);
+        }
+      }
+    }
+    return res;
+  }
+
+  public void recordEnumState(DexProgramClass clazz, StaticFieldValues staticFieldValues) {
+    if (staticFieldValues == null || !staticFieldValues.isEnumStaticFieldValues()) {
+      return;
+    }
+    assert clazz.isEnum();
+    EnumStaticFieldValues enumStaticFieldValues = staticFieldValues.asEnumStaticFieldValues();
+    if (enumEqMethods.containsValue(clazz.type)) {
+      staticFieldValuesMap.put(clazz.type, enumStaticFieldValues);
+    }
+  }
+
+  public void run(IRCode code) {
+    if (shouldRewriteCode(code)) {
+      rewriteCode(code);
+    }
+  }
+
+  protected boolean shouldRewriteCode(IRCode code) {
+    return !code.context().getDefinition().isClassInitializer()
+        && appView
+            .getSyntheticItems()
+            .isSyntheticOfKind(code.context().getHolderType(), k -> k.TYPE_SWITCH_CLASS);
+  }
+
+  protected CodeRewriterResult rewriteCode(IRCode code) {
+    boolean change = false;
+    Map<If, If> replacement = new IdentityHashMap<>();
+    InstructionListIterator iterator = code.instructionListIterator();
+    while (iterator.hasNext()) {
+      Instruction next = iterator.next();
+      if (next.isIf()) {
+        If anIf = next.asIf();
+        If replace = replacement.get(anIf);
+        if (replace != null) {
+          replacement.remove(anIf);
+          BasicBlock fallThrough = anIf.fallthroughBlock();
+          BasicBlock jumpTarget = anIf.getTrueTarget();
+          iterator.replaceCurrentInstruction(replace);
+          replace.setFallthroughBlock(fallThrough);
+          replace.setTrueTarget(jumpTarget);
+        }
+      } else if (next.isInvokeStatic()
+          && appView
+              .getSyntheticItems()
+              .isSyntheticOfKind(
+                  next.asInvokeStatic().getInvokedMethod().getHolderType(),
+                  kinds -> kinds.TYPE_SWITCH_HELPER_ENUM)) {
+        if (!next.hasOutValue()) {
+          // This happens if the enum is missing from the compilation, the result is always false.
+          iterator.removeOrReplaceByDebugLocalRead();
+        } else if (next.outValue().hasSingleUniqueUserAndNoOtherUsers()) {
+          InvokeStatic invokeStatic = next.asInvokeStatic();
+          DexString fieldName = invokeStatic.getLastArgument().getConstStringOrNull();
+          if (fieldName == null) {
+            continue;
+          }
+          DexType enumType = enumEqMethods.get(invokeStatic.getInvokedMethod());
+          if (enumType == null) {
+            continue;
+          }
+          EnumStaticFieldValues enumStaticFieldValues = staticFieldValuesMap.get(enumType);
+          if (enumStaticFieldValues == null) {
+            continue;
+          }
+          Pair<DexField, Integer> fieldAccessor =
+              enumStaticFieldValues.getFieldAccessorForName(fieldName, appView.dexItemFactory());
+          if (fieldAccessor == null) {
+            continue;
+          }
+          if (!invokeStatic.outValue().hasSingleUniqueUserAndNoOtherUsers()) {
+            continue;
+          }
+          Instruction instruction = invokeStatic.outValue().singleUniqueUser();
+          if (!instruction.isIf()) {
+            continue;
+          }
+          // The enum instance has been resolved. We can replace
+          // if (enumEq(val, cache, cacheIndex, name))
+          // by
+          // if (val == fieldAccessor)
+          Value newValue;
+          if (fieldAccessor.getSecond() == null) {
+            newValue =
+                code.createValue(
+                    TypeElement.fromDexType(
+                        fieldAccessor.getFirst().getType(), Nullability.maybeNull(), appView));
+            iterator.replaceCurrentInstruction(new StaticGet(newValue, fieldAccessor.getFirst()));
+          } else {
+            newValue =
+                code.createValue(
+                    TypeElement.fromDexType(
+                        fieldAccessor.getFirst().getType().toBaseType(appView.dexItemFactory()),
+                        Nullability.maybeNull(),
+                        appView));
+            Value arrayValue =
+                code.createValue(
+                    TypeElement.fromDexType(
+                        fieldAccessor.getFirst().getType(), Nullability.maybeNull(), appView));
+            iterator.previous();
+            Value intValue =
+                iterator.insertConstIntInstruction(
+                    code, appView.options(), fieldAccessor.getSecond());
+            StaticGet fieldGet = new StaticGet(arrayValue, fieldAccessor.getFirst());
+            iterator.add(fieldGet);
+            iterator.next();
+            ArrayGet arrayGet = new ArrayGet(MemberType.OBJECT, newValue, arrayValue, intValue);
+            iterator.replaceCurrentInstruction(arrayGet);
+            fieldGet.setPosition(arrayGet.getPosition());
+          }
+          If anIf = instruction.asIf();
+          // The method enumEq answered true/false based on if the values are equal, which is
+          // inverted when one compares directly the values.
+          If newIf =
+              new If(
+                  anIf.getType().inverted(),
+                  ImmutableList.of(newValue, invokeStatic.getFirstArgument()));
+          replacement.put(anIf, newIf);
+          change = true;
+        }
+      }
+    }
+    assert replacement.isEmpty();
+    return CodeRewriterResult.hasChanged(change);
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticItems.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticItems.java
index 683ad53..e60d77e 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticItems.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticItems.java
@@ -416,7 +416,6 @@
   }
 
   // Predicates and accessors.
-
   @Override
   public ClassResolutionResult definitionFor(
       DexType type, Function<DexType, ClassResolutionResult> baseDefinitionFor) {
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
index 4ad6f93..7c95c3b 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
@@ -70,8 +70,10 @@
       generator.forSingleMethodWithGlobalMerging("AutoCloseableDispatcher");
   public final SyntheticKind AUTOCLOSEABLE_FORWARDER =
       generator.forSingleMethodWithGlobalMerging("AutoCloseableForwarder");
-  public final SyntheticKind TYPE_SWITCH_HELPER =
-      generator.forSingleMethodWithGlobalMerging("TypeSwitchHelper");
+  public final SyntheticKind TYPE_SWITCH_HELPER_INT =
+      generator.forSingleMethodWithGlobalMerging("TypeSwitchHelperInt");
+  public final SyntheticKind TYPE_SWITCH_HELPER_ENUM =
+      generator.forSingleMethodWithGlobalMerging("TypeSwitchHelperEnum");
   public final SyntheticKind ENUM_UNBOXING_CHECK_NOT_ZERO_METHOD =
       generator.forSingleMethodWithGlobalMerging("CheckNotZero");
   public final SyntheticKind RECORD_HELPER = generator.forSingleMethodWithGlobalMerging("Record");
diff --git a/src/test/java24/com/android/tools/r8/jdk24/switchpatternmatching/EnumSwitchOldSyntaxV2Test.java b/src/test/java24/com/android/tools/r8/jdk24/switchpatternmatching/EnumSwitchOldSyntaxV2Test.java
new file mode 100644
index 0000000..c1e3f1d
--- /dev/null
+++ b/src/test/java24/com/android/tools/r8/jdk24/switchpatternmatching/EnumSwitchOldSyntaxV2Test.java
@@ -0,0 +1,130 @@
+// Copyright (c) 2025 the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.jdk24.switchpatternmatching;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.JdkClassFileProvider;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.TestRuntime.CfVm;
+import com.android.tools.r8.utils.StringUtils;
+import org.junit.Assume;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class EnumSwitchOldSyntaxV2Test extends TestBase {
+
+  @Parameter public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters()
+        .withCfRuntimesStartingFromIncluding(CfVm.JDK24)
+        .withDexRuntimes()
+        .withAllApiLevelsAlsoForCf()
+        .withPartialCompilation()
+        .build();
+  }
+
+  public static String EXPECTED_OUTPUT = StringUtils.lines("null", "e11", "e22", "e33");
+
+  @Test
+  public void testJvm() throws Exception {
+    parameters.assumeJvmTestParameters();
+    testForJvm(parameters)
+        .addInnerClassesAndStrippedOuter(getClass())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    testForD8(parameters)
+        .addInnerClassesAndStrippedOuter(getClass())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    parameters.assumeR8TestParameters();
+    testForR8(parameters)
+        .addInnerClassesAndStrippedOuter(getClass())
+        .applyIf(
+            parameters.isCfRuntime(),
+            b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
+        .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
+  }
+
+  @Test
+  public void testR8Split() throws Exception {
+    Assume.assumeFalse("TODO(b/414335863)", parameters.isRandomPartialCompilation());
+    parameters.assumeR8TestParameters();
+    R8TestCompileResult compile =
+        testForR8(Backend.CF)
+            .addInnerClassesAndStrippedOuter(getClass())
+            .addLibraryProvider(JdkClassFileProvider.fromSystemJdk())
+            .addKeepMainRule(Main.class)
+            .addKeepEnumsRule()
+            .compile();
+    compile.inspect(i -> assertEquals(1, i.clazz(E.class).allFields().size()));
+    // The enum is there with the $VALUES field but not each enum field.
+    testForR8(parameters)
+        .addProgramFiles(compile.writeToZip())
+        .addKeepMainRule(Main.class)
+        .addKeepEnumsRule()
+        .applyIf(
+            parameters.isCfRuntime(),
+            b -> b.addLibraryProvider(JdkClassFileProvider.fromSystemJdk()))
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
+  }
+
+  public enum E {
+    E1,
+    E2,
+    E3
+  }
+
+  static class Main {
+
+    static void enumSwitch(E e) {
+      switch (e) {
+        case E.E1:
+          System.out.println("e11");
+          break;
+        case E.E2:
+          System.out.println("e22");
+          break;
+        case E.E3:
+          System.out.println("e33");
+          break;
+        case null:
+          System.out.println("null");
+          break;
+      }
+    }
+
+    public static void main(String[] args) {
+      try {
+        enumSwitch(null);
+      } catch (NullPointerException e) {
+        System.out.println("caught npe");
+      }
+      for (E value : E.values()) {
+        enumSwitch(value);
+      }
+    }
+  }
+}