Enum unboxing: fix rebinding issues

Bug: 170798502
Change-Id: Ifcc96ad9ae5a4b24967ebc444bbbeeb33775d67f
diff --git a/src/main/java/com/android/tools/r8/graph/DexEncodedMethod.java b/src/main/java/com/android/tools/r8/graph/DexEncodedMethod.java
index 1e628e8..84019a9 100644
--- a/src/main/java/com/android/tools/r8/graph/DexEncodedMethod.java
+++ b/src/main/java/com/android/tools/r8/graph/DexEncodedMethod.java
@@ -179,6 +179,10 @@
     assert !obsolete;
   }
 
+  public CompilationState getCompilationState() {
+    return compilationState;
+  }
+
   public boolean isObsolete() {
     // Do not be cheating. This util can be used only if you're going to do appropriate action,
     // e.g., using GraphLens#mapDexEncodedMethod to look up the correct, up-to-date instance.
@@ -1448,6 +1452,10 @@
       }
     }
 
+    public void setCompilationState(CompilationState compilationState) {
+      this.compilationState = compilationState;
+    }
+
     public void setMethod(DexMethod method) {
       this.method = method;
     }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index fbe8514..f65b476 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -156,7 +156,7 @@
         }
         switch (instruction.opcode()) {
           case Opcodes.CONST_CLASS:
-            analyzeConstClass(instruction.asConstClass(), eligibleEnums);
+            analyzeConstClass(instruction.asConstClass(), eligibleEnums, code.context());
             break;
           case Opcodes.CHECK_CAST:
             analyzeCheckCast(instruction.asCheckCast(), eligibleEnums);
@@ -242,7 +242,8 @@
         TypeElement.fromDexType(checkCast.getType(), definitelyNotNull(), appView));
   }
 
-  private void analyzeConstClass(ConstClass constClass, Set<DexType> eligibleEnums) {
+  private void analyzeConstClass(
+      ConstClass constClass, Set<DexType> eligibleEnums, ProgramMethod context) {
     // We are using the ConstClass of an enum, which typically means the enum cannot be unboxed.
     // We however allow unboxing if the ConstClass is used only:
     // - as an argument to Enum#valueOf, to allow unboxing of:
@@ -267,15 +268,18 @@
           && isUnboxableNameMethod(user.asInvokeVirtual().getInvokedMethod())) {
         continue;
       }
-      if (!(user.isInvokeStatic()
-          && user.asInvokeStatic().getInvokedMethod() == factory.enumMembers.valueOf)) {
-        markEnumAsUnboxable(Reason.CONST_CLASS, enumClass);
-        return;
+      if (user.isInvokeStatic()) {
+        DexEncodedMethod singleTarget = user.asInvokeStatic().lookupSingleTarget(appView, context);
+        if (singleTarget != null && singleTarget.getReference() == factory.enumMembers.valueOf) {
+          // The name data is required for the correct mapping from the enum name to the ordinal in
+          // the valueOf utility method.
+          addRequiredNameData(enumType);
+          continue;
+        }
       }
+      markEnumAsUnboxable(Reason.CONST_CLASS, enumClass);
+      return;
     }
-    // The name data is required for the correct mapping from the enum name to the ordinal in the
-    // valueOf utility method.
-    addRequiredNameData(enumType);
     eligibleEnums.add(enumType);
   }
 
@@ -370,6 +374,7 @@
     NestedGraphLens enumUnboxingLens =
         new EnumUnboxingTreeFixer(appView, enumsToUnbox, relocator, enumUnboxerRewriter)
             .fixupTypeReferences();
+    enumUnboxerRewriter.setEnumUnboxingLens(enumUnboxingLens);
     appView.setUnboxedEnums(enumUnboxerRewriter.getEnumsToUnbox());
     GraphLens previousLens = appView.graphLens();
     appView.rewriteWithLensAndApplication(enumUnboxingLens, appBuilder.build());
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 02b521f..1bca720 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -24,6 +24,7 @@
 import com.android.tools.r8.graph.EnumValueInfoMapCollection.EnumValueInfo;
 import com.android.tools.r8.graph.EnumValueInfoMapCollection.EnumValueInfoMap;
 import com.android.tools.r8.graph.FieldAccessFlags;
+import com.android.tools.r8.graph.GraphLens.NestedGraphLens;
 import com.android.tools.r8.graph.MethodAccessFlags;
 import com.android.tools.r8.graph.ParameterAnnotationsList;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -75,6 +76,7 @@
   private final EnumValueInfoMapCollection enumsToUnbox;
   private final EnumInstanceFieldDataMap unboxedEnumsInstanceFieldData;
   private final UnboxedEnumMemberRelocator relocator;
