Enum unboxing: equals and compareTo support

Bug: 157112269
Change-Id: I53f9766b9c17470af795ddaa0fce062698ec9a81
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 3b635e6..1a318c0 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -1268,6 +1268,7 @@
     public final DexMethod name;
     public final DexMethod toString;
     public final DexMethod compareTo;
+    public final DexMethod equals;
 
     public final DexMethod constructor =
         createMethod(enumType, createProto(voidType, stringType, intType), constructorMethodName);
@@ -1301,10 +1302,13 @@
               DexString.EMPTY_ARRAY);
       compareTo =
           createMethod(
+              enumDescriptor, compareToMethodName, intDescriptor, new DexString[] {enumDescriptor});
+      equals =
+          createMethod(
               enumDescriptor,
-              compareToMethodName,
-              stringDescriptor,
-              new DexString[] {enumDescriptor});
+              equalsMethodName,
+              booleanDescriptor,
+              new DexString[] {objectDescriptor});
     }
 
     public boolean isValuesMethod(DexMethod method, DexClass enumClass) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index 3461a0f..379fc28 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -446,14 +446,12 @@
       if (dexClass.type != factory.enumType) {
         return Reason.UNSUPPORTED_LIBRARY_CALL;
       }
-      // TODO(b/147860220): Methods toString(), name(), compareTo(), EnumSet and EnumMap may
-      // be interesting to model. A the moment rewrite only Enum#ordinal() and Enum#valueOf.
-      if (debugLogEnabled) {
-        if (singleTarget == factory.enumMethods.compareTo) {
-          return Reason.COMPARE_TO_INVOKE;
-        }
-      }
-      if (singleTarget == factory.enumMethods.name) {
+      // TODO(b/147860220): EnumSet and EnumMap may be interesting to model.
+      if (singleTarget == factory.enumMethods.compareTo) {
+        return Reason.ELIGIBLE;
+      } else if (singleTarget == factory.enumMethods.equals) {
+        return Reason.ELIGIBLE;
+      } else if (singleTarget == factory.enumMethods.name) {
         return Reason.ELIGIBLE;
       } else if (singleTarget == factory.enumMethods.toString) {
         return Reason.ELIGIBLE;
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index e417ad9..541cd92 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -54,26 +54,25 @@
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.function.Function;
 
 public class EnumUnboxingRewriter {
 
   public static final String ENUM_UNBOXING_UTILITY_CLASS_NAME = "$r8$EnumUnboxingUtility";
-  public static final String ENUM_UNBOXING_UTILITY_ORDINAL = "$enumboxing$ordinal";
-  private static final String ENUM_UNBOXING_UTILITY_VALUES = "$enumboxing$values";
+  public static final String ENUM_UNBOXING_UTILITY_METHOD_PREFIX = "$enumboxing$";
   private static final int REQUIRED_CLASS_FILE_VERSION = 52;
 
   private final AppView<AppInfoWithLiveness> appView;
   private final DexItemFactory factory;
   private final EnumValueInfoMapCollection enumsToUnbox;
-  private final Map<DexMethod, DexEncodedMethod> extraUtilityMethods = new ConcurrentHashMap<>();
+  private final Map<DexMethod, DexEncodedMethod> utilityMethods = new ConcurrentHashMap<>();
   private final Map<DexField, DexEncodedField> extraUtilityFields = new ConcurrentHashMap<>();
 
   private final DexMethod ordinalUtilityMethod;
+  private final DexMethod equalsUtilityMethod;
+  private final DexMethod compareToUtilityMethod;
   private final DexMethod valuesUtilityMethod;
 
-  private boolean requiresOrdinalUtilityMethod = false;
-  private boolean requiresValuesUtilityMethod = false;
-
   EnumUnboxingRewriter(AppView<AppInfoWithLiveness> appView, Set<DexType> enumsToUnbox) {
     this.appView = appView;
     this.factory = appView.dexItemFactory();
@@ -83,16 +82,28 @@
     }
     this.enumsToUnbox = builder.build();
 
+    // Custom methods for java.lang.Enum methods ordinal, equals and compareTo.
     this.ordinalUtilityMethod =
         factory.createMethod(
             factory.enumUnboxingUtilityType,
             factory.createProto(factory.intType, factory.intType),
-            ENUM_UNBOXING_UTILITY_ORDINAL);
+            ENUM_UNBOXING_UTILITY_METHOD_PREFIX + "ordinal");
+    this.equalsUtilityMethod =
+        factory.createMethod(
+            factory.enumUnboxingUtilityType,
+            factory.createProto(factory.booleanType, factory.intType, factory.intType),
+            ENUM_UNBOXING_UTILITY_METHOD_PREFIX + "equals");
+    this.compareToUtilityMethod =
+        factory.createMethod(
+            factory.enumUnboxingUtilityType,
+            factory.createProto(factory.intType, factory.intType, factory.intType),
+            ENUM_UNBOXING_UTILITY_METHOD_PREFIX + "compareTo");
+    // Custom methods for generated field $VALUES initialization.
     this.valuesUtilityMethod =
         factory.createMethod(
             factory.enumUnboxingUtilityType,
             factory.createProto(factory.intArrayType, factory.intType),
-            ENUM_UNBOXING_UTILITY_VALUES);
+            ENUM_UNBOXING_UTILITY_METHOD_PREFIX + "values");
   }
 
   public EnumValueInfoMapCollection getEnumsToUnbox() {
@@ -119,17 +130,23 @@
         DexType enumType = getEnumTypeOrNull(invokeMethod.getReceiver(), convertedEnums);
         if (enumType != null) {
           if (invokedMethod == factory.enumMethods.ordinal) {
-            instruction =
-                new InvokeStatic(
-                    ordinalUtilityMethod, invokeMethod.outValue(), invokeMethod.inValues());
-            iterator.replaceCurrentInstruction(instruction);
-            requiresOrdinalUtilityMethod = true;
+            replaceEnumInvoke(
+                iterator, invokeMethod, ordinalUtilityMethod, m -> synthesizeOrdinalMethod());
+            continue;
+          } else if (invokedMethod == factory.enumMethods.equals) {
+            replaceEnumInvoke(
+                iterator, invokeMethod, equalsUtilityMethod, m -> synthesizeEqualsMethod());
+            continue;
+          } else if (invokedMethod == factory.enumMethods.compareTo) {
+            replaceEnumInvoke(
+                iterator, invokeMethod, compareToUtilityMethod, m -> synthesizeCompareToMethod());
             continue;
           } else if (invokedMethod == factory.enumMethods.name
               || invokedMethod == factory.enumMethods.toString) {
             DexMethod toStringMethod = computeDefaultToStringUtilityMethod(enumType);
             iterator.replaceCurrentInstruction(
-                new InvokeStatic(toStringMethod, invokeMethod.outValue(), invokeMethod.inValues()));
+                new InvokeStatic(
+                    toStringMethod, invokeMethod.outValue(), invokeMethod.arguments()));
             continue;
           }
         }
