Fix subtype assumptions when unboxing enum with subtypes
Bug: b/271385332
Change-Id: Ic7db15a268a8cca25cf1dfa0147251d532368b1f
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index cfa8347..ae3f0ac 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -1991,8 +1991,9 @@
       return field == nameField || field == ordinalField;
     }
 
-    public boolean isEnumField(DexEncodedField staticField, DexType enumType) {
-      return isEnumField(staticField, enumType, ImmutableSet.of());
+    public boolean isEnumFieldCandidate(DexEncodedField staticField) {
+      assert staticField.isStatic();
+      return staticField.isEnum() && staticField.isFinal();
     }
 
     // In some case, the enum field may be respecialized to an enum subtype. In this case, one
@@ -2001,8 +2002,7 @@
         DexEncodedField staticField, DexType enumType, Set<DexType> subtypes) {
       assert staticField.isStatic();
       return (staticField.getType() == enumType || subtypes.contains(staticField.getType()))
-          && staticField.isEnum()
-          && staticField.isFinal();
+          && isEnumFieldCandidate(staticField);
     }
 
     public boolean isValuesFieldCandidate(DexEncodedField staticField, DexType enumType) {
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/StaticFieldValues.java b/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/StaticFieldValues.java
index 0f40f03..8e5b4ca 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/StaticFieldValues.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/fieldvalueanalysis/StaticFieldValues.java
@@ -79,7 +79,7 @@
             enumObjectStateBuilder.put(
                 staticField.getReference(), value.asSingleFieldValue().getObjectState());
           }
