Reland Enum unboxing: unbox null checks

Bug: 167843161
Bug: 166397278
Change-Id: I9e2d102fb8e2cb6355289b7573b5fc96b1017ac4
diff --git a/src/main/java/com/android/tools/r8/L8.java b/src/main/java/com/android/tools/r8/L8.java
index 25052e8..ef71e03 100644
--- a/src/main/java/com/android/tools/r8/L8.java
+++ b/src/main/java/com/android/tools/r8/L8.java
@@ -99,7 +99,11 @@
           });
       assert !options.cfToCfDesugar;
       if (shrink) {
-        R8.run(r8Command);
+        AndroidApp r8CommandInputApp = r8Command.getInputApp();
+        InternalOptions r8CommandInternalOptions = r8Command.getInternalOptions();
+        // TODO(b/167843161): Disable temporarily enum unboxing in L8 due to naming issues.
+        r8CommandInternalOptions.enableEnumUnboxing = false;
+        R8.runForTesting(r8CommandInputApp, r8CommandInternalOptions);
       } else if (d8Command != null) {
         D8.run(d8Command, executorService);
       }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index 30773a4..3d7d397 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -872,9 +872,20 @@
       assert dexClass.isLibraryClass();
       if (dexClass.type != factory.enumType) {
         // System.identityHashCode(Object) is supported for proto enums.
+        // Object#getClass without outValue and Objects.requireNonNull are supported since R8
+        // rewrites explicit null checks to such instructions.
         if (singleTarget == factory.javaLangSystemMethods.identityHashCode) {
           return Reason.ELIGIBLE;
         }
+        if (singleTarget == factory.objectMembers.getClass
+            && (!invokeMethod.hasOutValue() || !invokeMethod.outValue().hasAnyUsers())) {
+          // This is a hidden null check.
+          return Reason.ELIGIBLE;
+        }
+        if (singleTarget == factory.objectsMethods.requireNonNull
+            || singleTarget == factory.objectsMethods.requireNonNullWithMessage) {
+          return Reason.ELIGIBLE;
+        }
         return Reason.UNSUPPORTED_LIBRARY_CALL;
       }
       // TODO(b/147860220): EnumSet and EnumMap may be interesting to model.
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingCfMethods.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingCfMethods.java
index daf1ad4..aa7ec4c 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingCfMethods.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingCfMethods.java
@@ -24,6 +24,7 @@
 import com.android.tools.r8.cf.code.CfNew;
 import com.android.tools.r8.cf.code.CfNewArray;
 import com.android.tools.r8.cf.code.CfReturn;
+import com.android.tools.r8.cf.code.CfReturnVoid;
 import com.android.tools.r8.cf.code.CfStackInstruction;
 import com.android.tools.r8.cf.code.CfStore;
 import com.android.tools.r8.cf.code.CfThrow;
@@ -260,4 +261,69 @@
         ImmutableList.of(),
         ImmutableList.of());
   }
