Correctly handle constructors in Enum subtypes

Correctly handle dispatch on abstract enum methods

Bug: b/271385332
Change-Id: I848410971e3a88666c195eadb4b4b176a088cd2a
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
index 06185fc..6e75300 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
@@ -408,8 +408,11 @@
               instructionsToRemove.put(constructorInvoke, Optional.empty());
             }
 
-            ProgramMethod constructor =
-                unboxedEnum.lookupProgramMethod(lookupResult.getReference());
+            DexProgramClass holder =
+                newInstance.getType() == unboxedEnum.getType()
+                    ? unboxedEnum
+                    : appView.programDefinitionFor(newInstance.getType(), classInitializer);
+            ProgramMethod constructor = holder.lookupProgramMethod(lookupResult.getReference());
             assert constructor != null;
 
             InstanceFieldInitializationInfo ordinalInitializationInfo =
@@ -646,8 +649,7 @@
     if (superMethod.isProgramMethod()) {
       superUtilityMethod =
           installLocalUtilityMethod(
-                  localUtilityClass, localUtilityMethods, superMethod.asProgramMethod())
-              .getReference();
+              localUtilityClass, localUtilityMethods, superMethod.asProgramMethod());
     } else {
       // All methods but toString() are final or non-virtual.
       // We could support other cases by setting correctly the superUtilityMethod here.
@@ -656,10 +658,10 @@
     }
     Map<DexMethod, DexMethod> overrideToUtilityMethods = new IdentityHashMap<>();
     for (ProgramMethod subMethod : subimplementations) {
-      DexEncodedMethod subEnumLocalUtilityMethod =
+      DexMethod subEnumLocalUtilityMethod =
           installLocalUtilityMethod(localUtilityClass, localUtilityMethods, subMethod);
-      overrideToUtilityMethods.put(
-          subMethod.getReference(), subEnumLocalUtilityMethod.getReference());
+      assert subEnumLocalUtilityMethod != null;
+      overrideToUtilityMethods.put(subMethod.getReference(), subEnumLocalUtilityMethod);
     }
     DexMethod dispatch =
         installDispatchMethod(
@@ -687,16 +689,20 @@
       LocalEnumUnboxingUtilityClass localUtilityClass,
       Map<DexMethod, DexEncodedMethod> localUtilityMethods,
       ProgramMethod method) {
-    DexEncodedMethod utilityMethod =
+    DexMethod utilityMethod =
         installLocalUtilityMethod(localUtilityClass, localUtilityMethods, method);
-    lensBuilder.moveAndMap(
-        method.getReference(), utilityMethod.getReference(), method.getDefinition().isStatic());
+    assert utilityMethod != null;
+    lensBuilder.moveAndMap(method.getReference(), utilityMethod, method.getDefinition().isStatic());
   }
 
   public void recordEmulatedDispatch(DexMethod from, DexMethod move, DexMethod dispatch) {
     // Move is used for getRenamedSignature and to remap invoke-super.
     // Map is used to remap all the other invokes.
-    lensBuilder.moveVirtual(from, move);
+    assert from != null;
+    assert dispatch != null;
+    if (move != null) {
+      lensBuilder.moveVirtual(from, move);
+    }
     lensBuilder.mapToDispatch(from, dispatch);
   }
 
@@ -749,10 +755,13 @@
     return newLocalUtilityMethod;
   }
 