-        } else if (factory.enumMembers.isEnumField(staticField, staticField.getHolderType())) {
+        } else if (factory.enumMembers.isEnumFieldCandidate(staticField)) {
           if (value.isSingleFieldValue()
               && !value.asSingleFieldValue().getObjectState().isEmpty()) {
             enumObjectStateBuilder.put(
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/value/SingleFieldValue.java b/src/main/java/com/android/tools/r8/ir/analysis/value/SingleFieldValue.java
index 9acf109..b78f18b 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/value/SingleFieldValue.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/value/SingleFieldValue.java
@@ -145,11 +145,9 @@
   public SingleValue rewrittenWithLens(
       AppView<AppInfoWithLiveness> appView, GraphLens lens, GraphLens codeLens) {
     AbstractValueFactory factory = appView.abstractValueFactory();
-    if (field.holder == field.type) {
-      EnumDataMap enumDataMap = appView.unboxedEnums();
-      if (enumDataMap.hasUnboxedValueFor(field)) {
-        return factory.createSingleNumberValue(enumDataMap.getUnboxedValue(field));
-      }
+    EnumDataMap enumDataMap = appView.unboxedEnums();
+    if (enumDataMap.hasUnboxedValueFor(field)) {
+      return factory.createSingleNumberValue(enumDataMap.getUnboxedValue(field));
     }
     DexField rewrittenField = lens.lookupField(field, codeLens);
     ObjectState rewrittenObjectState = getObjectState().rewrittenWithLens(appView, lens, codeLens);
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumDataMap.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumDataMap.java
index 6672c87..06cb686 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumDataMap.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumDataMap.java
@@ -4,18 +4,23 @@
 
 package com.android.tools.r8.ir.optimize.enums;
 
+import static com.android.tools.r8.ir.optimize.enums.EnumUnboxerImpl.ordinalToUnboxedInt;
+
 import com.android.tools.r8.errors.CheckEnumUnboxedDiagnostic;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramField;
+import com.android.tools.r8.ir.analysis.value.AbstractValueFactory;
+import com.android.tools.r8.ir.analysis.value.SingleNumberValue;
 import com.android.tools.r8.ir.optimize.enums.EnumInstanceFieldData.EnumInstanceFieldKnownData;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
+import it.unimi.dsi.fastutil.ints.Int2ReferenceMap.Entry;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Set;
@@ -76,6 +81,15 @@
     return map.containsKey(type);
   }
 
+  public SingleNumberValue getSingleNumberValueFromEnumType(
+      AbstractValueFactory factory, DexType enumType) {
+    assert isUnboxedEnum(enumType);
+    EnumData enumData = get(enumType);
+    return isSuperUnboxedEnum(enumType)
+        ? enumData.superEnumTypeSingleValue(factory, enumType)
+        : enumData.subEnumTypeSingleValue(factory, enumType);
+  }
+
   public boolean isUnboxedEnum(DexProgramClass clazz) {
     return isUnboxedEnum(clazz.getType());
   }
@@ -84,7 +98,7 @@
     return map.containsKey(representativeType(type));
   }
 
-  private EnumData get(DexType type) {
+  public EnumData get(DexType type) {
     EnumData enumData = map.get(representativeType(type));
     assert enumData != null;
     return enumData;
@@ -112,17 +126,19 @@
   }
 
   public boolean hasUnboxedValueFor(DexField enumStaticField) {
-    return isUnboxedEnum(enumStaticField.getHolderType())
-        && get(enumStaticField.getHolderType()).hasUnboxedValueFor(enumStaticField);
+    DexType representative = representativeType(enumStaticField.getHolderType());
+    return isSuperUnboxedEnum(representative)
+        && get(representative).hasUnboxedValueFor(enumStaticField);
   }
 
   public int getUnboxedValue(DexField enumStaticField) {
     assert isUnboxedEnum(enumStaticField.getHolderType());
-    return get(enumStaticField.getHolderType()).getUnboxedValue(enumStaticField);
+    return get(representativeType(enumStaticField.getHolderType()))
+        .getUnboxedValue(enumStaticField);
   }
 
   public int getValuesSize(DexType enumType) {
-    assert isUnboxedEnum(enumType);
+    assert isSuperUnboxedEnum(enumType);
     return get(enumType).getValuesSize();
   }
 
@@ -137,7 +153,7 @@
   }
 
   public boolean matchesValuesField(DexField staticField) {
-    assert isUnboxedEnum(staticField.getHolderType());
+    assert isSuperUnboxedEnum(staticField.getHolderType());
     return get(staticField.getHolderType()).matchesValuesField(staticField);
   }
 
@@ -202,5 +218,28 @@
       assert hasValues();
       return valuesSize;
     }
+
+    public SingleNumberValue superEnumTypeSingleValue(AbstractValueFactory factory, DexType type) {
+      // If there is a single live enum instance, then return the unboxed value for this one.
+      if (hasValues()) {
+        if (valuesSize == 1) {
+          return factory.createSingleNumberValue(ordinalToUnboxedInt(0));
+        }
+      } else if (unboxedValues.size() == 1) {
+        Integer next = unboxedValues.values().iterator().next();
+        return factory.createSingleNumberValue(ordinalToUnboxedInt(next));
+      }
+      return null;
+    }
+
+    public SingleNumberValue subEnumTypeSingleValue(AbstractValueFactory factory, DexType type) {
+      assert valuesTypes.values().stream().filter(t -> t == type).count() <= 1;
+      for (Entry<DexType> entry : valuesTypes.int2ReferenceEntrySet()) {
+        if (entry.getValue() == type) {
+          return factory.createSingleNumberValue(ordinalToUnboxedInt(entry.getIntKey()));
+        }
+      }
+      return null;
+    }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
index 12e1f1e..df5d6b1 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
@@ -181,6 +181,10 @@
         appView.appInfo().resolveField(factory.enumMembers.ordinalField).getResolutionPair();
   }
 
+  public static int unboxedIntToOrdinal(int unboxedInt) {
+    return unboxedInt - 1;
+  }
+
   public static int ordinalToUnboxedInt(int ordinal) {
     return ordinal + 1;
   }
@@ -1280,7 +1284,8 @@
           : Reason.ASSIGNMENT_OUTSIDE_INIT;
     }
     // The put value has to be of the field type.