@@ -175,11 +192,12 @@
           affectedPhis.addAll(staticGet.outValue().uniquePhiUsers());
           EnumValueInfo enumValueInfo = enumValueInfoMap.getEnumValueInfo(staticGet.getField());
           if (enumValueInfo == null && staticGet.getField().name == factory.enumValuesFieldName) {
-            requiresValuesUtilityMethod = true;
+            utilityMethods.computeIfAbsent(
+                valuesUtilityMethod, m -> synthesizeValuesUtilityMethod());
             DexField fieldValues = createValuesField(holder);
             extraUtilityFields.computeIfAbsent(fieldValues, this::computeValuesEncodedField);
             DexMethod methodValues = createValuesMethod(holder);
-            extraUtilityMethods.computeIfAbsent(
+            utilityMethods.computeIfAbsent(
                 methodValues,
                 m -> computeValuesEncodedMethod(m, fieldValues, enumValueInfoMap.size()));
             Value rewrittenOutValue =
@@ -215,6 +233,17 @@
     return affectedPhis;
   }
 
+  private void replaceEnumInvoke(
+      InstructionListIterator iterator,
+      InvokeMethodWithReceiver invokeMethod,
+      DexMethod method,
+      Function<DexMethod, DexEncodedMethod> synthesizor) {
+    utilityMethods.computeIfAbsent(method, synthesizor);
+    Instruction instruction =
+        new InvokeStatic(method, invokeMethod.outValue(), invokeMethod.arguments());
+    iterator.replaceCurrentInstruction(instruction);
+  }
+
   private boolean validateArrayAccess(ArrayAccess arrayAccess) {
     ArrayTypeElement arrayType = arrayAccess.array().getType().asArrayType();
     if (arrayType == null) {
@@ -286,7 +315,7 @@
             factory.enumUnboxingUtilityType,
             factory.createProto(factory.intType, factory.stringType),
             "valueOf" + compatibleName(type));
-    extraUtilityMethods.computeIfAbsent(valueOf, m -> synthesizeValueOfUtilityMethod(m, type));
+    utilityMethods.computeIfAbsent(valueOf, m -> synthesizeValueOfUtilityMethod(m, type));
     return valueOf;
   }
 
@@ -297,7 +326,7 @@
             factory.enumUnboxingUtilityType,
             factory.createProto(factory.stringType, factory.intType),
             "toString" + compatibleName(type));
