filled-new-array: Add support for InvokeNewArray to proto shrinker

Bug: 246971330
Change-Id: Ic5d26a0ca645a581bd2cb4e4e5b6dabcc43c0fa1
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
index 74a3fb5..b010a58 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
@@ -33,6 +33,7 @@
 import com.android.tools.r8.ir.code.InvokeDirect;
 import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.InvokeMethodWithReceiver;
+import com.android.tools.r8.ir.code.InvokeNewArray;
 import com.android.tools.r8.ir.code.MemberType;
 import com.android.tools.r8.ir.code.NewArrayEmpty;
 import com.android.tools.r8.ir.code.NewInstance;
@@ -46,6 +47,7 @@
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.Sets;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
@@ -316,17 +318,36 @@
     Value sizeValue =
         instructionIterator.insertConstIntInstruction(code, appView.options(), objects.size());
     Value newObjectsValue = code.createValue(objectArrayType);
-    instructionIterator.add(
-        new NewArrayEmpty(newObjectsValue, sizeValue, appView.dexItemFactory().objectArrayType));
 
     // Populate the `objects` array.
-    for (int i = 0; i < objects.size(); i++) {
-      Value indexValue = instructionIterator.insertConstIntInstruction(code, appView.options(), i);
-      Instruction materializingInstruction = objects.get(i).buildIR(appView, code);
-      instructionIterator.add(materializingInstruction);
+    if (appView.options().rewriteArrayOptions().experimentalNewFilledArraySupport
+        && appView.options().rewriteArrayOptions().canUseFilledNewArrayOfObjects()) {
+      List<Value> arrayValues = new ArrayList<>(objects.size());
+      for (int i = 0; i < objects.size(); i++) {
+        Instruction materializingInstruction = objects.get(i).buildIR(appView, code);
+        instructionIterator.add(materializingInstruction);
+        arrayValues.add(materializingInstruction.outValue());
+      }
       instructionIterator.add(
-          new ArrayPut(
-              MemberType.OBJECT, newObjectsValue, indexValue, materializingInstruction.outValue()));
+          new InvokeNewArray(
+              appView.dexItemFactory().objectArrayType, newObjectsValue, arrayValues));
+    } else {
+      instructionIterator.add(
+          new NewArrayEmpty(newObjectsValue, sizeValue, appView.dexItemFactory().objectArrayType));
+
+      // Populate the `objects` array.
+      for (int i = 0; i < objects.size(); i++) {
+        Value indexValue =
+            instructionIterator.insertConstIntInstruction(code, appView.options(), i);
+        Instruction materializingInstruction = objects.get(i).buildIR(appView, code);
+        instructionIterator.add(materializingInstruction);
+        instructionIterator.add(
+            new ArrayPut(
+                MemberType.OBJECT,
+                newObjectsValue,
+                indexValue,
+                materializingInstruction.outValue()));
+      }
     }
 
     // Pass the newly created `objects` array to RawMessageInfo.<init>(...) or
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
index fbe3ff5..c25c456 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
@@ -31,6 +31,7 @@
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InstructionIterator;
 import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.InvokeNewArray;
 import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.NewArrayEmpty;
 import com.android.tools.r8.ir.code.StaticGet;