-    if (field.getReference().type.toBaseType(factory) != enumClass.type) {
+    if (!enumUnboxingCandidatesInfo.isAssignableTo(
+        field.getReference().type.toBaseType(factory), enumClass.type)) {
       return Reason.TYPE_MISMATCH_FIELD_PUT;
     }
     return Reason.ELIGIBLE;
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
index 3aaff49..53f7f63 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
@@ -4,6 +4,8 @@
 
 package com.android.tools.r8.ir.optimize.enums;
 
+import static com.android.tools.r8.ir.optimize.enums.EnumUnboxerImpl.unboxedIntToOrdinal;
+
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
@@ -16,8 +18,10 @@
 import com.android.tools.r8.graph.proto.ArgumentInfoCollection;
 import com.android.tools.r8.graph.proto.RewrittenPrototypeDescription;
 import com.android.tools.r8.graph.proto.RewrittenTypeInfo;
+import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.ir.analysis.value.AbstractValueFactory;
 import com.android.tools.r8.ir.analysis.value.SingleFieldValue;
+import com.android.tools.r8.ir.analysis.value.SingleNumberValue;
 import com.android.tools.r8.ir.analysis.value.SingleValue;
 import com.android.tools.r8.ir.code.InvokeType;
 import com.android.tools.r8.ir.conversion.ExtraUnusedNullParameter;
@@ -41,6 +45,7 @@
   private final AbstractValueFactory abstractValueFactory;
   private final Map<DexMethod, RewrittenPrototypeDescription> prototypeChangesPerMethod;
   private final EnumDataMap unboxedEnums;
+  private final Set<DexMethod> dispatchMethods;
 
   EnumUnboxingLens(
       AppView<?> appView,
@@ -48,12 +53,14 @@
       BidirectionalOneToManyRepresentativeMap<DexMethod, DexMethod> renamedSignatures,
       Map<DexType, DexType> typeMap,
       Map<DexMethod, DexMethod> methodMap,
-      Map<DexMethod, RewrittenPrototypeDescription> prototypeChangesPerMethod) {
+      Map<DexMethod, RewrittenPrototypeDescription> prototypeChangesPerMethod,
+      Set<DexMethod> dispatchMethods) {
     super(appView, fieldMap, methodMap, typeMap, renamedSignatures);
     assert !appView.unboxedEnums().isEmpty();
     this.abstractValueFactory = appView.abstractValueFactory();
     this.prototypeChangesPerMethod = prototypeChangesPerMethod;
     this.unboxedEnums = appView.unboxedEnums();
+    this.dispatchMethods = dispatchMethods;
   }
 
   @Override
@@ -94,6 +101,35 @@
     return true;
   }
 
+  public DexMethod lookupRefinedDispatchMethod(
+      DexMethod method,
+      DexMethod context,
+      InvokeType type,
+      GraphLens codeLens,
+      AbstractValue unboxedEnumValue,
+      DexType enumType) {
+    assert codeLens == getPrevious();
+    DexMethod reference = lookupMethod(method, context, type, codeLens).getReference();
+    if (!dispatchMethods.contains(reference) || !unboxedEnumValue.isSingleNumberValue()) {
+      return null;
+    }
+    // We know the exact type of enum, so there is no need to go for the dispatch method. Instead,
+    // we compute the exact target from the enum instance.
+    int unboxedEnum = unboxedEnumValue.asSingleNumberValue().getIntValue();
+    DexType instanceType =
+        unboxedEnums
+            .get(enumType)
+            .valuesTypes
+            .getOrDefault(unboxedIntToOrdinal(unboxedEnum), enumType);
+    DexMethod specializedMethod = method.withHolder(instanceType, dexItemFactory());
+    DexMethod superEnumMethod = method.withHolder(enumType, dexItemFactory());
+    DexMethod refined =
+        newMethodSignatures.getRepresentativeValueOrDefault(
+            specializedMethod, newMethodSignatures.getRepresentativeValue(superEnumMethod));
+    assert refined != null;
+    return refined;
+  }
+
   @Override
   public MethodLookupResult internalDescribeLookupMethod(
       MethodLookupResult previous, DexMethod context, GraphLens codeLens) {
@@ -187,13 +223,15 @@
     return type;
   }
 
