Check for member type to be of array type

Bug: b/283715197
Change-Id: Iae3cc98b8a5f5d5eaeb49bf652c0bc4c652c4baf
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 3988896..26da1f1 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -36,6 +36,7 @@
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.horizontalclassmerging.HorizontalClassMergerUtils;
 import com.android.tools.r8.ir.analysis.equivalence.BasicBlockBehavioralSubsumption;
+import com.android.tools.r8.ir.analysis.type.ArrayTypeElement;
 import com.android.tools.r8.ir.analysis.type.DynamicTypeWithUpperBound;
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
@@ -2288,10 +2289,11 @@
     if (!options.canUseSubTypesInFilledNewArray()
         && arrayType != dexItemFactory.objectArrayType
         && !arrayType.isPrimitiveArrayType()) {
-      DexType baseType = arrayType.toBaseType(dexItemFactory);
+      DexType elementType = arrayType.toArrayElementType(dexItemFactory);
       for (Instruction uniqueUser : newArrayEmpty.outValue().uniqueUsers()) {
         if (uniqueUser.isArrayPut()
-            && !uniqueUser.asArrayPut().value().getType().isClassType(baseType)) {
+            && uniqueUser.asArrayPut().array() == newArrayEmpty.outValue()
+            && !checkTypeOfArrayPut(uniqueUser.asArrayPut(), elementType)) {
           return null;
         }
       }
@@ -2299,6 +2301,41 @@
     return new FilledArrayCandidate(newArrayEmpty, size, encodeAsFilledNewArray);
   }
 
+  private boolean checkTypeOfArrayPut(ArrayPut arrayPut, DexType elementType) {
+    TypeElement valueType = arrayPut.value().getType();
+    if (!valueType.isPrimitiveType() && elementType == dexItemFactory.objectType) {
+      return true;
+    }
+    if (valueType.isNullType() && !elementType.isPrimitiveType()) {
+      return true;
+    }
+    if (elementType.isArrayType()) {
+      if (valueType.isNullType()) {
+        return true;
+      }
+      ArrayTypeElement arrayTypeElement = valueType.asArrayType();
+      if (arrayTypeElement == null
+          || arrayTypeElement.getNesting() != elementType.getNumberOfLeadingSquareBrackets()) {
+        return false;
+      }
+      valueType = arrayTypeElement.getBaseType();
+      elementType = elementType.toBaseType(dexItemFactory);
+    }
+    assert !valueType.isArrayType();
+    assert !elementType.isArrayType();
+    if (valueType.isPrimitiveType() && !elementType.isPrimitiveType()) {
+      return false;
+    }
+    if (valueType.isPrimitiveType()) {
+      return true;
+    }
+    DexClass clazz = appView.definitionFor(elementType);
+    if (clazz == null) {
+      return false;
+    }
+    return clazz.isInterface() || valueType.isClassType(elementType);
+  }
+
   private boolean canUseFilledNewArray(DexType arrayType, int size, RewriteArrayOptions options) {
     if (size < options.minSizeForFilledNewArray) {
       return false;
diff --git a/src/test/java/com/android/tools/r8/rewrite/arrays/SimplifyArrayConstructionTest.java b/src/test/java/com/android/tools/r8/rewrite/arrays/SimplifyArrayConstructionTest.java
index 38870bd..76fcda9 100644
--- a/src/test/java/com/android/tools/r8/rewrite/arrays/SimplifyArrayConstructionTest.java
+++ b/src/test/java/com/android/tools/r8/rewrite/arrays/SimplifyArrayConstructionTest.java
@@ -170,6 +170,8 @@
         mainClass.uniqueMethodWithOriginalName("referenceArraysNoCasts");
     MethodSubject referenceArraysWithSubclasses =
         mainClass.uniqueMethodWithOriginalName("referenceArraysWithSubclasses");
+    MethodSubject referenceArraysWithInterfaceImplementations =
+        mainClass.uniqueMethodWithOriginalName("referenceArraysWithInterfaceImplementations");
     MethodSubject interfaceArrayWithRawObject =
         mainClass.uniqueMethodWithOriginalName("interfaceArrayWithRawObject");
 
@@ -220,18 +222,23 @@
     if (parameters.getApiLevel().isLessThan(AndroidApiLevel.N)) {
       assertArrayTypes(referenceArraysNoCasts, DexNewArray.class);
       assertArrayTypes(referenceArraysWithSubclasses, DexNewArray.class);
+      assertArrayTypes(referenceArraysWithInterfaceImplementations, DexNewArray.class);
       assertArrayTypes(phiFilledNewArray, DexNewArray.class);
       assertArrayTypes(objectArraysFilledNewArrayRange, DexNewArray.class);
       assertArrayTypes(twoDimensionalArrays, DexNewArray.class);
       assertArrayTypes(assumedValues, DexNewArray.class);
     } else {
       assertArrayTypes(referenceArraysNoCasts, DexFilledNewArray.class);
-      // TODO(b/246971330): Add support for arrays with subtypes.
-      if (isR8) {
+      if (isR8 && parameters.canUseSubTypesInFilledNewArray()) {
         assertArrayTypes(referenceArraysWithSubclasses, DexFilledNewArray.class);
       } else {
         assertArrayTypes(referenceArraysWithSubclasses, DexNewArray.class);
       }
+      if (isR8) {
+        assertArrayTypes(referenceArraysWithInterfaceImplementations, DexFilledNewArray.class);
+      } else {
+        assertArrayTypes(referenceArraysWithInterfaceImplementations, DexNewArray.class);
+      }
 
       // TODO(b/246971330): Add support for arrays whose values have conditionals.
       // assertArrayTypes(phiFilledNewArray, DexFilledNewArray.class);
@@ -298,6 +305,7 @@
       stringArrays();
       referenceArraysNoCasts();
       referenceArraysWithSubclasses();
+      referenceArraysWithInterfaceImplementations();
       interfaceArrayWithRawObject();
       phiFilledNewArray();
       intsThatUseFilledNewArray();
@@ -338,9 +346,13 @@
     }
 
     @NeverInline
-    private static void referenceArraysWithSubclasses() {
+    private static void referenceArraysWithInterfaceImplementations() {
       Serializable[] interfaceArr = {1, null, 2};
       System.out.println(Arrays.toString(interfaceArr));
+    }
+
+    @NeverInline
+    private static void referenceArraysWithSubclasses() {
       Number[] objArray = {1, null, 2};
       System.out.println(Arrays.toString(objArray));
     }