Encode info for RawMessageInfo to DexString

Change-Id: I5807ad14789f3a0745b09ce75f178c118acbb23e
diff --git a/src/main/java/com/android/tools/r8/graph/DexString.java b/src/main/java/com/android/tools/r8/graph/DexString.java
index 96a1aa2..dc1a5ec 100644
--- a/src/main/java/com/android/tools/r8/graph/DexString.java
+++ b/src/main/java/com/android/tools/r8/graph/DexString.java
@@ -175,20 +175,23 @@
 
   // Inspired from /dex/src/main/java/com/android/dex/Mutf8.java
   private static int countBytes(String string) {
-    int result = 0;
+    // We need an extra byte for the terminating '0'.
+    int result = 1;
     for (int i = 0; i < string.length(); ++i) {
-      char ch = string.charAt(i);
-      if (ch != 0 && ch <= 127) { // U+0000 uses two bytes.
-        ++result;
-      } else if (ch <= 2047) {
-        result += 2;
-      } else {
-        result += 3;
-      }
+      result += countBytes(string.charAt(i));
       assert result > 0;
     }
-    // We need an extra byte for the terminating '0'.
-    return result + 1;
+    return result;
+  }
+
+  public static int countBytes(char ch) {
+    if (ch != 0 && ch <= 127) { // U+0000 uses two bytes.
+      return 1;
+    }
+    if (ch <= 2047) {
+      return 2;
+    }
+    return 3;
   }
 
   // Inspired from /dex/src/main/java/com/android/dex/Mutf8.java
@@ -196,22 +199,26 @@
     byte[] result = new byte[countBytes(string)];
     int offset = 0;
     for (int i = 0; i < string.length(); i++) {
-      char ch = string.charAt(i);
-      if (ch != 0 && ch <= 127) { // U+0000 uses two bytes.
-        result[offset++] = (byte) ch;
-      } else if (ch <= 2047) {
-        result[offset++] = (byte) (0xc0 | (0x1f & (ch >> 6)));
-        result[offset++] = (byte) (0x80 | (0x3f & ch));
-      } else {
-        result[offset++] = (byte) (0xe0 | (0x0f & (ch >> 12)));
-        result[offset++] = (byte) (0x80 | (0x3f & (ch >> 6)));
-        result[offset++] = (byte) (0x80 | (0x3f & ch));
-      }
+      offset = encodeToMutf8(string.charAt(i), result, offset);
     }
     result[offset] = 0;
     return result;
   }
 