-  public static Builder enumUnboxingLensBuilder(AppView<AppInfoWithLiveness> appView) {
-    return new Builder(appView);
+  public static Builder enumUnboxingLensBuilder(
+      AppView<AppInfoWithLiveness> appView, EnumDataMap enumDataMap) {
+    return new Builder(appView, enumDataMap);
   }
 
   static class Builder {
 
     private final DexItemFactory dexItemFactory;
+    private final AbstractValueFactory abstractValueFactory;
     private final Map<DexType, DexType> typeMap = new IdentityHashMap<>();
     private final MutableBidirectionalOneToOneMap<DexField, DexField> newFieldSignatures =
         new BidirectionalOneToOneHashMap<>();
@@ -204,8 +242,12 @@
     private final Map<DexMethod, RewrittenPrototypeDescription> prototypeChangesPerMethod =
         new IdentityHashMap<>();
 
-    Builder(AppView<AppInfoWithLiveness> appView) {
+    private final EnumDataMap enumDataMap;
+
+    Builder(AppView<AppInfoWithLiveness> appView, EnumDataMap enumDataMap) {
       this.dexItemFactory = appView.dexItemFactory();
+      this.abstractValueFactory = appView.abstractValueFactory();
+      this.enumDataMap = enumDataMap;
     }
 
     public Builder mapUnboxedEnums(Set<DexType> enumsToUnbox) {
@@ -279,14 +321,17 @@
         assert toStatic;
         offsetDiff = 1;
         if (!virtualReceiverAlreadyRemapped) {
-          builder
-              .addArgumentInfo(
-                  0,
-                  RewrittenTypeInfo.builder()
-                      .setOldType(from.getHolderType())
-                      .setNewType(to.getParameter(0))
-                      .build())
-              .setIsConvertedToStaticMethod();
+          RewrittenTypeInfo.Builder typeInfoBuilder =
+              RewrittenTypeInfo.builder()
+                  .setOldType(from.getHolderType())
+                  .setNewType(to.getParameter(0));
+          SingleNumberValue singleValue =
+              enumDataMap.getSingleNumberValueFromEnumType(
+                  abstractValueFactory, from.getHolderType());
+          if (singleValue != null) {
+            typeInfoBuilder.setSingleValue(singleValue);
+          }
+          builder.addArgumentInfo(0, typeInfoBuilder.build()).setIsConvertedToStaticMethod();
         } else {
           assert to.getParameter(0).isIntType();
           assert !fromStatic;
@@ -327,7 +372,7 @@
           originalCheckNotNullMethodSignature, checkNotNullMethod.getReference());
     }
 
-    public EnumUnboxingLens build(AppView<?> appView) {
+    public EnumUnboxingLens build(AppView<?> appView, Set<DexMethod> dispatchMethods) {
       assert !typeMap.isEmpty();
       return new EnumUnboxingLens(
           appView,
@@ -335,7 +380,8 @@
           newMethodSignatures,
           typeMap,
           methodMap,
-          ImmutableMap.copyOf(prototypeChangesPerMethod));
+          ImmutableMap.copyOf(prototypeChangesPerMethod),
+          dispatchMethods);
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index cbbd4b9..789c542 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -37,6 +37,7 @@
 import com.android.tools.r8.ir.code.NewUnboxedEnumInstance;
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.StaticGet;
+import com.android.tools.r8.ir.code.TypeAndLocalInfoSupplier;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.optimize.enums.EnumInstanceFieldData.EnumInstanceFieldKnownData;
@@ -46,9 +47,9 @@
 import com.android.tools.r8.utils.InternalOptions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Sets;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.IdentityHashMap;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -87,9 +88,10 @@
   }
 
   private Map<Instruction, DexType> createInitialConvertedEnums(
-      IRCode code, RewrittenPrototypeDescription prototypeChanges) {
+      IRCode code, RewrittenPrototypeDescription prototypeChanges, Set<Phi> affectedPhis) {
     Map<Instruction, DexType> convertedEnums = new IdentityHashMap<>();
-    Iterator<Instruction> iterator = code.entryBlock().iterator();
+    List<Instruction> extraConstants = new ArrayList<>();
+    InstructionListIterator iterator = code.entryBlock().listIterator(code);
     int originalNumberOfArguments =
         code.getNumberOfArguments()
             + prototypeChanges.getArgumentInfoCollection().numberOfRemovedArguments();
@@ -105,11 +107,36 @@
         RewrittenTypeInfo rewrittenTypeInfo = argumentInfo.asRewrittenTypeInfo();
         DexType enumType =
             getEnumClassTypeOrNull(rewrittenTypeInfo.getOldType().toBaseType(factory));
-        if (enumType != null) {
+        if (rewrittenTypeInfo.hasSingleValue()
+            && rewrittenTypeInfo.getSingleValue().isSingleNumberValue()) {
+          assert rewrittenTypeInfo
+              .getSingleValue()
+              .isMaterializableInContext(appView, code.context());
+          Instruction materializingInstruction =
+              rewrittenTypeInfo
+                  .getSingleValue()
+                  .createMaterializingInstruction(
+                      appView,
+                      code,
+                      TypeAndLocalInfoSupplier.create(
+                          rewrittenTypeInfo.getNewType().toTypeElement(appView),
+                          next.getLocalInfo()));
+          materializingInstruction.setPosition(next.getPosition());
+          extraConstants.add(materializingInstruction);
+          affectedPhis.addAll(next.outValue().uniquePhiUsers());
+          next.outValue().replaceUsers(materializingInstruction.outValue());
+          convertedEnums.put(materializingInstruction, enumType);
+        } else if (enumType != null) {
           convertedEnums.put(next, enumType);
         }
       }
     }
+    if (!extraConstants.isEmpty()) {
+      assert extraConstants.size() == 1; // So far this is used only for unboxed enums "this".
+      for (Instruction extraConstant : extraConstants) {
+        iterator.add(extraConstant);
+      }
+    }
     return convertedEnums;
   }
 
@@ -125,8 +152,9 @@
     assert code.isConsistentSSABeforeTypesAreCorrect(appView);
     ProgramMethod context = code.context();
     EnumUnboxerMethodProcessorEventConsumer eventConsumer = methodProcessor.getEventConsumer();
-    Map<Instruction, DexType> convertedEnums = createInitialConvertedEnums(code, prototypeChanges);
     Set<Phi> affectedPhis = Sets.newIdentityHashSet();
+    Map<Instruction, DexType> convertedEnums =
+        createInitialConvertedEnums(code, prototypeChanges, affectedPhis);
     BasicBlockIterator blocks = code.listIterator();
     Set<BasicBlock> seenBlocks = Sets.newIdentityHashSet();
     Set<Instruction> instructionsToRemove = Sets.newIdentityHashSet();
@@ -237,6 +265,23 @@
             } else if (invokedMethod == factory.objectMembers.getClass) {
               rewriteNullCheck(iterator, invoke, context, eventConsumer);
               continue;
+            } else if (invoke.isInvokeVirtual() || invoke.isInvokeInterface()) {
+              DexMethod refinedDispatchMethodReference =
+                  enumUnboxingLens.lookupRefinedDispatchMethod(
+                      invokedMethod,
+                      context.getReference(),
+                      invoke.getType(),
+                      enumUnboxingLens.getPrevious(),
+                      invoke.getArgument(0).getAbstractValue(appView, context),
+                      enumType);
+              if (refinedDispatchMethodReference != null) {
+                DexClassAndMethod refinedDispatchMethod =
+                    appView.definitionFor(refinedDispatchMethodReference);
+                assert refinedDispatchMethod != null;
+                assert refinedDispatchMethod.isProgramMethod();
+                replaceEnumInvoke(iterator, invoke, refinedDispatchMethod.asProgramMethod());
+              }
+              continue;
             }
           } else if (invokedMethod == factory.stringBuilderMethods.appendObject
               || invokedMethod == factory.stringBufferMethods.appendObject) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