-  private DexEncodedMethod installLocalUtilityMethod(
+  private DexMethod installLocalUtilityMethod(
       LocalEnumUnboxingUtilityClass localUtilityClass,
       Map<DexMethod, DexEncodedMethod> localUtilityMethods,
       ProgramMethod method) {
+    if (method.getAccessFlags().isAbstract()) {
+      return null;
+    }
     DexEncodedMethod newLocalUtilityMethod =
         createLocalUtilityMethod(
             method,
@@ -760,7 +769,7 @@
             newMethodSignature -> !localUtilityMethods.containsKey(newMethodSignature));
     assert !localUtilityMethods.containsKey(newLocalUtilityMethod.getReference());
     localUtilityMethods.put(newLocalUtilityMethod.getReference(), newLocalUtilityMethod);
-    return newLocalUtilityMethod;
+    return newLocalUtilityMethod.getReference();
   }
 
   private DexEncodedMethod createLocalUtilityMethod(
diff --git a/src/main/java/com/android/tools/r8/ir/synthetic/EnumUnboxingCfCodeProvider.java b/src/main/java/com/android/tools/r8/ir/synthetic/EnumUnboxingCfCodeProvider.java
index af9e41c..545e609 100644
--- a/src/main/java/com/android/tools/r8/ir/synthetic/EnumUnboxingCfCodeProvider.java
+++ b/src/main/java/com/android/tools/r8/ir/synthetic/EnumUnboxingCfCodeProvider.java
@@ -36,6 +36,7 @@
 import com.android.tools.r8.ir.code.ValueType;
 import com.android.tools.r8.ir.optimize.enums.EnumDataMap.EnumData;
 import com.android.tools.r8.ir.optimize.enums.EnumInstanceFieldData.EnumInstanceFieldMappingData;
+import com.android.tools.r8.utils.IntBox;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
 import java.util.ArrayList;
 import java.util.List;
@@ -90,27 +91,43 @@
       // TODO(b/167942775): Should use a table-switch for large enums (maybe same threshold in the
       //  rewriter of switchmaps).
 
+      assert !methodMap.isEmpty();
       DexItemFactory factory = appView.dexItemFactory();
-      int returnInvokeSize = superEnumMethod.getParameters().size() + 2;
-      List<CfInstruction> instructions =
-          new ArrayList<>(methodMap.size() * (returnInvokeSize + 5) + returnInvokeSize);
+      boolean hasDefaultCase = superEnumMethod != null;
+      DexMethod representative = methodMap.values().iterator().next();
+
+      int invokeSize = representative.getParameters().size() + 2;
+      int branchSize = 5;
+      int instructionsSize =
+          methodMap.size() * (invokeSize + branchSize)
+              + (hasDefaultCase ? invokeSize : -branchSize);
+      List<CfInstruction> instructions = new ArrayList<>(instructionsSize);
 
       CfFrame.Builder frameBuilder = CfFrame.builder();
-      for (DexType parameter : superEnumMethod.getParameters()) {
+      for (DexType parameter : representative.getParameters()) {
         frameBuilder.appendLocal(FrameType.initialized(parameter));
       }
+      IntBox index = new IntBox();
       methodMap.forEach(
           (unboxedEnumValue, method) -> {
-            CfLabel dest = new CfLabel();
-            instructions.add(new CfLoad(ValueType.fromDexType(factory.intType), 0));
-            instructions.add(new CfConstNumber(unboxedEnumValue, ValueType.INT));
-            instructions.add(new CfIfCmp(IfType.NE, ValueType.INT, dest));
-            addReturnInvoke(instructions, method);
-            instructions.add(dest);
-            instructions.add(frameBuilder.build());
+            boolean lastCase = index.incrementAndGet() == methodMap.size() && !hasDefaultCase;
+            if (!lastCase) {
+              CfLabel dest = new CfLabel();
+              instructions.add(new CfLoad(ValueType.fromDexType(factory.intType), 0));
+              instructions.add(new CfConstNumber(unboxedEnumValue, ValueType.INT));
+              instructions.add(new CfIfCmp(IfType.NE, ValueType.INT, dest));
+              addReturnInvoke(instructions, method);
+              instructions.add(dest);
+              instructions.add(frameBuilder.build());
+            } else {
+              addReturnInvoke(instructions, method);
+            }
           });
 
-      addReturnInvoke(instructions, superEnumMethod);
+      if (hasDefaultCase) {
+        addReturnInvoke(instructions, superEnumMethod);
+      }
+      assert instructions.size() == instructionsSize;
       return new CfCodeWithLens(getHolder(), defaultMaxStack(), defaultMaxLocals(), instructions);
     }
 
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/enummerging/AbstractEnumMergingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/enummerging/AbstractEnumMergingTest.java
new file mode 100644
index 0000000..9335996
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/enumunboxing/enummerging/AbstractEnumMergingTest.java
@@ -0,0 +1,87 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.enumunboxing.enummerging;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.enumunboxing.EnumUnboxingTestBase;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class AbstractEnumMergingTest extends EnumUnboxingTestBase {
+
+  private final TestParameters parameters;
+  private final boolean enumValueOptimization;
+  private final EnumKeepRules enumKeepRules;
+
+  @Parameters(name = "{0} valueOpt: {1} keep: {2}")
+  public static List<Object[]> data() {
+    return enumUnboxingTestParameters();
+  }
+
+  public AbstractEnumMergingTest(
+      TestParameters parameters, boolean enumValueOptimization, EnumKeepRules enumKeepRules) {
+    this.parameters = parameters;
+    this.enumValueOptimization = enumValueOptimization;
+    this.enumKeepRules = enumKeepRules;
+  }
+
+  @Test
+  public void testEnumUnboxing() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addKeepRules(enumKeepRules.getKeepRules())
+        .addOptionsModification(opt -> opt.testing.enableEnumWithSubtypesUnboxing = true)
+        .addEnumUnboxingInspector(inspector -> inspector.assertUnboxed(MyEnum.class))
+        .enableInliningAnnotations()
+        .addOptionsModification(opt -> enableEnumOptions(opt, enumValueOptimization))
+        .setMinApi(parameters)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("336", "74", "96", "44");
+  }
+
+  enum MyEnum {
+    A(8) {
+      @NeverInline
+      @Override
+      public long operate(long another) {
+        return num * another;
+      }
+    },
+    B(32) {
+      @NeverInline
+      @Override
+      public long operate(long another) {
+        return num + another;
+      }
+    };
+    final long num;
+
+    MyEnum(long num) {
+      this.num = num;
+    }
+
+    public abstract long operate(long another);
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      System.out.println(MyEnum.A.operate(42));
+      System.out.println(MyEnum.B.operate(42));
+      System.out.println(indirect(MyEnum.A));
+      System.out.println(indirect(MyEnum.B));
+    }
+
+    @NeverInline
+    public static long indirect(MyEnum e) {
+      return e.operate(12);
+    }
+  }
+}