Account for control flow in loops correctly

Bug: b/237674850
Change-Id: I3e3d260884bc9c3ae8e1aa1e3cc1ac0eac08ba2a
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderAppendOptimizer.java b/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderAppendOptimizer.java
index a2a90b4..ee85c13 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderAppendOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderAppendOptimizer.java
@@ -35,6 +35,8 @@
 import com.android.tools.r8.ir.optimize.string.StringBuilderNode.InitNode;
 import com.android.tools.r8.ir.optimize.string.StringBuilderNode.InitOrAppend;
 import com.android.tools.r8.ir.optimize.string.StringBuilderNode.LoopNode;
+import com.android.tools.r8.ir.optimize.string.StringBuilderNode.NewInstanceNode;
+import com.android.tools.r8.ir.optimize.string.StringBuilderNode.SplitReferenceNode;
 import com.android.tools.r8.ir.optimize.string.StringBuilderNodeMuncher.MunchingState;
 import com.android.tools.r8.ir.optimize.string.StringBuilderOracle.DefaultStringBuilderOracle;
 import com.android.tools.r8.utils.DepthFirstSearchWorkListBase.DepthFirstSearchWorkList;
@@ -42,6 +44,9 @@
 import com.android.tools.r8.utils.TraversalContinuation;
 import com.android.tools.r8.utils.WorkList;
 import com.google.common.collect.Sets;
