Allow phis in proto extension registry shrinking

Bug: 183734568
Change-Id: I7854d6a1821fff23eee8fb8bbfc57993b32a4f0b
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
index eb00100..e7d953a 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
@@ -8,6 +8,7 @@
 
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClassAndField;
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
@@ -30,6 +31,7 @@
 import com.android.tools.r8.ir.code.IRCodeUtils;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.StaticPut;
 import com.android.tools.r8.ir.code.Value;
@@ -48,6 +50,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Consumer;
 import java.util.function.Predicate;
 
 // TODO(b/112437944): Handle cycles in the graph + add a test that fails with the current
@@ -247,48 +250,56 @@
     Map<DexProgramClass, Set<DexEncodedField>> extensionFieldsByClass = new IdentityHashMap<>();
     for (ProgramMethod findLiteExtensionByNumberMethod : findLiteExtensionByNumberMethods) {
       IRCode code = findLiteExtensionByNumberMethod.buildIR(appView);
+      Set<Phi> seenPhis = Sets.newIdentityHashSet();
       for (BasicBlock block : code.blocks(BasicBlock::isReturnBlock)) {
-        Value returnValue = block.exit().asReturn().returnValue().getAliasedValue();
-        if (returnValue.isPhi()) {
-          assert false;
-          continue;
-        }
-
-        if (returnValue.isZero()) {
-          continue; // OK.
-        }
-
-        Instruction definition = returnValue.definition;
-        if (definition.isStaticGet()) {
-          StaticGet staticGet = definition.asStaticGet();
-          DexEncodedField field =
-              appView.appInfo().resolveField(staticGet.getField()).getResolvedField();
-          if (field == null) {
-            assert false;
-            continue;
-          }
-
-          DexProgramClass holder =
-              asProgramClassOrNull(appView.definitionFor(field.getHolderType()));
-          if (holder == null) {
-            assert false;
-            continue;
-          }
-
-          extensionFieldsByClass
-              .computeIfAbsent(holder, ignore -> Sets.newIdentityHashSet())
-              .add(field);
-          continue;
-        }
-
-        assert definition.isInvokeMethodWithReceiver()
-            && references.isFindLiteExtensionByNumberMethod(
-                definition.asInvokeMethodWithReceiver().getInvokedMethod());
+        Value returnValue = block.exit().asReturn().returnValue();
+        collectExtensionFieldsFromValue(
+            returnValue,
+            seenPhis,
+            field ->
+                extensionFieldsByClass
+                    .computeIfAbsent(field.getHolder(), ignore -> Sets.newIdentityHashSet())
+                    .add(field.getDefinition()));
       }
     }
     return extensionFieldsByClass;
   }
 
+  private void collectExtensionFieldsFromValue(
+      Value returnValue, Set<Phi> seenPhis, Consumer<ProgramField> consumer) {
+    Value root = returnValue.getAliasedValue();
+    if (root.isPhi()) {
+      Phi phi = root.asPhi();
+      if (seenPhis.add(phi)) {
+        for (Value operand : phi.getOperands()) {
+          collectExtensionFieldsFromValue(operand, seenPhis, consumer);
+        }
+      }
+      return;
+    }
+
+    if (root.isZero()) {
+      return;
+    }
+
+    Instruction definition = root.definition;
+    if (definition.isStaticGet()) {
+      StaticGet staticGet = definition.asStaticGet();
+      DexClassAndField field =
+          appView.appInfo().resolveField(staticGet.getField()).getResolutionPair();
+      if (field == null || !field.isProgramField()) {
+        assert false;
+        return;
+      }
+      consumer.accept(field.asProgramField());
+      return;
+    }
+
+    assert definition.isInvokeMethodWithReceiver()
+        && references.isFindLiteExtensionByNumberMethod(
+            definition.asInvokeMethodWithReceiver().getInvokedMethod());
+  }
+
   /**
    * Updates {@link #extensionGraph} based on the definition of {@param staticPut}.
    *