index e08b479..6ac7697 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
@@ -115,7 +115,8 @@
     this.factory = appView.dexItemFactory();
     this.unboxedEnumHierarchy = unboxedEnums;
     this.lensBuilder =
-        EnumUnboxingLens.enumUnboxingLensBuilder(appView).mapUnboxedEnums(getUnboxedEnums());
+        EnumUnboxingLens.enumUnboxingLensBuilder(appView, enumDataMap)
+            .mapUnboxedEnums(getUnboxedEnums());
     this.utilityClasses = utilityClasses;
     this.prunedItemsBuilder = PrunedItems.concurrentBuilder();
     this.profileCollectionAdditions = ProfileCollectionAdditions.create(appView);
@@ -146,7 +147,9 @@
         .fixupClassesConcurrentlyByConnectedProgramComponents(Timing.empty(), executorService);
 
     // Install the new graph lens before processing any checkNotZero() methods.
-    EnumUnboxingLens lens = lensBuilder.build(appView);
+    Set<DexMethod> dispatchMethodReferences = Sets.newIdentityHashSet();
+    dispatchMethods.forEach((method, code) -> dispatchMethodReferences.add(method.getReference()));
+    EnumUnboxingLens lens = lensBuilder.build(appView, dispatchMethodReferences);
     appView.rewriteWithLens(lens);
 
     // Rewrite outliner with lens.
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/enummerging/EnumMergingInstantiatingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/enummerging/EnumMergingInstantiatingTest.java
new file mode 100644
index 0000000..074e0e9
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/enumunboxing/enummerging/EnumMergingInstantiatingTest.java
@@ -0,0 +1,190 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.enumunboxing.enummerging;
+
+import com.android.tools.r8.AlwaysInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.NoReturnTypeStrengthening;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.enumunboxing.EnumUnboxingTestBase;
+import com.android.tools.r8.utils.DescriptorUtils;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class EnumMergingInstantiatingTest extends EnumUnboxingTestBase {
+
+  private static Collection<byte[]> PROGRAM_CLASSES_DATA;
+
+  private final TestParameters parameters;
+  private final boolean enumValueOptimization;
+  private final EnumKeepRules enumKeepRules;
+
+  @Parameters(name = "{0} valueOpt: {1} keep: {2}")
+  public static List<Object[]> data() {
+    return enumUnboxingTestParameters();
+  }
+
+  @BeforeClass
+  public static void setup() throws IOException {
+    PROGRAM_CLASSES_DATA = tranformedInputs();
+  }
+
+  public EnumMergingInstantiatingTest(
+      TestParameters parameters, boolean enumValueOptimization, EnumKeepRules enumKeepRules) {
+    this.parameters = parameters;
+    this.enumValueOptimization = enumValueOptimization;
+    this.enumKeepRules = enumKeepRules;
+  }
+
+  @Test
+  public void testEnumUnboxing() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClassFileData(PROGRAM_CLASSES_DATA)
+        .addKeepMainRule(Main.class)
+        .addKeepRules(enumKeepRules.getKeepRules())
+        .addOptionsModification(opt -> opt.testing.enableEnumWithSubtypesUnboxing = true)
+        .addEnumUnboxingInspector(inspector -> inspector.assertUnboxed(InstantiatingEnum.class))
+        .enableInliningAnnotations()
+        .enableAlwaysInliningAnnotations()
+        .enableNoHorizontalClassMergingAnnotations()
+        .enableNoReturnTypeStrengtheningAnnotations()
+        .addOptionsModification(opt -> enableEnumOptions(opt, enumValueOptimization))
+        .setMinApi(parameters)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("AC/DC", "AC/DC", "AC/DC");
+  }
+
+  private static String getNewDescriptorForName(String name, String descriptor) {
+    String descr = DescriptorUtils.javaTypeToDescriptor(InstantiatingEnum.class.getTypeName());
+    String prefix = descr.substring(0, descr.length() - 1);
+    if (name.equals("A")) {
+      return prefix + "$1;";
+    }
+    if (name.equals("D")) {
+      return prefix + "$2;";
+    }
+    return descriptor;
+  }
+
+  public static Collection<byte[]> tranformedInputs() throws IOException {
+    Collection<Path> classFilesForInnerClasses =
+        ToolHelper.getClassFilesForInnerClasses(EnumMergingInstantiatingTest.class);
+    ArrayList<byte[]> bytes = new ArrayList<>();
+    for (Path classFilesForInnerClass : classFilesForInnerClasses) {
+      bytes.add(transform(classFilesForInnerClass));
+    }
+    return bytes;
+  }
+
+  public static byte[] transform(Path path) throws IOException {
+    return transformer(path, null)
+        .changeFieldType(
+            name -> name.equals("A") || name.equals("D"),
+            EnumMergingInstantiatingTest::getNewDescriptorForName)
+        .transform();
+  }
+
+  public static class C {
+
+    @AlwaysInline
+    public static C dispatch(InstantiatingEnum theEnum) {
+      // Only after inlining can the invoke be respecialized to each of the subtype.
+      return theEnum.newEntry();
+    }
+  }
+
+  @NoHorizontalClassMerging
+  public static class AC extends C {
+
+    @Override
+    public String toString() {
+      return "AC";
+    }
+  }
+
+  @NoHorizontalClassMerging
+  public static class DC extends C {
+
+    @Override
+    public String toString() {
+      return "DC";
+    }
+  }
+
+  enum InstantiatingEnum {
+    A {
+      @NeverInline
+      @Override
+      public C newEntry() {
+        return new AC();
+      }
+
+      @NeverInline
+      @Override
+      public C newEntryThroughDispatch() {
+        return C.dispatch(this);
+      }
+    },
+    D {
+      @NeverInline
+      @Override
+      public C newEntry() {
+        return new DC();
+      }
+
+      @NeverInline
+      @Override
+      public C newEntryThroughDispatch() {
+        return C.dispatch(this);
+      }
+    };
+
+    @NeverInline
+    public abstract C newEntry();
+
+    @NeverInline
+    public abstract C newEntryThroughDispatch();
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      System.out.println(getAC() + "/" + getDC());
+      System.out.println(getC(InstantiatingEnum.A) + "/" + getC(InstantiatingEnum.D));
+      System.out.println(
+          InstantiatingEnum.A.newEntryThroughDispatch()
+              + "/"
+              + InstantiatingEnum.D.newEntryThroughDispatch());
+    }
+
+    @NeverInline
+    @NoReturnTypeStrengthening
+    private static C getAC() {
+      return InstantiatingEnum.A.newEntry();
+    }
+
+    @NeverInline
+    @NoReturnTypeStrengthening
+    private static C getDC() {
+      return InstantiatingEnum.D.newEntry();
+    }
+
+    @NeverInline
+    private static C getC(InstantiatingEnum e) {
+      return e.newEntry();
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/transformers/ClassFileTransformer.java b/src/test/java/com/android/tools/r8/transformers/ClassFileTransformer.java
index ce200a0..63573d2 100644
--- a/src/test/java/com/android/tools/r8/transformers/ClassFileTransformer.java
+++ b/src/test/java/com/android/tools/r8/transformers/ClassFileTransformer.java
@@ -37,6 +37,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Predicate;
@@ -887,6 +888,34 @@
         });
   }
 
+  public ClassFileTransformer changeFieldType(
+      Predicate<String> fieldPredicate,
+      BiFunction<String, String, String> newDescriptorTransformer) {
+    return addClassTransformer(
+            new ClassTransformer() {
+              @Override
+              public FieldVisitor visitField(
+                  int access, String name, String descriptor, String signature, Object value) {
+                String newDescriptor =
+                    fieldPredicate.test(name)
+                        ? newDescriptorTransformer.apply(name, descriptor)
+                        : descriptor;
+                return super.visitField(access, name, newDescriptor, signature, value);
+              }
+            })
+        .addMethodTransformer(
+            new MethodTransformer() {
+              @Override
+              public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
+                String newDescriptor =
+                    fieldPredicate.test(name)
+                        ? newDescriptorTransformer.apply(name, descriptor)
+                        : descriptor;
+                super.visitFieldInsn(opcode, owner, name, newDescriptor);
+              }
+            });
+  }
+
   public ClassFileTransformer renameAndRemapField(String oldName, String newName) {
     FieldSignaturePredicate matchPredicate = (name, signature) -> oldName.equals(name);
     remapField(matchPredicate, newName);