+  public static int encodeToMutf8(char ch, byte[] array, int offset) {
+    if (ch != 0 && ch <= 127) { // U+0000 uses two bytes.
+      array[offset++] = (byte) ch;
+    } else if (ch <= 2047) {
+      array[offset++] = (byte) (0xc0 | (0x1f & (ch >> 6)));
+      array[offset++] = (byte) (0x80 | (0x3f & ch));
+    } else {
+      array[offset++] = (byte) (0xe0 | (0x0f & (ch >> 12)));
+      array[offset++] = (byte) (0x80 | (0x3f & (ch >> 6)));
+      array[offset++] = (byte) (0x80 | (0x3f & ch));
+    }
+    return offset;
+  }
+
   @Override
   public void collectIndexedItems(IndexedItemCollection indexedItems,
       DexMethod method, int instructionOffset) {
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
index c676478..c7eaa02 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
@@ -11,6 +11,7 @@
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.ir.analysis.proto.schema.ProtoFieldTypeFactory;
+import com.android.tools.r8.ir.analysis.proto.schema.ProtoMessageInfo;
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.code.BasicBlock.ThrowingInfo;
@@ -31,14 +32,20 @@
 public class GeneratedMessageLiteShrinker {
 
   private final AppView<AppInfoWithLiveness> appView;
+  private final RawMessageInfoDecoder decoder;
+  private final RawMessageInfoEncoder encoder;
   private final ProtoReferences references;
+  private final TypeLatticeElement stringType;
   private final ThrowingInfo throwingInfo;
 
   private final ProtoFieldTypeFactory factory = new ProtoFieldTypeFactory();
 
   public GeneratedMessageLiteShrinker(AppView<AppInfoWithLiveness> appView) {
     this.appView = appView;
+    this.decoder = new RawMessageInfoDecoder(factory);
+    this.encoder = new RawMessageInfoEncoder(appView.dexItemFactory());
     this.references = appView.protoShrinker().references;
+    this.stringType = TypeLatticeElement.stringClassType(appView, Nullability.definitelyNotNull());
     this.throwingInfo = ThrowingInfo.defaultForConstString(appView.options());
   }
 
@@ -62,28 +69,14 @@
       return;
     }
 
-    List<ConstString> rewritingCandidates = new ArrayList<>();
+    InvokeStatic newMessageInfoInvoke = null;
     for (Instruction instruction : code.instructions()) {
       if (instruction.isInvokeStatic()) {
         InvokeStatic invoke = instruction.asInvokeStatic();
-        if (invoke.getInvokedMethod() != references.newMessageInfoMethod) {
-          continue;
+        if (invoke.getInvokedMethod() == references.newMessageInfoMethod) {
+          newMessageInfoInvoke = invoke;
+          break;
         }
-        Value infoValue = invoke.inValues().get(1);
-        Value objectsValue = invoke.inValues().get(2);
-        for (Instruction user : objectsValue.uniqueUsers()) {
-          if (!user.isArrayPut()) {
-            continue;
-          }
-          Value rewritingCandidate = user.asArrayPut().value().getAliasedValue();
-          if (rewritingCandidate.isPhi() || !rewritingCandidate.definition.isConstString()) {
-            continue;
-          }
-          rewritingCandidates.add(rewritingCandidate.definition.asConstString());
-        }
-
-        // For now, just verify that we can actually decode the raw message info from the IR.
-        assert new RawMessageInfoDecoder(factory, infoValue, objectsValue).run() != null;
       }
 
       // Implicitly check that the method newMessageInfo() has not been inlined. In that case,
@@ -93,24 +86,54 @@
           || instruction.asNewInstance().clazz != references.rawMessageInfoType;
     }
 
-    boolean changed = false;
-    for (ConstString rewritingCandidate : rewritingCandidates) {
-      DexString fieldName = rewritingCandidate.getValue();
-      DexField field = uniqueInstanceFieldWithName(clazz, fieldName, code.origin);
-      if (field == null) {
-        continue;
-      }
-      Value newValue =
-          code.createValue(
-              TypeLatticeElement.stringClassType(appView, Nullability.definitelyNotNull()));
-      rewritingCandidate.replace(
-          new DexItemBasedConstString(
-              newValue, field, FieldNameComputationInfo.forFieldName(), throwingInfo));
-      changed = true;
-    }
+    if (newMessageInfoInvoke != null) {
+      Value infoValue = newMessageInfoInvoke.inValues().get(1).getAliasedValue();
+      Value objectsValue = newMessageInfoInvoke.inValues().get(2).getAliasedValue();
 
-    if (changed) {
-      method.getMutableOptimizationInfo().markUseIdentifierNameString();
+      // TODO(b/112437944): If we regenerate the arguments to newMessageInfo() entirely, then we can
+      //  simply generate DexItemBasedConstString instructions at that point. That way the block
+      //  below will not be needed.
+      {
+        List<ConstString> identifierNameStringCandidates = new ArrayList<>();
+        for (Instruction user : objectsValue.uniqueUsers()) {
+          if (user.isArrayPut()) {
+            Value rewritingCandidate = user.asArrayPut().value().getAliasedValue();
+            if (!rewritingCandidate.isPhi() && rewritingCandidate.definition.isConstString()) {
+              identifierNameStringCandidates.add(rewritingCandidate.definition.asConstString());
+            }
+          }
+        }
+
+        boolean changed = false;
+        for (ConstString rewritingCandidate : identifierNameStringCandidates) {
+          DexString fieldName = rewritingCandidate.getValue();
+          DexField field = uniqueInstanceFieldWithName(clazz, fieldName, code.origin);
+          if (field == null) {
+            continue;
+          }
+          Value newValue = code.createValue(stringType);
+          rewritingCandidate.replace(
+              new DexItemBasedConstString(
+                  newValue, field, FieldNameComputationInfo.forFieldName(), throwingInfo));
+          changed = true;
+        }
+
+        if (changed) {
+          method.getMutableOptimizationInfo().markUseIdentifierNameString();
+        }
+      }
+
+      // Decode the arguments passed to newMessageInfo().
+      ProtoMessageInfo protoMessageInfo = decoder.run(infoValue, objectsValue);
+      if (protoMessageInfo != null) {
+        // Rewrite the arguments to newMessageInfo().
+        infoValue.definition.replace(
+            new ConstString(
+                code.createValue(stringType), encoder.encodeInfo(protoMessageInfo), throwingInfo));
+      } else {
+        // We should generally be able to decode the arguments passed to newMessageInfo().
+        assert false;
+      }
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
index d57f353..fdc3934 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
@@ -49,19 +49,16 @@
   private static final int IS_PROTO_2_MASK = 0x1;
 
   private final ProtoFieldTypeFactory factory;
-  private final Value infoValue;
-  private final Value objectsValue;
 
-  public RawMessageInfoDecoder(ProtoFieldTypeFactory factory, Value infoValue, Value objectsValue) {
+  RawMessageInfoDecoder(ProtoFieldTypeFactory factory) {
     this.factory = factory;
-    this.infoValue = infoValue.getAliasedValue();
-    this.objectsValue = objectsValue.getAliasedValue();
   }
 
-  public ProtoMessageInfo run() {
+  public ProtoMessageInfo run(Value infoValue, Value objectsValue) {
     try {
       ProtoMessageInfo.Builder builder = ProtoMessageInfo.builder();
-      ThrowingIntIterator<InvalidRawMessageInfoException> infoIterator = createInfoIterator();
+      ThrowingIntIterator<InvalidRawMessageInfoException> infoIterator =
+          createInfoIterator(infoValue);
 
       // flags := info[0].
       int flags = infoIterator.nextIntComputeIfAbsent(this::invalidInfoFailure);
@@ -86,11 +83,13 @@
       // repeatedFieldCount := info[8].
       // checkInitialized   := info[9].
       for (int i = 4; i < 10; i++) {
+        // No need to store these values, since they can be computed from the rest (and need to be
+        // recomputed if the proto is changed).
         infoIterator.nextIntComputeIfAbsent(this::invalidInfoFailure);
       }
 
       ThrowingIterator<Value, InvalidRawMessageInfoException> objectIterator =
-          createObjectIterator();
+          createObjectIterator(objectsValue);
 
       if (numberOfOneOfObjects > 0) {
         builder.setNumberOfOneOfObjects(numberOfOneOfObjects);
@@ -153,8 +152,8 @@
     throw new InvalidRawMessageInfoException();
   }
 
-  private ThrowingIntIterator<InvalidRawMessageInfoException> createInfoIterator()
-      throws InvalidRawMessageInfoException {
+  private static ThrowingIntIterator<InvalidRawMessageInfoException> createInfoIterator(
+      Value infoValue) throws InvalidRawMessageInfoException {
     if (!infoValue.isPhi() && infoValue.definition.isConstString()) {
       return createInfoIterator(infoValue.definition.asConstString().getValue());
     }
@@ -162,7 +161,8 @@
   }
 
   /** Returns an iterator that yields the integers that results from decoding the given string. */
-  private ThrowingIntIterator<InvalidRawMessageInfoException> createInfoIterator(DexString info) {
+  private static ThrowingIntIterator<InvalidRawMessageInfoException> createInfoIterator(
+      DexString info) {
     return new ThrowingIntIterator<InvalidRawMessageInfoException>() {
 
       private final ThrowingCharIterator<UTFDataFormatException> charIterator = info.iterator();
@@ -207,8 +207,8 @@
    * passed to GeneratedMessageLite.newMessageInfo(). The array values are returned in-order, i.e.,
    * the value objects[i] will be returned prior to the value objects[i+1].
    */
-  private ThrowingIterator<Value, InvalidRawMessageInfoException> createObjectIterator()
-      throws InvalidRawMessageInfoException {
+  private static ThrowingIterator<Value, InvalidRawMessageInfoException> createObjectIterator(
+      Value objectsValue) throws InvalidRawMessageInfoException {
     if (objectsValue.isPhi() || !objectsValue.definition.isNewArrayEmpty()) {
       throw new InvalidRawMessageInfoException();
     }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoEncoder.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoEncoder.java
new file mode 100644
index 0000000..8487210
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoEncoder.java
@@ -0,0 +1,110 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.analysis.proto;
+
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.ir.analysis.proto.schema.ProtoFieldInfo;
+import com.android.tools.r8.ir.analysis.proto.schema.ProtoFieldType;
+import com.android.tools.r8.ir.analysis.proto.schema.ProtoMessageInfo;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.ints.IntList;
+import it.unimi.dsi.fastutil.ints.IntListIterator;
+
+public class RawMessageInfoEncoder {
+
+  private final DexItemFactory dexItemFactory;
+
+  RawMessageInfoEncoder(DexItemFactory dexItemFactory) {
+    this.dexItemFactory = dexItemFactory;
+  }
+
+  DexString encodeInfo(ProtoMessageInfo protoMessageInfo) {
+    IntList info = new IntArrayList();
+    info.add(protoMessageInfo.flags());
+    info.add(protoMessageInfo.numberOfFields());
+
+    if (protoMessageInfo.hasFields()) {
+      int minFieldNumber = Integer.MAX_VALUE;
+      int maxFieldNumber = Integer.MIN_VALUE;
+      int mapFieldCount = 0;
+      int repeatedFieldCount = 0;
+      int checkInitialized = 0;
+
+      for (ProtoFieldInfo protoFieldInfo : protoMessageInfo.fields()) {
+        int fieldNumber = protoFieldInfo.getNumber();
+        if (fieldNumber < minFieldNumber) {
+          minFieldNumber = fieldNumber;
+        }
+        if (fieldNumber > maxFieldNumber) {
+          maxFieldNumber = fieldNumber;
+        }
+        ProtoFieldType fieldType = protoFieldInfo.getType();
+        if (fieldType.id() == ProtoFieldType.MAP_ID) {
+          mapFieldCount++;
+        } else if (!fieldType.isSingular()) {
+          repeatedFieldCount++;
+        }
+        if (fieldType.needsIsInitializedCheck()) {
+          checkInitialized++;
+        }
+      }
+
+      info.add(protoMessageInfo.numberOfOneOfObjects());
+      info.add(protoMessageInfo.numberOfHasBitsObjects());
+      info.add(minFieldNumber);
+      info.add(maxFieldNumber);
+      info.add(protoMessageInfo.numberOfFields());
+      info.add(mapFieldCount);
+      info.add(repeatedFieldCount);
+      info.add(checkInitialized);
+
+      for (ProtoFieldInfo protoFieldInfo : protoMessageInfo.fields()) {
+        info.add(protoFieldInfo.getNumber());
+        info.add(protoFieldInfo.getType().serialize());
+        if (protoFieldInfo.hasAuxData()) {
+          info.add(protoFieldInfo.getAuxData());
+        }
+      }
+    }
+
+    return encodeInfo(info);
+  }
+
+  private DexString encodeInfo(IntList info) {
+    byte[] result = new byte[countBytes(info)];
+    int numberOfExtraChars = 0;
+    int offset = 0;
+    IntListIterator iterator = info.iterator();
+    while (iterator.hasNext()) {
+      int value = iterator.nextInt();
+      while (value >= 0xD800) {
+        char c = (char) (0xE000 | (value & 0x1FFF));
+        offset = DexString.encodeToMutf8(c, result, offset);
+        numberOfExtraChars++;
+        value >>= 13;
+      }
+      offset = DexString.encodeToMutf8((char) value, result, offset);
+    }
+    result[offset] = 0;
+    return dexItemFactory.createString(info.size() + numberOfExtraChars, result);
+  }
+
+  private static int countBytes(IntList info) {
+    // We need an extra byte for the terminating '0'.
+    int result = 1;
+    IntListIterator iterator = info.iterator();
+    while (iterator.hasNext()) {
+      int value = iterator.nextInt();
+      while (value >= 0xD800) {
+        char c = (char) (0xE000 | (value & 0x1FFF));
+        value >>= 13;
+        result += DexString.countBytes(c);
+      }
+      result += DexString.countBytes((char) value);
+    }
+    return result;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
index a908cc5..d701c02 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
@@ -24,4 +24,21 @@
     this.auxData = auxData;
     this.objects = objects;
   }
+
+  public boolean hasAuxData() {
+    return auxData.isPresent();
+  }
+
+  public int getAuxData() {
+    assert hasAuxData();
+    return auxData.getAsInt();
+  }
+
+  public int getNumber() {
+    return number;
+  }
+
+  public ProtoFieldType getType() {
+    return type;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
index 25ff283..45bfdc0 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
@@ -81,4 +81,28 @@
   public static ProtoMessageInfo.Builder builder() {
     return new ProtoMessageInfo.Builder();
   }
+
+  public List<ProtoFieldInfo> fields() {
+    return fields;
+  }
+
+  public int flags() {
+    return flags;
+  }
+
+  public boolean hasFields() {
+    return fields != null && !fields.isEmpty();
+  }
+
+  public int numberOfFields() {
+    return fields != null ? fields.size() : 0;
+  }
+
+  public int numberOfHasBitsObjects() {
+    return hasBitsObjects != null ? hasBitsObjects.size() : 0;
+  }
+
+  public int numberOfOneOfObjects() {
+    return oneOfObjects != null ? oneOfObjects.size() : 0;
+  }
 }