Enum unboxer: fix enum arrays on various instructions

Bug: b/261900447
Change-Id: If67ae8d31e3c832e02abb90142cc0ba19aa9fdbf
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 7327231..d0db72d 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -102,7 +102,8 @@
       assert next.isArgument();
       if (argumentInfo.isRewrittenTypeInfo()) {
         RewrittenTypeInfo rewrittenTypeInfo = argumentInfo.asRewrittenTypeInfo();
-        DexType enumType = getEnumTypeOrNull(rewrittenTypeInfo.getOldType().toBaseType(factory));
+        DexType enumType =
+            getEnumClassTypeOrNull(rewrittenTypeInfo.getOldType().toBaseType(factory));
         if (enumType != null) {
           convertedEnums.put(next, enumType);
         }
@@ -142,7 +143,7 @@
 
         if (instruction.isInitClass()) {
           InitClass initClass = instruction.asInitClass();
-          DexType enumType = getEnumTypeOrNull(initClass.getClassValue());
+          DexType enumType = getEnumClassTypeOrNull(initClass.getClassValue());
           if (enumType != null) {
             iterator.removeOrReplaceByDebugLocalRead();
           }
@@ -154,7 +155,7 @@
           if (!ifInstruction.isZeroTest()) {
             for (int operandIndex = 0; operandIndex < 2; operandIndex++) {
               Value operand = ifInstruction.getOperand(operandIndex);
-              DexType enumType = getEnumTypeOrNull(operand, convertedEnums);
+              DexType enumType = getEnumClassTypeOrNull(operand, convertedEnums);
               if (enumType != null) {
                 int otherOperandIndex = 1 - operandIndex;
                 Value otherOperand = ifInstruction.getOperand(otherOperandIndex);
@@ -179,7 +180,7 @@
         //   also in the unboxed enum class.
         if (instruction.isInvokeMethodWithReceiver()) {
           InvokeMethodWithReceiver invoke = instruction.asInvokeMethodWithReceiver();
-          DexType enumType = getEnumTypeOrNull(invoke.getReceiver(), convertedEnums);
+          DexType enumType = getEnumClassTypeOrNull(invoke.getReceiver(), convertedEnums);
           DexMethod invokedMethod = invoke.getInvokedMethod();
           if (enumType != null) {
             if (invokedMethod == factory.enumMembers.ordinalMethod
@@ -217,7 +218,7 @@
             // Rewrites stringBuilder.append(enumInstance) as if it was
             // stringBuilder.append(String.valueOf(unboxedEnumInstance));
             Value enumArg = invoke.getArgument(1);
-            DexType enumArgType = getEnumTypeOrNull(enumArg, convertedEnums);
+            DexType enumArgType = getEnumClassTypeOrNull(enumArg, convertedEnums);
             if (enumArgType != null) {
               ProgramMethod stringValueOfMethod =
                   getLocalUtilityClass(enumArgType).ensureStringValueOfMethod(appView);
@@ -330,7 +331,7 @@
         // Rewrite array accesses from MyEnum[] (OBJECT) to int[] (INT).
         if (instruction.isArrayAccess()) {
           ArrayAccess arrayAccess = instruction.asArrayAccess();
-          DexType enumType = getEnumTypeOrNull(arrayAccess, convertedEnums);
+          DexType enumType = getEnumArrayTypeOrNull(arrayAccess, convertedEnums);
           if (enumType != null) {
             if (arrayAccess.hasOutValue()) {
               affectedPhis.addAll(arrayAccess.outValue().uniquePhiUsers());
@@ -403,14 +404,14 @@
       if (invokedMethod == factory.objectsMethods.requireNonNull) {
         assert invoke.arguments().size() == 1;
         Value argument = invoke.getFirstArgument();
-        DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+        DexType enumType = getEnumClassTypeOrNull(argument, convertedEnums);
         if (enumType != null) {
           rewriteNullCheck(instructionIterator, invoke);
         }
       } else if (invokedMethod == factory.objectsMethods.requireNonNullWithMessage) {
         assert invoke.arguments().size() == 2;
         Value argument = invoke.getFirstArgument();
-        DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+        DexType enumType = getEnumClassTypeOrNull(argument, convertedEnums);
         if (enumType != null) {
           replaceEnumInvoke(
               instructionIterator,
@@ -426,7 +427,7 @@
       if (invokedMethod == factory.stringMembers.valueOf) {
         assert invoke.arguments().size() == 1;
         Value argument = invoke.getFirstArgument();
-        DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+        DexType enumType = getEnumClassTypeOrNull(argument, convertedEnums);
         if (enumType != null) {
           ProgramMethod stringValueOfMethod =
               getLocalUtilityClass(enumType).ensureStringValueOfMethod(appView);
@@ -445,7 +446,7 @@
       } else if (invokedMethod == factory.javaLangSystemMembers.identityHashCode) {
         assert invoke.arguments().size() == 1;
         Value argument = invoke.getFirstArgument();
-        DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+        DexType enumType = getEnumClassTypeOrNull(argument, convertedEnums);
         if (enumType != null) {
           invoke.outValue().replaceUsers(argument);
           instructionIterator.removeOrReplaceByDebugLocalRead();
@@ -470,7 +471,7 @@
           CheckNotNullEnumUnboxerMethodClassification checkNotNullClassification =
               classification.asCheckNotNullClassification();
           Value argument = invoke.getArgument(checkNotNullClassification.getArgumentIndex());
-          DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+          DexType enumType = getEnumClassTypeOrNull(argument, convertedEnums);
           if (enumType != null) {
             InvokeStatic replacement =
                 InvokeStatic.builder()
@@ -529,7 +530,7 @@
 
   private Value fixNullsInBlockPhis(IRCode code, BasicBlock block, Value zeroConstValue) {
     for (Phi phi : block.getPhis()) {
-      if (getEnumTypeOrNull(phi.getType()) != null) {
+      if (getEnumClassTypeOrNull(phi.getType()) != null) {
         for (int i = 0; i < phi.getOperands().size(); i++) {
           Value operand = phi.getOperand(i);
           if (operand.getType().isNullType()) {
@@ -585,26 +586,26 @@
     return true;
   }
 
-  private DexType getEnumTypeOrNull(Value receiver, Map<Instruction, DexType> convertedEnums) {
+  private DexType getEnumClassTypeOrNull(Value receiver, Map<Instruction, DexType> convertedEnums) {
     TypeElement type = receiver.getType();
-    if (type.isInt() || (type.isArrayType() && type.asArrayType().getBaseType().isInt())) {
+    if (type.isInt()) {
       return receiver.isPhi() ? null : convertedEnums.get(receiver.getDefinition());
     }
-    return getEnumTypeOrNull(type);
+    return getEnumClassTypeOrNull(type);
   }
 
-  private DexType getEnumTypeOrNull(TypeElement type) {
+  private DexType getEnumClassTypeOrNull(TypeElement type) {
     if (!type.isClassType()) {
       return null;
     }
-    return getEnumTypeOrNull(type.asClassType().getClassType());
+    return getEnumClassTypeOrNull(type.asClassType().getClassType());
   }
 
-  private DexType getEnumTypeOrNull(DexType type) {
+  private DexType getEnumClassTypeOrNull(DexType type) {
     return unboxedEnumsData.isUnboxedEnum(type) ? type : null;
   }
 
-  private DexType getEnumTypeOrNull(
+  private DexType getEnumArrayTypeOrNull(
       ArrayAccess arrayAccess, Map<Instruction, DexType> convertedEnums) {
     ArrayTypeElement arrayType = arrayAccess.array().getType().asArrayType();
     if (arrayType == null) {
@@ -616,9 +617,13 @@
     }
     TypeElement baseType = arrayType.getBaseType();
     if (baseType.isClassType()) {
-      DexType classType = baseType.asClassType().getClassType();
-      return unboxedEnumsData.isUnboxedEnum(classType) ? classType : null;
+      return getEnumClassTypeOrNull(baseType.asClassType().getClassType());
     }
-    return getEnumTypeOrNull(arrayAccess.array(), convertedEnums);
+    if (arrayType.getBaseType().isInt()) {
+      return arrayAccess.array().isPhi()
+          ? null
+          : convertedEnums.get(arrayAccess.array().getDefinition());
+    }
+    return null;
   }
 }
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/NullAssignmentToArrayArgEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/NullAssignmentToArrayArgEnumUnboxingTest.java
new file mode 100644
index 0000000..fa5450f
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/enumunboxing/NullAssignmentToArrayArgEnumUnboxingTest.java
@@ -0,0 +1,90 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.enumunboxing;
+
+import com.android.tools.r8.TestParameters;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class NullAssignmentToArrayArgEnumUnboxingTest extends EnumUnboxingTestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameter(1)
+  public boolean enumValueOptimization;
+
+  @Parameter(2)
+  public EnumKeepRules enumKeepRules;
+
+  @Parameters(name = "{0}, value opt.: {1}, keep: {2}")
+  public static List<Object[]> data() {
+    return enumUnboxingTestParameters();
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addKeepRules(enumKeepRules.getKeepRules())
+        .addEnumUnboxingInspector(inspector -> inspector.assertUnboxed(MyEnum.class))
+        // We need to disable entirely inlining since not only the methods checkNotNull and
+        // contains should not be inlined, but also the synthetic method with the zero check
+        // replacing the checkNotNull method should not be inlined.
+        .addOptionsModification(opt -> opt.inlinerOptions().enableInlining = false)
+        .addOptionsModification(opt -> enableEnumOptions(opt, enumValueOptimization))
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("false", "true", "npe", "npe");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      System.out.println(contains(MyEnum.A, new MyEnum[] {MyEnum.B, MyEnum.C}));
+      System.out.println(contains(MyEnum.B, new MyEnum[] {MyEnum.B, MyEnum.C}));
+      try {
+        System.out.println(contains(MyEnum.B, null));
+      } catch (NullPointerException npe) {
+        System.out.println("npe");
+      }
+      try {
+        System.out.println(contains(null, new MyEnum[] {MyEnum.B, MyEnum.C}));
+      } catch (NullPointerException npe) {
+        System.out.println("npe");
+      }
+    }
+
+    static void checkNotNull(Object o, String msg) {
+      if (o == null) {
+        throw new NullPointerException(msg);
+      }
+    }
+
+    static boolean contains(MyEnum e, MyEnum[] contents) {
+      checkNotNull(e, "elem");
+      checkNotNull(contents, "array");
+      for (MyEnum content : contents) {
+        if (content == e) {
+          return true;
+        }
+      }
+      return false;
+    }
+  }
+
+  enum MyEnum {
+    A,
+    B,
+    C;
+  }
+}