Adding return state to DFS worklist helpers

Change-Id: I9d20cfa71fde1ad54c6aa891d66afbeabe683902
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/SimpleDominatingEffectAnalysis.java b/src/main/java/com/android/tools/r8/ir/optimize/SimpleDominatingEffectAnalysis.java
index 75e94fb..9027ee1 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/SimpleDominatingEffectAnalysis.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/SimpleDominatingEffectAnalysis.java
@@ -219,16 +219,20 @@
       result = ResultState.NOT_SATISFIED;
     }
 
-    public void addSatisfyingInstruction(Instruction instruction) {
+    public SimpleEffectAnalysisResultBuilder addSatisfyingInstruction(Instruction instruction) {
       satisfyingInstructions.add(instruction);
+      return this;
     }
 
-    public void setFailingBlocksForPartialResults(List<BasicBlock> basicBlocks) {
+    public SimpleEffectAnalysisResultBuilder setFailingBlocksForPartialResults(
+        List<BasicBlock> basicBlocks) {
       this.failingBlocksForPartialResults = basicBlocks;
+      return this;
     }
 
-    public void setResult(ResultState result) {
+    public SimpleEffectAnalysisResultBuilder setResult(ResultState result) {
       this.result = result;
+      return this;
     }
 
     public SimpleEffectAnalysisResult build() {
@@ -246,63 +250,69 @@
   public static SimpleEffectAnalysisResult run(IRCode code, InstructionAnalysis analysis) {
     SimpleEffectAnalysisResultBuilder builder = SimpleEffectAnalysisResult.builder();
     IntBox visitedInstructions = new IntBox();
-    new StatefulDepthFirstSearchWorkList<BasicBlock, ResultStateWithPartialBlocks>() {
+    TraversalContinuation<Void, ResultStateWithPartialBlocks> runResult =
+        new StatefulDepthFirstSearchWorkList<BasicBlock, ResultStateWithPartialBlocks, Void>() {
 
-      @Override
-      protected TraversalContinuation<?, ?> process(
-          DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> node,
-          Function<BasicBlock, DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks>>
-              childNodeConsumer) {
-        InstructionEffect effect = NO_EFFECT;
-        for (Instruction instruction : node.getNode().getInstructions()) {
-          if (visitedInstructions.getAndIncrement() > analysis.maxNumberOfInstructions()) {
-            builder.fail();
-            return doBreak();
-          }
-          effect = analysis.analyze(instruction);
-          if (!effect.isNoEffect()) {
-            if (effect.isDesired()) {
-              builder.addSatisfyingInstruction(instruction);
+          @Override
+          protected TraversalContinuation<Void, ResultStateWithPartialBlocks> process(
+              DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> node,
+              Function<BasicBlock, DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks>>
+                  childNodeConsumer) {
+            InstructionEffect effect = NO_EFFECT;
+            for (Instruction instruction : node.getNode().getInstructions()) {
+              if (visitedInstructions.getAndIncrement() > analysis.maxNumberOfInstructions()) {
+                return doBreak();
+              }
+              effect = analysis.analyze(instruction);
+              if (!effect.isNoEffect()) {
+                if (effect.isDesired()) {
+                  builder.addSatisfyingInstruction(instruction);
+                }
+                break;
+              }
             }
-            break;
-          }
-        }
-        if (effect.isNoEffect()) {
-          List<BasicBlock> successors = analysis.getSuccessors(node.getNode());
-          for (BasicBlock successor : successors) {
-            DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> childNode =
-                childNodeConsumer.apply(successor);
-            if (childNode.hasState()) {
-              // If we see a block where the children have not been processed we cannot guarantee
-              // all paths having the effect since - ex. we could have a non-terminating loop.
-              builder.fail();
-              return doBreak();
+            if (effect.isNoEffect()) {
+              List<BasicBlock> successors = analysis.getSuccessors(node.getNode());
+              for (BasicBlock successor : successors) {
+                DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> childNode =
+                    childNodeConsumer.apply(successor);
+                if (childNode.hasState()) {
+                  // If we see a block where the children have not been processed we cannot
+                  // guarantee all paths having the effect since - ex. we could have a
+                  // non-terminating loop.
+                  return doBreak();
+                }
+              }
             }
+            node.setState(
+                new ResultStateWithPartialBlocks(effect.toResultState(), ImmutableList.of()));
+            return doContinue();
           }
-        }
-        node.setState(new ResultStateWithPartialBlocks(effect.toResultState(), ImmutableList.of()));
-        return doContinue();
-      }
 
-      @Override
-      protected TraversalContinuation<?, ?> joiner(
-          DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> node,
-          List<DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks>> childNodes) {
-        ResultStateWithPartialBlocks resultState = node.getState();
-        if (resultState.state.isNotComputed()) {
-          resultState = resultState.joinChildren(childNodes);
-        } else {
-          assert resultState.state.isSatisfied() || resultState.state.isNotSatisfied();
-          assert childNodes.isEmpty();
-        }
-        node.setState(resultState);
-        if (node.getNode().isEntry()) {
-          builder.setResult(resultState.state);
-          builder.setFailingBlocksForPartialResults(resultState.failingBlocks);
-        }
-        return doContinue();
-      }
-    }.run(code.entryBlock());
+          @Override
+          protected TraversalContinuation<Void, ResultStateWithPartialBlocks> joiner(
+              DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> node,
+              List<DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks>> childNodes) {
+            ResultStateWithPartialBlocks resultState = node.getState();
+            if (resultState.state.isNotComputed()) {
+              resultState = resultState.joinChildren(childNodes);
+            } else {
+              assert resultState.state.isSatisfied() || resultState.state.isNotSatisfied();
+              assert childNodes.isEmpty();
+            }
+            node.setState(resultState);
+            return doContinue(resultState);
+          }
+        }.run(code.entryBlock());
+
+    if (runResult.isBreak()) {
+      builder.fail();
+    } else {
+      ResultStateWithPartialBlocks resultState = runResult.asContinue().getValue();
+      builder
+          .setResult(resultState.state)
+          .setFailingBlocksForPartialResults(resultState.failingBlocks);
+    }
 
     return builder.build();
   }
diff --git a/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java b/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
index c9a56da..c16b39f 100644
--- a/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
+++ b/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
@@ -13,12 +13,13 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
 
-public abstract class DepthFirstSearchWorkListBase<N, T extends DFSNodeImpl<N>> {
+public abstract class DepthFirstSearchWorkListBase<N, T extends DFSNodeImpl<N>, TB, TC> {
 
   public interface DFSNode<N> {
     N getNode();
@@ -103,36 +104,48 @@
   }
 
   private final ArrayDeque<T> workList = new ArrayDeque<>();
+  // This map is necessary ensure termination since we embed nodes into nodes with state.
+  private final Map<N, T> nodeToNodeWithStateMap = new IdentityHashMap<>();
 
   abstract T createDfsNode(N node);
 
   /** The initial processing of a node during forward search */
-  abstract TraversalContinuation<?, ?> internalOnVisit(T node);
+  abstract TraversalContinuation<TB, TC> internalOnVisit(T node);
 
   /** The joining of state during backtracking of the algorithm. */
-  abstract TraversalContinuation<?, ?> internalOnJoin(T node);
+  abstract TraversalContinuation<TB, TC> internalOnJoin(T node);
+
+  abstract List<TC> getFinalStateForRoots(Collection<N> roots);
 
   final T internalEnqueueNode(N value) {
-    T dfsNode = createDfsNode(value);
+    T dfsNode = nodeToNodeWithStateMap.computeIfAbsent(value, this::createDfsNode);
     if (dfsNode.isNotProcessed()) {
       workList.addLast(dfsNode);
     }
     return dfsNode;
   }
 
+  protected T getNodeStateForNode(N value) {
+    return nodeToNodeWithStateMap.get(value);
+  }
+
+  public final TraversalContinuation<TB, TC> run(N root) {
+    return run(Collections.singletonList(root)).map(Function.identity(), results -> results.get(0));
+  }
+
   @SafeVarargs
-  public final TraversalContinuation<?, ?> run(N... roots) {
+  public final TraversalContinuation<TB, List<TC>> run(N... roots) {
     return run(Arrays.asList(roots));
   }
 
-  public final TraversalContinuation<?, ?> run(Collection<N> roots) {
+  public final TraversalContinuation<TB, List<TC>> run(Collection<N> roots) {
     roots.forEach(this::internalEnqueueNode);
-    TraversalContinuation<?, ?> continuation = TraversalContinuation.doContinue();
     while (!workList.isEmpty()) {
       T node = workList.removeLast();
       if (node.isFinished()) {
         continue;
       }
+      TraversalContinuation<TB, TC> continuation;
       if (node.isNotProcessed()) {
         workList.addLast(node);
         node.setWaiting();
@@ -143,14 +156,15 @@
         node.setFinished();
       }
       if (continuation.shouldBreak()) {
-        return continuation;
+        return TraversalContinuation.doBreak(continuation.asBreak().getValue());
       }
     }
-    return continuation;
+
+    return TraversalContinuation.doContinue(getFinalStateForRoots(roots));
   }
 
-  public abstract static class DepthFirstSearchWorkList<N>
-      extends DepthFirstSearchWorkListBase<N, DFSNodeImpl<N>> {
+  public abstract static class DepthFirstSearchWorkList<N, TB, TC>
+      extends DepthFirstSearchWorkListBase<N, DFSNodeImpl<N>, TB, TC> {
 
     /**
      * The initial processing of the node when visiting the first time during the depth first
@@ -161,7 +175,7 @@
      *     before but not finished there is a cycle.
      * @return A value describing if the DFS algorithm should continue to run.
      */
-    protected abstract TraversalContinuation<?, ?> process(
+    protected abstract TraversalContinuation<TB, TC> process(
         DFSNode<N> node, Function<N, DFSNode<N>> childNodeConsumer);
 
     @Override
@@ -170,18 +184,18 @@
     }
 
     @Override
-    TraversalContinuation<?, ?> internalOnVisit(DFSNodeImpl<N> node) {
+    TraversalContinuation<TB, TC> internalOnVisit(DFSNodeImpl<N> node) {
       return process(node, this::internalEnqueueNode);
     }
 
     @Override
-    protected TraversalContinuation<?, ?> internalOnJoin(DFSNodeImpl<N> node) {
+    protected TraversalContinuation<TB, TC> internalOnJoin(DFSNodeImpl<N> node) {
       return TraversalContinuation.doContinue();
     }
   }
 
-  public abstract static class StatefulDepthFirstSearchWorkList<N, S>
-      extends DepthFirstSearchWorkListBase<N, DFSNodeWithStateImpl<N, S>> {
+  public abstract static class StatefulDepthFirstSearchWorkList<N, S, TB>
+      extends DepthFirstSearchWorkListBase<N, DFSNodeWithStateImpl<N, S>, TB, S> {
 
     private final Map<DFSNodeWithStateImpl<N, S>, List<DFSNodeWithState<N, S>>> childStateMap =
         new IdentityHashMap<>();
@@ -195,7 +209,7 @@
      *     before but not finished there is a cycle.
      * @return A value describing if the DFS algorithm should continue to run.
      */
-    protected abstract TraversalContinuation<?, ?> process(
+    protected abstract TraversalContinuation<TB, S> process(
         DFSNodeWithState<N, S> node, Function<N, DFSNodeWithState<N, S>> childNodeConsumer);
 
     /**
@@ -205,7 +219,7 @@
      * @param childStates The already computed child states.
      * @return A value describing if the DFS algorithm should continue to run.
      */
-    protected abstract TraversalContinuation<?, ?> joiner(
+    protected abstract TraversalContinuation<TB, S> joiner(
         DFSNodeWithState<N, S> node, List<DFSNodeWithState<N, S>> childStates);
 
     @Override
@@ -214,7 +228,7 @@
     }
 
     @Override
-    TraversalContinuation<?, ?> internalOnVisit(DFSNodeWithStateImpl<N, S> node) {
+    TraversalContinuation<TB, S> internalOnVisit(DFSNodeWithStateImpl<N, S> node) {
       List<DFSNodeWithState<N, S>> childStates = new ArrayList<>();
       List<DFSNodeWithState<N, S>> removedChildStates = childStateMap.put(node, childStates);
       assert removedChildStates == null;
@@ -228,7 +242,7 @@
     }
 
     @Override
-    protected TraversalContinuation<?, ?> internalOnJoin(DFSNodeWithStateImpl<N, S> node) {
+    protected TraversalContinuation<TB, S> internalOnJoin(DFSNodeWithStateImpl<N, S> node) {
       return joiner(
           node,
           childStateMap.computeIfAbsent(
@@ -238,5 +252,10 @@
                 return new ArrayList<>();
               }));
     }
+
+    @Override
+    List<S> getFinalStateForRoots(Collection<N> roots) {
+      return ListUtils.map(roots, root -> getNodeStateForNode(root).state);
+    }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java b/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
index f89727b..c33d020 100644
--- a/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
+++ b/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
@@ -26,6 +26,16 @@
     return null;
   }
 
+  public <TBx, TCx> TraversalContinuation<TBx, TCx> map(
+      Function<TB, TBx> mapBreak, Function<TC, TCx> mapContinue) {
+    if (isBreak()) {
+      return new Break<>(mapBreak.apply(asBreak().getValue()));
+    } else {
+      assert isContinue();
+      return new Continue<>(mapContinue.apply(asContinue().getValue()));
+    }
+  }
+
   public static class Continue<TB, TC> extends TraversalContinuation<TB, TC> {
     private static final TraversalContinuation.Continue<?, ?> CONTINUE_NO_VALUE =
         new Continue<Object, Object>(null) {