+
+  public static CfCode EnumUnboxingMethods_zeroCheck(InternalOptions options, DexMethod method) {
+    CfLabel label0 = new CfLabel();
+    CfLabel label1 = new CfLabel();
+    CfLabel label2 = new CfLabel();
+    CfLabel label3 = new CfLabel();
+    return new CfCode(
+        method.holder,
+        2,
+        1,
+        ImmutableList.of(
+            label0,
+            new CfLoad(ValueType.INT, 0),
+            new CfIf(If.Type.NE, ValueType.INT, label2),
+            label1,
+            new CfNew(options.itemFactory.createType("Ljava/lang/NullPointerException;")),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfInvoke(
+                183,
+                options.itemFactory.createMethod(
+                    options.itemFactory.createType("Ljava/lang/NullPointerException;"),
+                    options.itemFactory.createProto(options.itemFactory.voidType),
+                    options.itemFactory.createString("<init>")),
+                false),
+            new CfThrow(),
+            label2,
+            new CfReturnVoid(),
+            label3),
+        ImmutableList.of(),
+        ImmutableList.of());
+  }
+
+  public static CfCode EnumUnboxingMethods_zeroCheckMessage(
+      InternalOptions options, DexMethod method) {
+    CfLabel label0 = new CfLabel();
+    CfLabel label1 = new CfLabel();
+    CfLabel label2 = new CfLabel();
+    CfLabel label3 = new CfLabel();
+    return new CfCode(
+        method.holder,
+        3,
+        2,
+        ImmutableList.of(
+            label0,
+            new CfLoad(ValueType.INT, 0),
+            new CfIf(If.Type.NE, ValueType.INT, label2),
+            label1,
+            new CfNew(options.itemFactory.createType("Ljava/lang/NullPointerException;")),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfLoad(ValueType.OBJECT, 1),
+            new CfInvoke(
+                183,
+                options.itemFactory.createMethod(
+                    options.itemFactory.createType("Ljava/lang/NullPointerException;"),
+                    options.itemFactory.createProto(
+                        options.itemFactory.voidType, options.itemFactory.stringType),
+                    options.itemFactory.createString("<init>")),
+                false),
+            new CfThrow(),
+            label2,
+            new CfReturnVoid(),
+            label3),
+        ImmutableList.of(),
+        ImmutableList.of());
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 54d9aaf..ca106ab 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -77,6 +77,8 @@
   private final DexMethod equalsUtilityMethod;
   private final DexMethod compareToUtilityMethod;
   private final DexMethod valuesUtilityMethod;
+  private final DexMethod zeroCheckMethod;
+  private final DexMethod zeroCheckMessageMethod;
 
   EnumUnboxingRewriter(
       AppView<AppInfoWithLiveness> appView,
@@ -114,6 +116,17 @@
             factory.enumUnboxingUtilityType,
             factory.createProto(factory.intArrayType, factory.intType),
             ENUM_UNBOXING_UTILITY_METHOD_PREFIX + "values");
+    // Custom methods for Object#getClass without outValue and Objects.requireNonNull.
+    this.zeroCheckMethod =
+        factory.createMethod(
+            factory.enumUnboxingUtilityType,
+            factory.createProto(factory.voidType, factory.intType),
+            ENUM_UNBOXING_UTILITY_METHOD_PREFIX + "zeroCheck");
+    this.zeroCheckMessageMethod =
+        factory.createMethod(
+            factory.enumUnboxingUtilityType,
+            factory.createProto(factory.voidType, factory.intType, factory.stringType),
+            ENUM_UNBOXING_UTILITY_METHOD_PREFIX + "zeroCheckMessage");
   }
 
   public EnumValueInfoMapCollection getEnumsToUnbox() {
@@ -160,6 +173,10 @@
                 new InvokeStatic(
                     toStringMethod, invokeMethod.outValue(), invokeMethod.arguments()));
             continue;
+          } else if (invokedMethod == factory.objectMembers.getClass) {
+            assert !invokeMethod.hasOutValue() || !invokeMethod.outValue().hasAnyUsers();
+            replaceEnumInvoke(
+                iterator, invokeMethod, zeroCheckMethod, m -> synthesizeZeroCheckMethod());
           }
         }
         // TODO(b/147860220): rewrite also other enum methods.
@@ -195,6 +212,25 @@
             invokeStatic.outValue().replaceUsers(argument);
             iterator.removeOrReplaceByDebugLocalRead();
           }
+        } else if (invokedMethod == factory.objectsMethods.requireNonNull) {
+          assert invokeStatic.inValues().size() == 1;
+          Value argument = invokeStatic.getArgument(0);
+          DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+          if (enumType != null) {
+            replaceEnumInvoke(
+                iterator, invokeStatic, zeroCheckMethod, m -> synthesizeZeroCheckMethod());
+          }
+        } else if (invokedMethod == factory.objectsMethods.requireNonNullWithMessage) {
+          assert invokeStatic.inValues().size() == 2;
+          Value argument = invokeStatic.getArgument(0);
+          DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+          if (enumType != null) {
+            replaceEnumInvoke(
+                iterator,
+                invokeStatic,
+                zeroCheckMessageMethod,
+                m -> synthesizeZeroCheckMessageMethod());
+          }
         }
       }
       if (instruction.isStaticGet()) {
@@ -497,6 +533,19 @@
     return false;
   }
 