+  private NestedGraphLens enumUnboxingLens;
 
   private final Map<DexMethod, DexEncodedMethod> utilityMethods = new ConcurrentHashMap<>();
   private final Map<DexField, DexEncodedField> utilityFields = new ConcurrentHashMap<>();
@@ -138,6 +140,10 @@
             ENUM_UNBOXING_UTILITY_METHOD_PREFIX + "zeroCheckMessage");
   }
 
+  public void setEnumUnboxingLens(NestedGraphLens enumUnboxingLens) {
+    this.enumUnboxingLens = enumUnboxingLens;
+  }
+
   public EnumValueInfoMapCollection getEnumsToUnbox() {
     return enumsToUnbox;
   }
@@ -149,6 +155,7 @@
       return Sets.newIdentityHashSet();
     }
     assert code.isConsistentSSABeforeTypesAreCorrect();
+    ProgramMethod context = code.context();
     Map<Instruction, DexType> convertedEnums = new IdentityHashMap<>();
     Set<Phi> affectedPhis = Sets.newIdentityHashSet();
     ListIterator<BasicBlock> blocks = code.listIterator();
@@ -160,18 +167,23 @@
       while (iterator.hasNext()) {
         Instruction instruction = iterator.next();
         // Rewrites specific enum methods, such as ordinal, into their corresponding enum unboxed
-        // counterpart.
+        // counterpart. The rewriting (== or match) is based on the following:
+        // - name, ordinal and compareTo are final and implemented only on java.lang.Enum,
+        // - equals, hashCode are final and implemented in java.lang.Enum and java.lang.Object,
+        // - getClass is final and implemented only in java.lang.Object,
+        // - toString is non-final, implemented in java.lang.Object, java.lang.Enum and possibly
+        //   also in the unboxed enum class.
         if (instruction.isInvokeMethodWithReceiver()) {
           InvokeMethodWithReceiver invokeMethod = instruction.asInvokeMethodWithReceiver();
-          DexMethod invokedMethod = invokeMethod.getInvokedMethod();
           DexType enumType = getEnumTypeOrNull(invokeMethod.getReceiver(), convertedEnums);
           if (enumType != null) {
+            DexMethod invokedMethod = invokeMethod.getInvokedMethod();
             if (invokedMethod == factory.enumMembers.ordinalMethod
-                || invokedMethod == factory.enumMembers.hashCode) {
+                || invokedMethod.match(factory.enumMembers.hashCode)) {
               replaceEnumInvoke(
                   iterator, invokeMethod, ordinalUtilityMethod, m -> synthesizeOrdinalMethod());
               continue;
-            } else if (invokedMethod == factory.enumMembers.equals) {
+            } else if (invokedMethod.match(factory.enumMembers.equals)) {
               replaceEnumInvoke(
                   iterator, invokeMethod, equalsUtilityMethod, m -> synthesizeEqualsMethod());
               continue;
@@ -179,14 +191,18 @@
               replaceEnumInvoke(
                   iterator, invokeMethod, compareToUtilityMethod, m -> synthesizeCompareToMethod());
               continue;
-            } else if (invokedMethod == factory.enumMembers.nameMethod
-                || invokedMethod == factory.enumMembers.toString) {
-              DexMethod toStringMethod =
-                  computeInstanceFieldUtilityMethod(enumType, factory.enumMembers.nameField);
-              iterator.replaceCurrentInstruction(
-                  new InvokeStatic(
-                      toStringMethod, invokeMethod.outValue(), invokeMethod.arguments()));
+            } else if (invokedMethod == factory.enumMembers.nameMethod) {
+              rewriteNameMethod(iterator, invokeMethod, enumType);
               continue;
+            } else if (invokedMethod.match(factory.enumMembers.toString)) {
+              DexMethod lookupMethod = enumUnboxingLens.lookupMethod(invokedMethod);
+              // If the lookupMethod is different, then a toString method was on the enumType
+              // class, which was moved, and the lens code rewriter will rewrite the invoke to
+              // that method.
+              if (invokeMethod.isInvokeSuper() || lookupMethod == invokedMethod) {
+                rewriteNameMethod(iterator, invokeMethod, enumType);
+                continue;
+              }
             } else if (invokedMethod == factory.objectMembers.getClass) {
               assert !invokeMethod.hasOutValue() || !invokeMethod.outValue().hasAnyUsers();
               replaceEnumInvoke(
@@ -195,11 +211,15 @@
           }
         } else if (instruction.isInvokeStatic()) {
           InvokeStatic invokeStatic = instruction.asInvokeStatic();
-          DexMethod invokedMethod = invokeStatic.getInvokedMethod();
+          DexEncodedMethod singleTarget = invokeStatic.lookupSingleTarget(appView, context);
+          if (singleTarget == null) {
+            continue;
+          }
+          DexMethod invokedMethod = singleTarget.getReference();
           if (invokedMethod == factory.enumMembers.valueOf
-              && invokeStatic.inValues().get(0).isConstClass()) {
+              && invokeStatic.getArgument(0).isConstClass()) {
             DexType enumType =
-                invokeStatic.inValues().get(0).getConstInstruction().asConstClass().getValue();
+                invokeStatic.getArgument(0).getConstInstruction().asConstClass().getValue();
             if (enumsToUnbox.containsEnum(enumType)) {
               DexMethod valueOfMethod = computeValueOfUtilityMethod(enumType);
               Value outValue = invokeStatic.outValue();
@@ -320,6 +340,9 @@
           ArrayAccess arrayAccess = instruction.asArrayAccess();
           DexType enumType = getEnumTypeOrNull(arrayAccess);
           if (enumType != null) {
+            if (arrayAccess.hasOutValue()) {
+              affectedPhis.addAll(arrayAccess.outValue().uniquePhiUsers());
+            }
             instruction = arrayAccess.withMemberType(MemberType.INT);
             iterator.replaceCurrentInstruction(instruction);
             convertedEnums.put(instruction, enumType);
@@ -332,6 +355,14 @@
     return affectedPhis;
   }
 
+  private void rewriteNameMethod(
+      InstructionListIterator iterator, InvokeMethodWithReceiver invokeMethod, DexType enumType) {
+    DexMethod toStringMethod =
+        computeInstanceFieldUtilityMethod(enumType, factory.enumMembers.nameField);
+    iterator.replaceCurrentInstruction(
+        new InvokeStatic(toStringMethod, invokeMethod.outValue(), invokeMethod.arguments()));
+  }
+
   private Value fixNullsInBlockPhis(IRCode code, BasicBlock block, Value zeroConstValue) {
     for (Phi phi : block.getPhis()) {
       if (getEnumTypeOrNull(phi.getType()) != null) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
index 634d74e..9c3aaef 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
@@ -113,7 +113,8 @@
     encodedMethod.accessFlags.promoteToStatic();
     encodedMethod.clearAnnotations();
     encodedMethod.clearParameterAnnotations();
-    return encodedMethod.toTypeSubstitutedMethod(newMethod);
+    return encodedMethod.toTypeSubstitutedMethod(
+        newMethod, builder -> builder.setCompilationState(encodedMethod.getCompilationState()));
   }
 
   private DexEncodedMethod fixupEncodedMethod(DexEncodedMethod encodedMethod) {
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/PhiEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/PhiEnumUnboxingTest.java
index 7a0c8bb..d2851c1 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/PhiEnumUnboxingTest.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/PhiEnumUnboxingTest.java
@@ -65,6 +65,8 @@
     public static void main(String[] args) {
       nonNullTest();
       nullTest();
+      arrayGetAndPhiTest();
+      argumentsAndPhiTest();
     }
 
     private static void nonNullTest() {
@@ -81,6 +83,21 @@
       System.out.println(true);
     }
 
+    private static void arrayGetAndPhiTest() {
+      MyEnum[] values = MyEnum.values();
+      System.out.println(arrayGetAndPhi(values, true));
+      System.out.println(values[1].ordinal());
+      System.out.println(arrayGetAndPhi(values, false));
+      System.out.println(values[0].ordinal());
+    }
+
+    private static void argumentsAndPhiTest() {
+      System.out.println(argumentsAndPhi(MyEnum.A, MyEnum.B, true));
+      System.out.println(MyEnum.B.ordinal());
+      System.out.println(argumentsAndPhi(MyEnum.A, MyEnum.B, false));
+      System.out.println(MyEnum.A.ordinal());
+    }
+
     @NeverInline
     static MyEnum switchOn(int i) {
       MyEnum returnValue;
@@ -112,5 +129,15 @@
       }
       return returnValue;
     }
+
+    @NeverInline
+    static int arrayGetAndPhi(MyEnum[] enums, boolean b) {
+      return (b ? enums[1] : enums[0]).ordinal();
+    }
+
+    @NeverInline
+    static int argumentsAndPhi(MyEnum e0, MyEnum e1, boolean b) {
+      return (b ? e1 : e0).ordinal();
+    }
   }
 }