-    extraUtilityMethods.computeIfAbsent(toString, m -> synthesizeToStringUtilityMethod(m, type));
+    utilityMethods.computeIfAbsent(toString, m -> synthesizeToStringUtilityMethod(m, type));
     return toString;
   }
 
@@ -323,16 +352,9 @@
       throws ExecutionException {
     // Synthesize a class which holds various utility methods that may be called from the IR
     // rewriting. If any of these methods are not used, they will be removed by the Enqueuer.
-    List<DexEncodedMethod> requiredMethods = new ArrayList<>(extraUtilityMethods.values());
+    List<DexEncodedMethod> requiredMethods = new ArrayList<>(utilityMethods.values());
     // Sort for deterministic order.
     requiredMethods.sort((m1, m2) -> m1.method.name.slowCompareTo(m2.method.name));
-    if (requiresOrdinalUtilityMethod) {
-      requiredMethods.add(synthesizeOrdinalMethod());
-    }
-    if (requiresValuesUtilityMethod) {
-      requiredMethods.add(synthesizeValuesUtilityMethod());
-    }
-    // TODO(b/147860220): synthesize also other enum methods.
     if (requiredMethods.isEmpty()) {
       return;
     }
@@ -412,6 +434,19 @@
     return synthesizeUtilityMethod(cfCode, ordinalUtilityMethod, false);
   }
 
