Relax assertions in proto optimization

Change-Id: I1e3a70782f5d016a54a99bd00e786e595cb75cbd
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
index 932e81d..812c0ee 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
@@ -196,11 +196,16 @@
     return isDynamicMethodBridge(method.getReference());
   }
 
-  @SuppressWarnings("ReferenceEquality")
   public boolean isFindLiteExtensionByNumberMethod(DexMethod method) {
-    return method.proto == findLiteExtensionByNumberProto
-        && method.name.startsWith(findLiteExtensionByNumberName)
-        && method.holder != extensionRegistryLiteType;
+    return method.getProto().isIdenticalTo(findLiteExtensionByNumberProto)
+        && method.getName().startsWith(findLiteExtensionByNumberName)
+        && method.getHolderType().isNotIdenticalTo(extensionRegistryLiteType);
+  }
+
+  public boolean isFindLiteExtensionByNumberBridgeMethod(DexMethod method) {
+    return method.getProto().isIdenticalTo(findLiteExtensionByNumberProto)
+        && method.getName().startsWith(findLiteExtensionByNumberName)
+        && method.getHolderType().isIdenticalTo(extensionRegistryLiteType);
   }
 
   public boolean isFindLiteExtensionByNumberMethod(ProgramMethod method) {
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
index c81b099..c9a6021 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
@@ -5,11 +5,11 @@
 package com.android.tools.r8.ir.analysis.proto.schema;
 
 import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
 
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClassAndField;
-import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
@@ -45,6 +45,8 @@
 import com.android.tools.r8.utils.BitUtils;
 import com.android.tools.r8.utils.OptionalBool;
 import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.collections.ProgramFieldMap;
+import com.android.tools.r8.utils.collections.ProgramFieldSet;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
@@ -52,7 +54,6 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.function.Consumer;
 import java.util.function.Predicate;
 
 // TODO(b/112437944): Handle cycles in the graph + add a test that fails with the current
@@ -224,7 +225,7 @@
    * inlining.
    */
   private void populateExtensionGraph(Enqueuer enqueuer) {
-    collectExtensionFields()
+    collectProgramExtensionFields()
         .forEach(
             (clazz, extensionFields) -> {
               ProgramMethod clinit = clazz.getProgramClassInitializer();
@@ -234,9 +235,9 @@
               }
 
               IRCode code = clinit.buildIR(appView, MethodConversionOptions.nonConverting());
-              Map<DexEncodedField, StaticPut> uniqueStaticPuts =
+              ProgramFieldMap<StaticPut> uniqueStaticPuts =
                   IRCodeUtils.findUniqueStaticPuts(appView, code, extensionFields);
-              for (DexEncodedField extensionField : extensionFields) {
+              for (ProgramField extensionField : extensionFields) {
                 StaticPut staticPut = uniqueStaticPuts.get(extensionField);
                 if (staticPut == null) {
                   // Could happen after we have optimized the code.
@@ -252,37 +253,33 @@
   }
 
   /**
-   * Finds the extension fields referenced in the methods in {@link
+   * Finds the program extension fields referenced in the methods in {@link
    * #findLiteExtensionByNumberMethods}.
    */
-  private Map<DexProgramClass, Set<DexEncodedField>> collectExtensionFields() {
-    Map<DexProgramClass, Set<DexEncodedField>> extensionFieldsByClass = new IdentityHashMap<>();
+  private Map<DexProgramClass, ProgramFieldSet> collectProgramExtensionFields() {
+    Map<DexProgramClass, ProgramFieldSet> extensionFieldsByClass = new IdentityHashMap<>();
     for (ProgramMethod findLiteExtensionByNumberMethod : findLiteExtensionByNumberMethods) {
       IRCode code =
           findLiteExtensionByNumberMethod.buildIR(appView, MethodConversionOptions.nonConverting());
       Set<Phi> seenPhis = Sets.newIdentityHashSet();
       for (BasicBlock block : code.blocks(BasicBlock::isReturnBlock)) {
         Value returnValue = block.exit().asReturn().returnValue();
-        collectExtensionFieldsFromValue(
-            returnValue,
-            seenPhis,
-            field ->
-                extensionFieldsByClass
-                    .computeIfAbsent(field.getHolder(), ignore -> Sets.newIdentityHashSet())
-                    .add(field.getDefinition()));
+        collectExtensionFieldsFromValue(returnValue, seenPhis, extensionFieldsByClass);
       }
     }
     return extensionFieldsByClass;
   }
 
   private void collectExtensionFieldsFromValue(
-      Value returnValue, Set<Phi> seenPhis, Consumer<ProgramField> consumer) {
+      Value returnValue,
+      Set<Phi> seenPhis,
+      Map<DexProgramClass, ProgramFieldSet> extensionFieldsByClass) {
     Value root = returnValue.getAliasedValue();
     if (root.isPhi()) {
       Phi phi = root.asPhi();
       if (seenPhis.add(phi)) {
         for (Value operand : phi.getOperands()) {
-          collectExtensionFieldsFromValue(operand, seenPhis, consumer);
+          collectExtensionFieldsFromValue(operand, seenPhis, extensionFieldsByClass);
         }
       }
       return;
@@ -292,22 +289,34 @@
       return;
     }
 
-    Instruction definition = root.definition;
+    Instruction definition = root.getDefinition();
     if (definition.isStaticGet()) {
       StaticGet staticGet = definition.asStaticGet();
       DexClassAndField field =
           appView.appInfo().resolveField(staticGet.getField()).getResolutionPair();
-      if (field == null || !field.isProgramField()) {
-        assert false;
+      if (field == null) {
+        assert staticGet.getField().getType().isIdenticalTo(references.generatedExtensionType);
         return;
       }
-      consumer.accept(field.asProgramField());
-      return;
+      if (field.isProgramField()) {
+        ProgramField programField = field.asProgramField();
+        extensionFieldsByClass
+            .computeIfAbsent(programField.getHolder(), ignoreKey(ProgramFieldSet::create))
+            .add(programField);
+      } else {
+        assert false;
+      }
+    } else {
+      assert verifyIsCallToFindLiteExtensionByNumberMethod(definition);
     }
+  }
 
-    assert definition.isInvokeMethod()
-        && references.isFindLiteExtensionByNumberMethod(
-            definition.asInvokeMethod().getInvokedMethod());
+  private boolean verifyIsCallToFindLiteExtensionByNumberMethod(Instruction instruction) {
+    assert instruction.isInvokeMethod();
+    DexMethod invokedMethod = instruction.asInvokeMethod().getInvokedMethod();
+    assert references.isFindLiteExtensionByNumberMethod(invokedMethod)
+        || references.isFindLiteExtensionByNumberBridgeMethod(invokedMethod);
+    return true;
   }
 
   /**
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java b/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
index 7caf4ec..e918ad4 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
@@ -6,14 +6,14 @@
 
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.ProgramField;
 import com.android.tools.r8.utils.DequeUtils;
+import com.android.tools.r8.utils.collections.ProgramFieldMap;
+import com.android.tools.r8.utils.collections.ProgramFieldSet;
 import com.google.common.collect.Sets;
 import java.util.ArrayDeque;
 import java.util.Deque;
-import java.util.IdentityHashMap;
-import java.util.Map;
 import java.util.Set;
 
 public class IRCodeUtils {
@@ -44,15 +44,12 @@
    * Finds the single assignment to the fields in {@param fields} in {@param code}. Note that this
    * does not guarantee that the assignments found dominate all the normal exits.
    */
-  public static Map<DexEncodedField, StaticPut> findUniqueStaticPuts(
-      AppView<? extends AppInfoWithClassHierarchy> appView,
-      IRCode code,
-      Set<DexEncodedField> fields) {
-    Set<DexEncodedField> writtenMoreThanOnce = Sets.newIdentityHashSet();
-    Map<DexEncodedField, StaticPut> uniqueStaticPuts = new IdentityHashMap<>();
+  public static ProgramFieldMap<StaticPut> findUniqueStaticPuts(
+      AppView<? extends AppInfoWithClassHierarchy> appView, IRCode code, ProgramFieldSet fields) {
+    ProgramFieldSet writtenMoreThanOnce = ProgramFieldSet.create();
+    ProgramFieldMap<StaticPut> uniqueStaticPuts = ProgramFieldMap.create();
     for (StaticPut staticPut : code.<StaticPut>instructions(Instruction::isStaticPut)) {
-      DexEncodedField field =
-          appView.appInfo().resolveField(staticPut.getField()).getResolvedField();
+      ProgramField field = appView.appInfo().resolveField(staticPut.getField()).getProgramField();
       if (field == null || !fields.contains(field) || writtenMoreThanOnce.contains(field)) {
         continue;
       }