Prefer removal of edges to <clinit> in call graph over actual call edges

Change-Id: I9bb2c9ffdf3d377277ce5594a52e714f0e558c9b
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilderBase.java b/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilderBase.java
index 32be662..4e7aab7 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilderBase.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallGraphBuilderBase.java
@@ -10,7 +10,6 @@
 import com.android.tools.r8.graph.DexCallSite;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexClassAndMethod;
-import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
@@ -20,6 +19,7 @@
 import com.android.tools.r8.graph.FieldAccessInfoCollection;
 import com.android.tools.r8.graph.GraphLens.MethodLookupResult;
 import com.android.tools.r8.graph.LookupResult;
+import com.android.tools.r8.graph.ProgramField;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.ResolutionResult;
 import com.android.tools.r8.graph.UseRegistry;
@@ -44,6 +44,7 @@
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 
@@ -185,7 +186,7 @@
           assert !context.getDefinition().isBridge()
               || singleTarget.getDefinition() != context.getDefinition();
           // For static invokes, the class could be initialized.
-          if (type == Invoke.Type.STATIC) {
+          if (type.isStatic()) {
             addClassInitializerTarget(singleTarget.getHolder());
           }
           addCallEdge(singleTarget, false);
@@ -252,28 +253,22 @@
       }
     }
 