@@ -302,18 +303,30 @@
    */
   private static ThrowingIterator<Value, InvalidRawMessageInfoException> createObjectIterator(
       Value objectsValue) throws InvalidRawMessageInfoException {
-    if (objectsValue.isPhi() || !objectsValue.definition.isNewArrayEmpty()) {
+    if (objectsValue.isPhi()) {
       throw new InvalidRawMessageInfoException();
     }
 
     NewArrayEmpty newArrayEmpty = objectsValue.definition.asNewArrayEmpty();
-    int expectedArraySize = objectsValue.uniqueUsers().size() - 1;
+    InvokeNewArray invokeNewArray = objectsValue.definition.asInvokeNewArray();
 
-    // Verify that the size is correct.
+    if (newArrayEmpty == null && invokeNewArray == null) {
+      throw new InvalidRawMessageInfoException();
+    }
+    // Verify that the array is used in only one spot.
+    if (invokeNewArray != null) {
+      if (objectsValue.uniqueUsers().size() != 1) {
+        throw new InvalidRawMessageInfoException();
+      }
+      return ThrowingIterator.fromIterator(invokeNewArray.inValues().iterator());
+    }
+
     Value sizeValue = newArrayEmpty.size().getAliasedValue();
-    if (sizeValue.isPhi()
-        || !sizeValue.definition.isConstNumber()
-        || sizeValue.definition.asConstNumber().getIntValue() != expectedArraySize) {
+    if (sizeValue.isPhi() || !sizeValue.definition.isConstNumber()) {
+      throw new InvalidRawMessageInfoException();
+    }
+    int arraySize = sizeValue.definition.asConstNumber().getIntValue();
+    if (arraySize != objectsValue.uniqueUsers().size() - 1) {
       throw new InvalidRawMessageInfoException();
     }
 
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java b/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
index e06cdbee..99f92eb 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
@@ -73,20 +73,25 @@
    * <p>Use with caution!
    */
   public static void removeArrayAndTransitiveInputsIfNotUsed(IRCode code, Instruction definition) {
-    Deque<InstructionOrPhi> worklist = new ArrayDeque<>();
     if (definition.isConstNumber()) {
       // No need to explicitly remove `null`, it will be removed by ordinary dead code elimination
       // anyway.
       assert definition.asConstNumber().isZero();
       return;
     }
-
-    if (definition.isNewArrayEmpty()) {
-      Value arrayValue = definition.outValue();
-      if (arrayValue.hasPhiUsers() || arrayValue.hasDebugUsers()) {
-        return;
-      }
-
+    Value arrayValue = definition.outValue();
+    if (arrayValue.hasPhiUsers() || arrayValue.hasDebugUsers()) {
+      return;
+    }
+    if (!definition.isNewArrayEmptyOrInvokeNewArray()) {
+      assert false;
+      return;
+    }
+    Deque<InstructionOrPhi> worklist = new ArrayDeque<>();
+    InvokeNewArray invokeNewArray = definition.asInvokeNewArray();
+    if (invokeNewArray != null) {
+      worklist.add(definition);
+    } else if (definition.isNewArrayEmpty()) {
       for (Instruction user : arrayValue.uniqueUsers()) {
         // If we encounter an Assume instruction here, we also need to consider indirect users.
         assert !user.isAssume();
@@ -95,11 +100,10 @@
         }
         worklist.add(user);
       }
-      internalRemoveInstructionAndTransitiveInputsIfNotUsed(code, worklist);
-      return;
+    } else {
+      assert false;
     }
-
-    assert false;
+    internalRemoveInstructionAndTransitiveInputsIfNotUsed(code, worklist);
   }
 
   /**
diff --git a/src/main/java/com/android/tools/r8/utils/ThrowingIterator.java b/src/main/java/com/android/tools/r8/utils/ThrowingIterator.java
index 50d23c9..7d41ae1 100644
--- a/src/main/java/com/android/tools/r8/utils/ThrowingIterator.java
+++ b/src/main/java/com/android/tools/r8/utils/ThrowingIterator.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.utils;
 
 import java.util.ArrayList;
+import java.util.Iterator;
 import java.util.List;
 import java.util.NoSuchElementException;
 
@@ -32,4 +33,18 @@
     }
     return result;
   }
+
+  public static <T, E extends Exception> ThrowingIterator<T, E> fromIterator(Iterator<T> it) {
+    return new ThrowingIterator<>() {
+      @Override
+      public boolean hasNext() {
+        return it.hasNext();
+      }
+
+      @Override
+      public T next() throws E {
+        return it.next();
+      }
+    };
+  }
 }