Unbox enums with arrays and null entries

Change-Id: I70dded91e93c4133f4f511745fb757c14d923cce
diff --git a/src/main/java/com/android/tools/r8/ir/code/ArrayPut.java b/src/main/java/com/android/tools/r8/ir/code/ArrayPut.java
index c521bf5..5f07431 100644
--- a/src/main/java/com/android/tools/r8/ir/code/ArrayPut.java
+++ b/src/main/java/com/android/tools/r8/ir/code/ArrayPut.java
@@ -58,6 +58,10 @@
     return inValues.get(VALUE_INDEX);
   }
 
+  public void replacePutValue(Value newValue) {
+    replaceValue(VALUE_INDEX, newValue);
+  }
+
   @Override
   public MemberType getMemberType() {
     return type;
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
index 27cda2d..ee9878a 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
@@ -51,7 +51,6 @@
 import com.android.tools.r8.ir.analysis.type.ArrayTypeElement;
 import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.analysis.type.DynamicType;
-import com.android.tools.r8.ir.analysis.type.ReferenceTypeElement;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.ir.analysis.value.objectstate.EnumValuesObjectState;
@@ -1182,24 +1181,14 @@
     return Reason.ELIGIBLE;
   }
 
-  private ReferenceTypeElement getValueBaseType(Value value, TypeElement arrayType) {
-    TypeElement valueBaseType = value.getType();
-    if (valueBaseType.isArrayType()) {
-      assert valueBaseType.asArrayType().getBaseType().isClassType();
-      assert valueBaseType.asArrayType().getNesting() == arrayType.asArrayType().getNesting() - 1;
-      valueBaseType = valueBaseType.asArrayType().getBaseType();
+  private boolean isAssignableToArray(Value value, ClassTypeElement arrayBaseType) {
+    TypeElement valueType = value.getType();
+    if (valueType.isNullType()) {
+      return true;
     }
-    assert valueBaseType.isClassType() || valueBaseType.isNullType();
-    return valueBaseType.asReferenceType();
-  }
-
-  private boolean areCompatibleArrayTypes(
-      ClassTypeElement arrayBaseType, ReferenceTypeElement valueBaseType) {
-    assert valueBaseType.isClassType() || valueBaseType.isNullType();
-    if (valueBaseType.isNullType()) {
-      // TODO(b/271385332): Allow nulls in enum arrays to be unboxed.
-      return false;
-    }
+    TypeElement valueBaseType =
+        valueType.isArrayType() ? valueType.asArrayType().getBaseType() : valueType;
+    assert valueBaseType.isClassType();
     return enumUnboxingCandidatesInfo.isAssignableTo(
         valueBaseType.asClassType().getClassType(), arrayBaseType.getClassType());
   }
@@ -1219,8 +1208,7 @@
     assert arrayType.isArrayType();
     assert arrayType.asArrayType().getBaseType().isClassType();
     ClassTypeElement arrayBaseType = arrayType.asArrayType().getBaseType().asClassType();
-    ReferenceTypeElement valueBaseType = getValueBaseType(arrayPut.value(), arrayType);
-    if (areCompatibleArrayTypes(arrayBaseType, valueBaseType)) {
+    if (isAssignableToArray(arrayPut.value(), arrayBaseType)) {
       return Reason.ELIGIBLE;
     }
     return Reason.INVALID_ARRAY_PUT;
@@ -1247,8 +1235,7 @@
     }
 
     for (Value value : invokeNewArray.inValues()) {
-      ReferenceTypeElement valueBaseType = getValueBaseType(value, arrayType);
-      if (!areCompatibleArrayTypes(arrayBaseType, valueBaseType)) {
+      if (!isAssignableToArray(value, arrayBaseType)) {
         return Reason.INVALID_INVOKE_NEW_ARRAY;
       }
     }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 4aa7645..cbbd4b9 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -19,6 +19,7 @@
 import com.android.tools.r8.ir.analysis.type.ArrayTypeElement;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.code.ArrayAccess;
+import com.android.tools.r8.ir.code.ArrayPut;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.BasicBlockIterator;
 import com.android.tools.r8.ir.code.ConstNumber;
@@ -364,6 +365,14 @@
             arrayAccess = arrayAccess.withMemberType(MemberType.INT);
             iterator.replaceCurrentInstruction(arrayAccess);
             convertedEnums.put(arrayAccess, enumType);
+            if (arrayAccess.isArrayPut()) {
+              ArrayPut arrayPut = arrayAccess.asArrayPut();
+              if (arrayPut.value().getType().isNullType()) {
+                iterator.previous();
+                arrayPut.replacePutValue(iterator.insertConstIntInstruction(code, options, 0));
+                iterator.next();
+              }
+            }
           }
           assert validateArrayAccess(arrayAccess);
         }
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingArrayTest.java b/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingArrayTest.java
index c3919c7..ceaf088 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingArrayTest.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingArrayTest.java
@@ -20,9 +20,12 @@
   private static final Class<?>[] TESTS = {
     Enum2DimArrayReadWrite.class,
     EnumArrayNullRead.class,
-    EnumArrayReadWrite.class,
+    EnumArrayPutNull.class,
     EnumArrayReadWriteNoEscape.class,
     EnumVarArgs.class,
+    EnumArrayPutNull.class,
+    Enum2DimArrayPutNull.class,
+    Enum2DimArrayPutNullArray.class
   };
 
   private final TestParameters parameters;
@@ -51,14 +54,19 @@
             .enableNeverClassInliningAnnotations()
             .addKeepRules(enumKeepRules.getKeepRules())
             .addOptionsModification(opt -> enableEnumOptions(opt, enumValueOptimization))
+            .addOptionsModification(opt -> opt.testing.enableEnumUnboxingDebugLogs = true)
+            .allowDiagnosticInfoMessages()
             .addEnumUnboxingInspector(
                 inspector ->
                     inspector.assertUnboxed(
                         Enum2DimArrayReadWrite.MyEnum.class,
                         EnumArrayNullRead.MyEnum.class,
-                        EnumArrayReadWrite.MyEnum.class,
+                        EnumArrayPutNull.MyEnum.class,
                         EnumArrayReadWriteNoEscape.MyEnum.class,
-                        EnumVarArgs.MyEnum.class))
+                        EnumVarArgs.MyEnum.class,
+                        EnumArrayPutNull.MyEnum.class,
+                        Enum2DimArrayPutNull.MyEnum.class,
+                        Enum2DimArrayPutNullArray.MyEnum.class))
             .setMinApi(parameters)
             .compile();
     for (Class<?> main : TESTS) {
@@ -113,6 +121,37 @@
     }
   }
 