-    private void processFieldRead(DexField field) {
-      if (!field.holder.isClassType()) {
+    private void processFieldRead(DexField reference) {
+      if (!reference.holder.isClassType()) {
         return;
       }
 
-      DexEncodedField encodedField = appView.appInfo().resolveField(field).getResolvedField();
-      if (encodedField == null || appView.appInfo().isPinned(encodedField.getReference())) {
-        return;
-      }
-
-      DexProgramClass clazz =
-          asProgramClassOrNull(appView.definitionFor(encodedField.getHolderType()));
-      if (clazz == null) {
+      ProgramField field = appView.appInfo().resolveField(reference).getProgramField();
+      if (field == null || appView.appInfo().isPinned(field)) {
         return;
       }
 
       // Each static field access implicitly triggers the class initializer.
-      if (encodedField.isStatic()) {
-        addClassInitializerTarget(clazz);
+      if (field.getAccessFlags().isStatic()) {
+        addClassInitializerTarget(field.getHolder());
       }
 
-      FieldAccessInfo fieldAccessInfo = fieldAccessInfoCollection.get(encodedField.getReference());
+      FieldAccessInfo fieldAccessInfo = fieldAccessInfoCollection.get(field.getReference());
       if (fieldAccessInfo != null && fieldAccessInfo.hasKnownWriteContexts()) {
         if (fieldAccessInfo.getNumberOfWriteContexts() == 1) {
           fieldAccessInfo.forEachWriteContext(this::addFieldReadEdge);
@@ -281,12 +276,12 @@
       }
     }
 
-    private void processFieldWrite(DexField field) {
-      if (field.holder.isClassType()) {
-        DexEncodedField encodedField = appView.appInfo().resolveField(field).getResolvedField();
-        if (encodedField != null && encodedField.isStatic()) {
+    private void processFieldWrite(DexField reference) {
+      if (reference.getHolderType().isClassType()) {
+        ProgramField field = appView.appInfo().resolveField(reference).getProgramField();
+        if (field != null && field.getAccessFlags().isStatic()) {
           // Each static field access implicitly triggers the class initializer.
-          addClassInitializerTarget(field.holder);
+          addClassInitializerTarget(field.getHolder());
         }
       }
     }
@@ -426,6 +421,11 @@
     // Nodes on the DFS stack.
     private Map<Node, StackEntryInfo> stackEntryInfo = new IdentityHashMap<>();
 
+    // Subset of the DFS stack, where the nodes on the stack are class initializers.
+    //
+    // This stack is used to efficiently compute if there is a class initializer on the stack.
+    private Deque<Node> clinitStack = new ArrayDeque<>();
+
     // Subset of the DFS stack, where the nodes on the stack satisfy that the edge from the
     // predecessor to the node itself is a field read edge.
     //
@@ -471,6 +471,7 @@
 
     private void prepareForNewTraversal() {
       assert calleesToBeRemoved.isEmpty();
+      assert clinitStack.isEmpty();
       assert stack.isEmpty();
       assert stackEntryInfo.isEmpty();
       assert writersToBeRemoved.isEmpty();
@@ -480,6 +481,7 @@
     }
 
     private void reset() {
+      assert clinitStack.isEmpty();
       assert marked.isEmpty();
       assert revisit.isEmpty();
       assert stack.isEmpty();
@@ -624,28 +626,36 @@
 
         // Otherwise, it is a call edge. Check if there is a field read edge in the cycle, and if
         // so, remove that edge.
-        if (!writerStack.isEmpty()) {
-          Node lastKnownWriter = writerStack.peek();
-          StackEntryInfo lastKnownWriterStackEntryInfo = stackEntryInfo.get(lastKnownWriter);
-          boolean cycleContainsLastKnownWriter =
-              lastKnownWriterStackEntryInfo.index > calleeOrWriterStackEntryInfo.index;
-          if (cycleContainsLastKnownWriter) {
-            assert verifyCycleSatisfies(
+        if (!writerStack.isEmpty()
+            && removeIncomingEdgeOnStack(
+                writerStack.peek(),
                 calleeOrWriter,
-                cycle ->
-                    cycle.contains(lastKnownWriter)
-                        && cycle.contains(lastKnownWriterStackEntryInfo.predecessor));
-            if (!lastKnownWriterStackEntryInfo.processed) {
-              removeFieldReadEdge(lastKnownWriterStackEntryInfo.predecessor, lastKnownWriter);
-              revisit.add(lastKnownWriter);
-              lastKnownWriterStackEntryInfo.processed = true;
-            }
-            continue;
-          }
+                calleeOrWriterStackEntryInfo,
+                this::removeFieldReadEdge)) {
+          continue;
         }
 
-        // It is a call edge, and the cycle does not contain any field read edges. In this case, we
-        // remove the call edge if it is safe according to force inlining.
+        // It is a call edge and the cycle does not contain any field read edges.
+        // If it is a call edge to a <clinit>, then remove it.
+        if (calleeOrWriter.getMethod().isClassInitializer()) {
+          // Calls to class initializers are always safe to remove.
+          assert callEdgeRemovalIsSafe(callerOrReader, calleeOrWriter);
+          removeCallEdge(callerOrReader, calleeOrWriter);
+          continue;
+        }
+
+        // Otherwise, check if there is a call edge to a <clinit> method in the cycle, and if so,
+        // remove that edge.
+        if (!clinitStack.isEmpty()
+            && removeIncomingEdgeOnStack(
+                clinitStack.peek(),
+                calleeOrWriter,
+                calleeOrWriterStackEntryInfo,
+                this::removeCallEdge)) {
+          continue;
+        }
+
+        // Otherwise, we remove the call edge if it is safe according to force inlining.
         if (callEdgeRemovalIsSafe(callerOrReader, calleeOrWriter)) {
           // Break the cycle by removing the edge node->calleeOrWriter.
           // Need to remove `calleeOrWriter` from `node.callees` using the iterator to prevent a
@@ -681,8 +691,12 @@
       stack.push(node);
       assert !stackEntryInfo.containsKey(node);
       stackEntryInfo.put(node, new StackEntryInfo(stack.size() - 1, predecessor));
-      if (predecessor != null && predecessor.getWritersWithDeterministicOrder().contains(node)) {
-        writerStack.push(node);
+      if (predecessor != null) {
+        if (node.getMethod().isClassInitializer()) {
+          clinitStack.push(node);
+        } else if (predecessor.getWritersWithDeterministicOrder().contains(node)) {
+          writerStack.push(node);
+        }
       }
     }
 
@@ -691,7 +705,10 @@
       assert popped == node;
       assert stackEntryInfo.containsKey(node);
       stackEntryInfo.remove(node);
-      if (writerStack.peek() == popped) {
+      if (clinitStack.peek() == popped) {
+        assert writerStack.peek() != popped;
+        clinitStack.pop();
+      } else if (writerStack.peek() == popped) {
         writerStack.pop();
       }
     }
@@ -704,6 +721,28 @@
       writersToBeRemoved.computeIfAbsent(reader, ignore -> Sets.newIdentityHashSet()).add(writer);
     }
 
+    private boolean removeIncomingEdgeOnStack(
+        Node target,
+        Node currentCalleeOrWriter,
+        StackEntryInfo currentCalleeOrWriterStackEntryInfo,
+        BiConsumer<Node, Node> edgeRemover) {
+      StackEntryInfo targetStackEntryInfo = stackEntryInfo.get(target);
+      boolean cycleContainsTarget =
+          targetStackEntryInfo.index > currentCalleeOrWriterStackEntryInfo.index;
+      if (cycleContainsTarget) {
+        assert verifyCycleSatisfies(
+            currentCalleeOrWriter,
+            cycle -> cycle.contains(target) && cycle.contains(targetStackEntryInfo.predecessor));
+        if (!targetStackEntryInfo.processed) {
+          edgeRemover.accept(targetStackEntryInfo.predecessor, target);
+          revisit.add(target);
+          targetStackEntryInfo.processed = true;
+        }
+        return true;
+      }
+      return false;
+    }
+
     private LinkedList<Node> extractCycle(Node entry) {
       LinkedList<Node> cycle = new LinkedList<>();
       do {