+  private DexEncodedMethod synthesizeZeroCheckMethod() {
+    CfCode cfCode =
+        EnumUnboxingCfMethods.EnumUnboxingMethods_zeroCheck(appView.options(), zeroCheckMethod);
+    return synthesizeUtilityMethod(cfCode, zeroCheckMethod, false);
+  }
+
+  private DexEncodedMethod synthesizeZeroCheckMessageMethod() {
+    CfCode cfCode =
+        EnumUnboxingCfMethods.EnumUnboxingMethods_zeroCheckMessage(
+            appView.options(), zeroCheckMessageMethod);
+    return synthesizeUtilityMethod(cfCode, zeroCheckMessageMethod, false);
+  }
+
   private DexEncodedMethod synthesizeOrdinalMethod() {
     CfCode cfCode =
         EnumUnboxingCfMethods.EnumUnboxingMethods_ordinal(appView.options(), ordinalUtilityMethod);
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingMethods.java b/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingMethods.java
index e1d51cc..cc2aa0c 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingMethods.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingMethods.java
@@ -51,4 +51,17 @@
     }
     return unboxedEnum1 == unboxedEnum2;
   }
+
+  // Methods zeroCheck and zeroCheckMessage are used to replace null checks on unboxed enums.
+  public static void zeroCheck(int unboxedEnum) {
+    if (unboxedEnum == 0) {
+      throw new NullPointerException();
+    }
+  }
+
+  public static void zeroCheckMessage(int unboxedEnum, String message) {
+    if (unboxedEnum == 0) {
+      throw new NullPointerException(message);
+    }
+  }
 }
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/NullCheckEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/NullCheckEnumUnboxingTest.java
new file mode 100644
index 0000000..25a3434
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/enumunboxing/NullCheckEnumUnboxingTest.java
@@ -0,0 +1,197 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.enumunboxing;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.R8TestRunResult;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import java.util.List;
+import java.util.Objects;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class NullCheckEnumUnboxingTest extends EnumUnboxingTestBase {
+
+  private final TestParameters parameters;
+  private final boolean enumValueOptimization;
+  private final EnumKeepRules enumKeepRules;
+
+  @Parameters(name = "{0} valueOpt: {1} keep: {2}")
+  public static List<Object[]> data() {
+    return enumUnboxingTestParameters();
+  }
+
+  public NullCheckEnumUnboxingTest(
+      TestParameters parameters, boolean enumValueOptimization, EnumKeepRules enumKeepRules) {
+    this.parameters = parameters;
+    this.enumValueOptimization = enumValueOptimization;
+    this.enumKeepRules = enumKeepRules;
+  }
+
+  @Test
+  public void testEnumUnboxing() throws Exception {
+    R8TestRunResult run =
+        testForR8(parameters.getBackend())
+            .addInnerClasses(NullCheckEnumUnboxingTest.class)
+            .addKeepMainRule(MainNullTest.class)
+            .addKeepRules(enumKeepRules.getKeepRules())
+            .enableNeverClassInliningAnnotations()
+            .enableInliningAnnotations()
+            .addOptionsModification(opt -> enableEnumOptions(opt, enumValueOptimization))
+            .allowDiagnosticMessages()
+            .setMinApi(parameters.getApiLevel())
+            .compile()
+            .inspectDiagnosticMessages(
+                m -> {
+                  assertEnumIsUnboxed(MyEnum.class, MyEnum.class.getSimpleName(), m);
+                  // MyEnum19 is unboxed only if minAPI > 19 because Objects#requiredNonNull is then
+                  // present.
+                  if (parameters.getApiLevel().isGreaterThanOrEqualTo(AndroidApiLevel.K)) {
+                    assertEnumIsUnboxed(MyEnum19.class, MyEnum19.class.getSimpleName(), m);
+                  } else {
+                    assertEnumIsBoxed(MyEnum19.class, MyEnum19.class.getSimpleName(), m);
+                  }
+                })
+            .run(parameters.getRuntime(), MainNullTest.class)
+            .assertSuccess();
+    assertLines2By2Correct(run.getStdOut());
+  }
+
+  @NeverClassInline
+  enum MyEnum {
+    A,
+    B,
+    C
+  }
+
+  @NeverClassInline
+  enum MyEnum19 {
+    A,
+    B,
+    C
+  }
+
+  static class MainNullTest {
+
+    public static void main(String[] args) {
+      nullCheckTests();
+      nullCheckMessageTests();
+    }
+
+    private static void nullCheckTests() {
+      nullCheck0Test(MyEnum.A, false);
+      nullCheck0Test(MyEnum.B, false);
+      nullCheck0Test(null, true);
+
+      nullCheck1Test(MyEnum.A, false);
+      nullCheck1Test(MyEnum.B, false);
+      nullCheck1Test(null, true);
+
+      nullCheck2Test(MyEnum19.A, false);
+      nullCheck2Test(MyEnum19.B, false);
+      nullCheck2Test(null, true);
+    }
+
+    private static void nullCheckMessageTests() {
+      nullCheckMessage0Test(MyEnum.A, "myMessageA", false);
+      nullCheckMessage0Test(MyEnum.B, "myMessageB", false);
+      nullCheckMessage0Test(null, "myMessageN", true);
+
+      nullCheckMessage1Test(MyEnum19.A, "myMessageA", false);
+      nullCheckMessage1Test(MyEnum19.B, "myMessageB", false);
+      nullCheckMessage1Test(null, "myMessageN", true);
+    }
+
+    private static void nullCheck0Test(MyEnum input, boolean isNull) {
+      String result = "pass";
+      try {
+        nullCheck0(input);
+      } catch (NullPointerException ex8) {
+        result = "fail";
+      }
+      System.out.println(result);
+      System.out.println(isNull ? "fail" : "pass");
+    }
+
+    private static void nullCheck1Test(MyEnum input, boolean isNull) {
+      String result = "pass";
+      try {
+        nullCheck1(input);
+      } catch (NullPointerException ex8) {
+        result = "fail";
+      }
+      System.out.println(result);
+      System.out.println(isNull ? "fail" : "pass");
+    }
+
+    private static void nullCheck2Test(MyEnum19 input, boolean isNull) {
+      String result = "pass";
+      try {
+        nullCheck2(input);
+      } catch (NullPointerException ex) {
+        result = "fail";
+      }
+      System.out.println(result);
+      System.out.println(isNull ? "fail" : "pass");
+    }
+
+    private static void nullCheckMessage0Test(MyEnum input, String message, boolean isNull) {
+      String result = "pass";
+      try {
+        nullCheckMessage0(input, message);
+      } catch (NullPointerException ex) {
+        result = ex.getMessage();
+      }
+      System.out.println(result);
+      System.out.println(isNull ? message : "pass");
+    }
+
+    private static void nullCheckMessage1Test(MyEnum19 input, String message, boolean isNull) {
+      String result = "pass";
+      try {
+        nullCheckMessage1(input, message);
+      } catch (NullPointerException ex) {
+        result = ex.getMessage();
+      }
+      System.out.println(result);
+      System.out.println(isNull ? message : "pass");
+    }
+
+    @NeverInline
+    private static void nullCheck0(MyEnum e) {
+      if (e == null) {
+        throw new NullPointerException();
+      }
+    }
+
+    @SuppressWarnings("ResultOfMethodCallIgnored")
+    @NeverInline
+    private static void nullCheck1(MyEnum e) {
+      e.getClass();
+    }
+
+    @NeverInline
+    private static void nullCheck2(MyEnum19 e) {
+      Objects.requireNonNull(e);
+    }
+
+    @NeverInline
+    private static void nullCheckMessage0(MyEnum e, String message) {
+      if (e == null) {
+        throw new NullPointerException(message);
+      }
+    }
+
+    @NeverInline
+    private static void nullCheckMessage1(MyEnum19 e, String message) {
+      Objects.requireNonNull(e, message);
+    }
+  }
+}