+  private DexEncodedMethod synthesizeEqualsMethod() {
+    CfCode cfCode =
+        EnumUnboxingCfMethods.EnumUnboxingMethods_equals(appView.options(), equalsUtilityMethod);
+    return synthesizeUtilityMethod(cfCode, equalsUtilityMethod, false);
+  }
+
+  private DexEncodedMethod synthesizeCompareToMethod() {
+    CfCode cfCode =
+        EnumUnboxingCfMethods.EnumUnboxingMethods_compareTo(
+            appView.options(), compareToUtilityMethod);
+    return synthesizeUtilityMethod(cfCode, compareToUtilityMethod, false);
+  }
+
   private DexEncodedMethod synthesizeValuesUtilityMethod() {
     CfCode cfCode =
         EnumUnboxingCfMethods.EnumUnboxingMethods_values(appView.options(), valuesUtilityMethod);
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/EqualsCompareToEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/EqualsCompareToEnumUnboxingTest.java
new file mode 100644
index 0000000..de86bd7
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/enumunboxing/EqualsCompareToEnumUnboxingTest.java
@@ -0,0 +1,108 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.enumunboxing;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.R8TestRunResult;
+import com.android.tools.r8.TestParameters;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class EqualsCompareToEnumUnboxingTest extends EnumUnboxingTestBase {
+
+  private final TestParameters parameters;
+  private final boolean enumValueOptimization;
+  private final KeepRule enumKeepRules;
+
+  @Parameters(name = "{0} valueOpt: {1} keep: {2}")
+  public static List<Object[]> data() {
+    return enumUnboxingTestParameters();
+  }
+
+  public EqualsCompareToEnumUnboxingTest(
+      TestParameters parameters, boolean enumValueOptimization, KeepRule enumKeepRules) {
+    this.parameters = parameters;
+    this.enumValueOptimization = enumValueOptimization;
+    this.enumKeepRules = enumKeepRules;
+  }
+
+  @Test
+  public void testEnumUnboxing() throws Exception {
+    Class<?> success = EnumEqualscompareTo.class;
+    R8TestRunResult run =
+        testForR8(parameters.getBackend())
+            .addInnerClasses(EqualsCompareToEnumUnboxingTest.class)
+            .addKeepMainRule(EnumEqualscompareTo.class)
+            .enableNeverClassInliningAnnotations()
+            .addKeepRules(enumKeepRules.getKeepRule())
+            .addOptionsModification(opt -> enableEnumOptions(opt, enumValueOptimization))
+            .allowDiagnosticInfoMessages()
+            .setMinApi(parameters.getApiLevel())
+            .compile()
+            .inspectDiagnosticMessages(
+                m ->
+                    assertEnumIsUnboxed(
+                        success.getDeclaredClasses()[0], success.getSimpleName(), m))
+            .run(parameters.getRuntime(), success)
+            .assertSuccess();
+    assertLines2By2Correct(run.getStdOut());
+  }
+
+  static class EnumEqualscompareTo {
+
+    @NeverClassInline
+    enum MyEnum {
+      A,
+      B
+    }
+
+    public static void main(String[] args) {
+      equalsTest();
+      compareToTest();
+    }
+
+    @SuppressWarnings({"ConstantConditions", "EqualsWithItself", "ResultOfMethodCallIgnored"})
+    private static void equalsTest() {
+      System.out.println(MyEnum.A.equals(MyEnum.B));
+      System.out.println(false);
+      System.out.println(MyEnum.A.equals(MyEnum.A));
+      System.out.println(true);
+      System.out.println(MyEnum.A.equals(null));
+      System.out.println(false);
+      try {
+        ((MyEnum) null).equals(null);
+      } catch (NullPointerException npe) {
+        System.out.println("npe " + npe.getMessage());
+        System.out.println("npe " + npe.getMessage());
+      }
+    }
+
+    @SuppressWarnings({"ConstantConditions", "EqualsWithItself", "ResultOfMethodCallIgnored"})
+    private static void compareToTest() {
+      System.out.println(MyEnum.B.compareTo(MyEnum.A) > 0);
+      System.out.println(true);
+      System.out.println(MyEnum.A.compareTo(MyEnum.B) < 0);
+      System.out.println(true);
+      System.out.println(MyEnum.A.compareTo(MyEnum.A) == 0);
+      System.out.println(true);
+      try {
+        ((MyEnum) null).equals(null);
+      } catch (NullPointerException npe) {
+        System.out.println("npe " + npe.getMessage());
+        System.out.println("npe " + npe.getMessage());
+      }
+      try {
+        MyEnum.A.compareTo(null);
+      } catch (NullPointerException npe) {
+        System.out.println("npe " + npe.getMessage());
+        System.out.println("npe " + npe.getMessage());
+      }
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/R8InliningTest.java b/src/test/java/com/android/tools/r8/ir/optimize/R8InliningTest.java
index 8ff71b8..75c45d1 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/R8InliningTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/R8InliningTest.java
@@ -374,7 +374,7 @@
               .getMethod()
               .name
               .toString()
-              .equals(EnumUnboxingRewriter.ENUM_UNBOXING_UTILITY_ORDINAL)) {
+              .startsWith(EnumUnboxingRewriter.ENUM_UNBOXING_UTILITY_METHOD_PREFIX)) {
         ++invokeCount;
       }
     }