+  static class EnumArrayPutNull {
+
+    public static void main(String[] args) {
+      MyEnum[] myEnums = getArray();
+      System.out.println(myEnums[1].ordinal());
+      System.out.println(1);
+      setNull(myEnums);
+      System.out.println(myEnums[0] == null);
+      System.out.println("true");
+    }
+
+    @NeverInline
+    public static void setNull(MyEnum[] myEnums) {
+      myEnums[0] = null;
+    }
+
+    @NeverInline
+    public static MyEnum[] getArray() {
+      MyEnum[] myEnums = new MyEnum[2];
+      myEnums[1] = MyEnum.B;
+      myEnums[0] = MyEnum.A;
+      return myEnums;
+    }
+
+    @NeverClassInline
+    enum MyEnum {
+      A,
+      B;
+    }
+  }
+
   static class EnumArrayReadWrite {
 
     public static void main(String[] args) {
@@ -193,4 +232,77 @@
       C;
     }
   }
+
+  static class Enum2DimArrayPutNull {
+
+    public static void main(String[] args) {
+      MyEnum[][] myEnums = getArray();
+      System.out.println(myEnums[1][1].ordinal());
+      System.out.println(1);
+      setNull(myEnums);
+      System.out.println(myEnums[0] == null);
+      System.out.println("true");
+    }
+
+    @NeverInline
+    public static void setNull(MyEnum[][] myEnums) {
+      myEnums[0] = null;
+    }
+
+    @NeverInline
+    public static MyEnum[][] getArray() {
+      MyEnum[][] myEnums = new MyEnum[2][2];
+      myEnums[0][1] = MyEnum.B;
+      myEnums[0][0] = MyEnum.A;
+      myEnums[1][1] = MyEnum.B;
+      myEnums[1][0] = MyEnum.A;
+      return myEnums;
+    }
+
+    @NeverClassInline
+    enum MyEnum {
+      A,
+      B;
+    }
+  }
+
+  static class Enum2DimArrayPutNullArray {
+
+    public static void main(String[] args) {
+      MyEnum[][] myEnums = getArray();
+      System.out.println(myEnums[1][1].ordinal());
+      System.out.println(1);
+      setNull(myEnums);
+      System.out.println(myEnums[0][0] == null);
+      System.out.println("true");
+    }
+
+    @NeverInline
+    public static void setNull(MyEnum[][] myEnums) {
+      MyEnum[] myEnums1 = new MyEnum[1];
+      setNull(myEnums1);
+      myEnums[0] = myEnums1;
+    }
+
+    @NeverInline
+    public static void setNull(MyEnum[] myEnums1) {
+      myEnums1[0] = null;
+    }
+
+    @NeverInline
+    public static MyEnum[][] getArray() {
+      MyEnum[][] myEnums = new MyEnum[2][2];
+      myEnums[0][1] = MyEnum.B;
+      myEnums[0][0] = MyEnum.A;
+      myEnums[1][1] = MyEnum.B;
+      myEnums[1][0] = MyEnum.A;
+      return myEnums;
+    }
+
+    @NeverClassInline
+    enum MyEnum {
+      A,
+      B;
+    }
+  }
 }