+import it.unimi.dsi.fastutil.objects.Reference2IntLinkedOpenHashMap;
+import it.unimi.dsi.fastutil.objects.Reference2IntMap;
+import it.unimi.dsi.fastutil.objects.Reference2IntMap.Entry;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.IdentityHashMap;
@@ -416,10 +421,13 @@
               DFSNodeWithState<BasicBlock, StringBuilderGraphState> node,
               List<DFSNodeWithState<BasicBlock, StringBuilderGraphState>> childStates) {
             StringBuilderGraphState state = node.getState();
+            Reference2IntMap<Value> rootsInChildStateCounts =
+                new Reference2IntLinkedOpenHashMap<>();
             for (DFSNodeWithState<BasicBlock, StringBuilderGraphState> childState : childStates) {
               StringBuilderGraphState childGraphState = childState.getState();
               childGraphState.roots.forEach(
                   (value, sbNode) -> {
+                    rootsInChildStateCounts.put(value, rootsInChildStateCounts.getInt(value) + 1);
                     StringBuilderNode currentRoot = state.roots.get(value);
                     StringBuilderNode currentTail = state.tails.get(value);
                     if (currentRoot == null) {
@@ -446,6 +454,18 @@
                 childGraphState.isPartOfLoop = true;
               }
             }
+            // To ensure that we account for control flow correctly, we insert split reference nodes
+            // for all roots we've seen in only a subset of child states.
+            for (Entry<Value> valueEntry : rootsInChildStateCounts.reference2IntEntrySet()) {
+              assert valueEntry.getIntValue() <= childStates.size();
+              if (valueEntry.getIntValue() < childStates.size()) {
+                SplitReferenceNode splitNode = StringBuilderNode.createSplitReferenceNode();
+                StringBuilderNode tail = state.tails.get(valueEntry.getKey());
+                assert tail != null;
+                splitNode.addPredecessor(tail);
+                tail.addSuccessor(splitNode);
+              }
+            }
             if (state.isPartOfLoop) {
               state.roots.replaceAll(
                   (value, currentRoot) -> {
@@ -475,6 +495,7 @@
       Map<Value, StringBuilderNode> stringBuilderGraphs) {
     Map<Instruction, StringBuilderAction> actions = new IdentityHashMap<>();
     // Build state to allow munching over the string builder graphs.
+    Map<StringBuilderNode, NewInstanceNode> newInstances = new IdentityHashMap<>();
     Set<StringBuilderNode> inspectingCapacity = Sets.newIdentityHashSet();
     Set<StringBuilderNode> looping = Sets.newIdentityHashSet();
     Map<StringBuilderNode, Set<StringBuilderNode>> materializing = new IdentityHashMap<>();
@@ -492,6 +513,10 @@
           while (workList.hasNext()) {
             StringBuilderNode next = workList.next();
             nodeToRoots.put(next, root);
+            if (next.isNewInstanceNode()) {
+              StringBuilderNode existing = newInstances.put(root, next.asNewInstanceNode());
+              assert existing == null;
+            }
             if (next.isInitOrAppend()) {
               ImplicitToStringNode dependency = next.asInitOrAppend().getImplicitToStringNode();
               if (dependency != null) {
@@ -518,7 +543,8 @@
         });
 
     MunchingState munchingState =
-        new MunchingState(actions, escaping, inspectingCapacity, looping, materializing, oracle);
+        new MunchingState(
+            actions, escaping, inspectingCapacity, looping, materializing, newInstances, oracle);
 
     boolean keepMunching = true;
     for (int i = 0; i < NUMBER_OF_MUNCHING_PASSES && keepMunching; i++) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderNodeMuncher.java b/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderNodeMuncher.java
index 74b654d..0659a8a 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderNodeMuncher.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/string/StringBuilderNodeMuncher.java
@@ -13,6 +13,7 @@
 import com.android.tools.r8.ir.optimize.string.StringBuilderNode.ImplicitToStringNode;
 import com.android.tools.r8.ir.optimize.string.StringBuilderNode.InitNode;
 import com.android.tools.r8.ir.optimize.string.StringBuilderNode.InitOrAppend;
+import com.android.tools.r8.ir.optimize.string.StringBuilderNode.NewInstanceNode;
 import com.android.tools.r8.ir.optimize.string.StringBuilderNode.StringBuilderInstruction;
 import com.android.tools.r8.ir.optimize.string.StringBuilderNode.ToStringNode;
 import com.android.tools.r8.utils.WorkList;
@@ -39,6 +40,7 @@
     private final Set<StringBuilderNode> inspectingCapacity;
     private final Set<StringBuilderNode> looping;
     private final Map<StringBuilderNode, Set<StringBuilderNode>> materializingInstructions;
+    private final Map<StringBuilderNode, NewInstanceNode> newInstances;
     private final Map<Value, String> optimizedStrings = new IdentityHashMap<>();
 
     MunchingState(
@@ -47,14 +49,24 @@
         Set<StringBuilderNode> inspectingCapacity,
         Set<StringBuilderNode> looping,
         Map<StringBuilderNode, Set<StringBuilderNode>> materializingInstructions,
+        Map<StringBuilderNode, NewInstanceNode> newInstances,
         StringBuilderOracle oracle) {
       this.actions = actions;
       this.escaping = escaping;
       this.inspectingCapacity = inspectingCapacity;
       this.looping = looping;
       this.materializingInstructions = materializingInstructions;
+      this.newInstances = newInstances;
       this.oracle = oracle;
     }
+
+    public NewInstanceNode getNewInstanceNode(StringBuilderNode root) {
+      return newInstances.get(root);
+    }
+
+    public boolean isLooping(StringBuilderNode root) {
+      return looping.contains(root);
+    }
   }
 
   private interface PeepholePattern {
@@ -132,11 +144,11 @@
       if (!currentNode.isToStringNode() && !currentNode.isImplicitToStringNode()) {
         return false;
       }
-      StringBuilderNode root = findFirstNonSentinelRoot(originalRoot);
-      if (!root.isNewInstanceNode() || !root.hasSingleSuccessor()) {
+      NewInstanceNode newInstanceNode = munchingState.getNewInstanceNode(originalRoot);
+      if (newInstanceNode == null || !newInstanceNode.hasSingleSuccessor()) {
         return false;
       }
-      StringBuilderNode init = root.getSingleSuccessor();
+      StringBuilderNode init = newInstanceNode.getSingleSuccessor();
       String rootConstantArgument = getConstantArgumentForNode(init, munchingState);
       if (rootConstantArgument == null || !init.isInitNode()) {
         return false;
@@ -185,24 +197,6 @@
     }
   }
 
-  /**
-   * Find the first non split reference node or loop-node, which are nodes inserted to track
-   * control-flow.
-   */
-  private static StringBuilderNode findFirstNonSentinelRoot(StringBuilderNode root) {
-    WorkList<StringBuilderNode> workList = WorkList.newIdentityWorkList(root);
-    while (workList.hasNext()) {
-      StringBuilderNode node = workList.next();
-      if (!node.isSplitReferenceNode() && !node.isLoopNode()) {
-        return node;
-      }
-      if (node.hasSingleSuccessor()) {
-        workList.addIfNotSeen(node.getSingleSuccessor());
-      }
-    }
-    return root;
-  }
-
   private static String getConstantArgumentForNode(
       StringBuilderNode node, MunchingState munchingState) {
     if (node.isAppendNode()) {
@@ -249,7 +243,7 @@
         // Remove appends if the string builder do not escape, is not inspected or materialized
         // and if it is not part of a loop.
         removeNode = false;
-        if (currentNode.isSplitReferenceNode()) {
+        if (currentNode.isSplitReferenceNode() && !munchingState.isLooping(root)) {
           removeNode = currentNode.getSuccessors().isEmpty() || currentNode.hasSinglePredecessor();
         } else if (currentNode.isAppendNode() && !isEscaping) {
           AppendNode appendNode = currentNode.asAppendNode();
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/string/StringBuilderTests.java b/src/test/java/com/android/tools/r8/ir/optimize/string/StringBuilderTests.java
index e7c5301..86fcb86 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/string/StringBuilderTests.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/string/StringBuilderTests.java
@@ -217,17 +217,21 @@
         .apply(parameters)
         .inspect(
             inspector -> {
-              if (parameters.isCfRuntime()) {
-                // TODO(b/114002137): for now, string concatenation depends on rewriteMoveResult.
-                return;
-              }
               MethodSubject method = inspector.method(stringBuilderTest.method);
               assertThat(method, isPresent());
               FoundMethodSubject foundMethodSubject = method.asFoundMethodSubject();
               assertEquals(
                   stringBuilderTest.stringBuilders, countStringBuilderInits(foundMethodSubject));
-              assertEquals(
-                  stringBuilderTest.appends, countStringBuilderAppends(foundMethodSubject));
+              if (parameters.isCfRuntime()
+                  && (stringBuilderTest.getMethodName().equals("diamondWithUseTest")
+                      || stringBuilderTest.getMethodName().equals("intoPhiTest"))) {
+                // We are not doing block suffix optimization in CF.
+                assertEquals(
+                    stringBuilderTest.appends + 1, countStringBuilderAppends(foundMethodSubject));
+              } else {
+                assertEquals(
+                    stringBuilderTest.appends, countStringBuilderAppends(foundMethodSubject));
+              }
               assertEquals(
                   stringBuilderTest.toStrings, countStringBuilderToStrings(foundMethodSubject));
             })
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/string/StringBuilderWithConstantsBeforeAndInOpenEndedIfInLoopTest.java b/src/test/java/com/android/tools/r8/ir/optimize/string/StringBuilderWithConstantsBeforeAndInOpenEndedIfInLoopTest.java
index 2fdfaa6..67d3d66 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/string/StringBuilderWithConstantsBeforeAndInOpenEndedIfInLoopTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/string/StringBuilderWithConstantsBeforeAndInOpenEndedIfInLoopTest.java
@@ -37,8 +37,7 @@
                 options.itemFactory.libraryMethodsReturningReceiver = Sets.newIdentityHashSet())
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), TestClass.class, "1", "3")
-        // TODO(b/237674850): Should be 0.1.2.3.
-        .assertSuccessWithOutputLines("0.1.2.3.2.");
+        .assertSuccessWithOutputLines("0.1.2.3.");
   }
 
   static class TestClass {