Update indices in proto schemas after tree shaking

Change-Id: I4295abb40e57c324377bc3aef7e6aaf72db8db1e
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoUtils.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoUtils.java
index 06a7bfe..4b8747e 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoUtils.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoUtils.java
@@ -10,6 +10,8 @@
 
 public class ProtoUtils {
 
+  public static final int IS_PROTO_2_MASK = 0x1;
+
   static Value getInfoValueFromMessageInfoConstructionInvoke(
       InvokeMethod invoke, ProtoReferences references) {
     assert references.isMessageInfoConstructionMethod(invoke.getInvokedMethod());
@@ -30,4 +32,8 @@
     int adjustment = BooleanUtils.intValue(invoke.isInvokeDirect());
     invoke.replaceValue(2 + adjustment, newObjectsValue);
   }
+
+  public static boolean isProto2(int flags) {
+    return (flags & IS_PROTO_2_MASK) != 0;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
index dd627f8..83607ca 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/RawMessageInfoDecoder.java
@@ -67,8 +67,6 @@
  */
 public class RawMessageInfoDecoder {
 
-  public static final int IS_PROTO_2_MASK = 0x1;
-
   private final ProtoFieldTypeFactory factory;
   private final ProtoReferences references;
 
@@ -135,7 +133,7 @@
                 objectIterator.computeNextIfAbsent(this::invalidObjectsFailure), context));
       }
 
-      boolean isProto2 = (flags & IS_PROTO_2_MASK) != 0;
+      boolean isProto2 = ProtoUtils.isProto2(flags);
       for (int i = 0; i < fieldCount; i++) {
         // Extract field-specific portion of "info" string.
         int fieldNumber = infoIterator.nextIntComputeIfAbsent(this::invalidInfoFailure);
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
index d9ef276..a8e75f5 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoFieldInfo.java
@@ -20,7 +20,7 @@
    * Index into {@link ProtoMessageInfo#oneOfObjects} or {@link ProtoMessageInfo#hasBitsObjects}.
    * Only used for oneof and proto2 singular fields.
    */
-  private final OptionalInt auxData;
+  private OptionalInt auxData;
 
   /**
    * For any non-oneof field, the first entry will be a reference to a java.lang.String literal. For
@@ -46,6 +46,11 @@
     return auxData.getAsInt();
   }
 
+  void setAuxData(int value) {
+    assert hasAuxData();
+    auxData = OptionalInt.of(value);
+  }
+
   public int getNumber() {
     return number;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
index aaea56f..c6f95fe 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoMessageInfo.java
@@ -4,9 +4,10 @@
 
 package com.android.tools.r8.ir.analysis.proto.schema;
 
-import static com.android.tools.r8.ir.analysis.proto.RawMessageInfoDecoder.IS_PROTO_2_MASK;
-
+import com.android.tools.r8.ir.analysis.proto.ProtoUtils;
 import com.android.tools.r8.utils.Pair;
+import it.unimi.dsi.fastutil.ints.Int2IntArrayMap;
+import it.unimi.dsi.fastutil.ints.Int2IntMap;
 import it.unimi.dsi.fastutil.ints.IntArrayList;
 import it.unimi.dsi.fastutil.ints.IntList;
 import java.util.Iterator;
@@ -72,44 +73,64 @@
     }
 
     private void removeUnusedSharedData() {
+      if (fields == null || fields.isEmpty()) {
+        oneOfObjects = null;
+        hasBitsObjects = null;
+        return;
+      }
+
       // Gather used "oneof" and "hasbits" indices.
-      IntList usedOneofIndices = new IntArrayList();
+      IntList usedOneOfIndices = new IntArrayList();
       IntList usedHasBitsIndices = new IntArrayList();
-      if (fields != null) {
-        for (ProtoFieldInfo field : fields) {
-          if (field.hasAuxData()) {
-            if (field.getType().isOneOf()) {
-              usedOneofIndices.add(field.getAuxData());
-            } else {
-              usedHasBitsIndices.add(field.getAuxData() / BITS_PER_HAS_BITS_WORD);
-            }
+      for (ProtoFieldInfo field : fields) {
+        if (field.hasAuxData()) {
+          if (field.getType().isOneOf()) {
+            usedOneOfIndices.add(field.getAuxData());
+          } else {
+            assert ProtoUtils.isProto2(flags) && field.getType().isSingular();
+            usedHasBitsIndices.add(field.getAuxData() / BITS_PER_HAS_BITS_WORD);
           }
         }
       }
 
       // Remove unused parts of "oneof" vector.
+      Int2IntMap newOneOfObjectIndices = new Int2IntArrayMap();
       if (oneOfObjects != null) {
         Iterator<Pair<ProtoObject, ProtoObject>> oneOfObjectIterator = oneOfObjects.iterator();
-        for (int i = 0; i < oneOfObjects.size(); i++) {
+        for (int i = 0, numberOfRemovedOneOfObjects = 0; i < oneOfObjects.size(); i++) {
           oneOfObjectIterator.next();
-          if (!usedOneofIndices.contains(i)) {
+          if (usedOneOfIndices.contains(i)) {
+            newOneOfObjectIndices.put(i, i - numberOfRemovedOneOfObjects);
+          } else {
             oneOfObjectIterator.remove();
+            numberOfRemovedOneOfObjects++;
           }
         }
       }
 
       // Remove unused parts of "hasbits" vector.
+      Int2IntMap newHasBitsObjectIndices = new Int2IntArrayMap();
       if (hasBitsObjects != null) {
         Iterator<ProtoObject> hasBitsObjectIterator = hasBitsObjects.iterator();
-        for (int i = 0; i < hasBitsObjects.size(); i++) {
+        for (int i = 0, numberOfRemovedHasBitsObjects = 0; i < hasBitsObjects.size(); i++) {
           hasBitsObjectIterator.next();
-          if (!usedHasBitsIndices.contains(i)) {
+          if (usedHasBitsIndices.contains(i)) {
+            newHasBitsObjectIndices.put(i, i - numberOfRemovedHasBitsObjects);
+          } else {
             hasBitsObjectIterator.remove();
+            numberOfRemovedHasBitsObjects++;
           }
         }
       }
 
-      // TODO(b/112437944): Fix up references + add a test that fails when references are not fixed.
+      // Fix up references.
+      for (ProtoFieldInfo field : fields) {
+        if (field.hasAuxData()) {
+          Int2IntMap indexMapping =
+              field.getType().isOneOf() ? newOneOfObjectIndices : newHasBitsObjectIndices;
+          field.setAuxData(indexMapping.get(field.getAuxData()));
+        }
+      }
     }
   }
 
@@ -135,7 +156,7 @@
   }
 
   public boolean isProto2() {
-    return (flags & IS_PROTO_2_MASK) != 0;
+    return ProtoUtils.isProto2(flags);
   }
 
   public List<ProtoFieldInfo> getFields() {