Add utility for DepthFirstSearch with custom state

Change-Id: I228820e3200be916c876f2a224384e7ab9215cb3
diff --git a/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java b/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
new file mode 100644
index 0000000..6cbd295
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
@@ -0,0 +1,234 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import static com.android.tools.r8.utils.DepthFirstSearchWorkListBase.ProcessingState.FINISHED;
+import static com.android.tools.r8.utils.DepthFirstSearchWorkListBase.ProcessingState.NOT_PROCESSED;
+import static com.android.tools.r8.utils.DepthFirstSearchWorkListBase.ProcessingState.WAITING;
+
+import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.utils.DepthFirstSearchWorkListBase.DFSNode;
+import com.android.tools.r8.utils.DepthFirstSearchWorkListBase.DFSNodeImpl;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.IdentityHashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+public abstract class DepthFirstSearchWorkListBase<
+    N, TExp extends DFSNode<N>, TImpl extends DFSNodeImpl<N>> {
+
+  public interface DFSNode<N> {
+    N getNode();
+
+    boolean seenAndNotProcessed();
+  }
+
+  public interface DFSNodeWithState<N, S> extends DFSNode<N> {
+
+    S getState();
+
+    void setState(S backtrackState);
+  }
+
+  enum ProcessingState {
+    NOT_PROCESSED,
+    WAITING,
+    FINISHED;
+  }
+
+  static class DFSNodeImpl<N> implements DFSNode<N> {
+
+    private final N node;
+    private ProcessingState processingState = NOT_PROCESSED;
+
+    private DFSNodeImpl(N node) {
+      this.node = node;
+    }
+
+    boolean isNotProcessed() {
+      return processingState == NOT_PROCESSED;
+    }
+
+    boolean isFinished() {
+      return processingState == FINISHED;
+    }
+
+    void setWaiting() {
+      processingState = WAITING;
+    }
+
+    void setFinished() {
+      assert processingState != FINISHED;
+      processingState = FINISHED;
+    }
+
+    @Override
+    public N getNode() {
+      return node;
+    }
+
+    @Override
+    public boolean seenAndNotProcessed() {
+      return processingState == WAITING;
+    }
+  }
+
+  static class DFSNodeWithStateImpl<N, S> extends DFSNodeImpl<N> implements DFSNodeWithState<N, S> {
+
+    private S state;
+
+    private DFSNodeWithStateImpl(N node) {
+      super(node);
+    }
+
+    @Override
+    public S getState() {
+      return state;
+    }
+
+    @Override
+    public void setState(S state) {
+      this.state = state;
+    }
+  }
+
+  private final ArrayDeque<TImpl> workList = new ArrayDeque<>();
+  private final Map<N, TImpl> stateMap = new IdentityHashMap<>();
+  private final Map<TImpl, List<TExp>> childStateMap = new IdentityHashMap<>();
+
+  abstract TImpl newNode(N node);
+
+  abstract boolean isStateful();
+
+  /**
+   * The initial processing of the node when visiting the first time during the depth first search.
+   *
+   * @param node The current node.
+   * @param childNodeConsumer A consumer for adding child nodes. If an element has been seen before
+   *     but not finished there is a cycle.
+   * @return A value describing if the DFS algorithm should continue to run.
+   */
+  protected abstract TraversalContinuation process(TExp node, Function<N, TExp> childNodeConsumer);
+
+  /**
+   * The joining of state during backtracking of the algorithm.
+   *
+   * @param node The current node
+   * @param childStates The already computed child states.
+   * @return A value describing if the DFS algorithm should continue to run.
+   */
+  TraversalContinuation joiner(TExp node, List<TExp> childStates) {
+    throw new Unreachable("Should not be called");
+  }
+
+  @SafeVarargs
+  public final TraversalContinuation run(N... roots) {
+    return run(Arrays.asList(roots));
+  }
+
+  @SuppressWarnings("unchecked")
+  public final TraversalContinuation run(Collection<N> roots) {
+    for (N root : roots) {
+      TImpl newNode = newNode(root);
+      stateMap.put(root, newNode);
+      workList.addLast(newNode);
+    }
+    TraversalContinuation continuation = TraversalContinuation.CONTINUE;
+    while (!workList.isEmpty()) {
+      TImpl node = workList.removeLast();
+      if (node.isFinished()) {
+        continue;
+      }
+      TExp exposed = (TExp) node;
+      if (node.isNotProcessed()) {
+        workList.addLast(node);
+        List<TExp> childStates =
+            isStateful()
+                ? childStateMap.computeIfAbsent(node, FunctionUtils.ignoreArgument(ArrayList::new))
+                : null;
+        node.setWaiting();
+        continuation =
+            process(
+                exposed,
+                childNode -> {
+                  TImpl childImpl = stateMap.computeIfAbsent(childNode, this::newNode);
+                  if (childImpl.isNotProcessed()) {
+                    workList.addLast(childImpl);
+                  }
+                  TExp childExp = (TExp) childImpl;
+                  if (childStates != null) {
+                    childStates.add(childExp);
+                  }
+                  return (TExp) childImpl;
+                });
+      } else {
+        assert node.seenAndNotProcessed();
+        if (isStateful()) {
+          continuation = joiner((TExp) node, childStateMap.get(node));
+        }
+        node.setFinished();
+      }
+      if (continuation.shouldBreak()) {
+        return continuation;
+      }
+    }
+    assert continuation.shouldBreak() || stateMap.values().stream().allMatch(TImpl::isFinished);
+    return continuation;
+  }
+
+  @SuppressWarnings("unchecked")
+  public void unwindUntilInclusive(TExp exp, Consumer<TExp> nodesOnPathConsumer) {
+    assert !workList.isEmpty();
+    TImpl startOfLoop = (TImpl) exp;
+    Iterator<TImpl> descendingIterator = workList.descendingIterator();
+    while (descendingIterator.hasNext()) {
+      TImpl next = descendingIterator.next();
+      if (!next.seenAndNotProcessed()) {
+        nodesOnPathConsumer.accept((TExp) next);
+      }
+      if (next == startOfLoop) {
+        return;
+      }
+    }
+  }
+
+  public abstract static class DepthFirstSearchWorkList<N>
+      extends DepthFirstSearchWorkListBase<N, DFSNode<N>, DFSNodeImpl<N>> {
+
+    @Override
+    DFSNodeImpl<N> newNode(N node) {
+      return new DFSNodeImpl<>(node);
+    }
+
+    @Override
+    boolean isStateful() {
+      return false;
+    }
+  }
+
+  public abstract static class StatefulDepthFirstSearchWorkListBase<N, S>
+      extends DepthFirstSearchWorkListBase<N, DFSNodeWithState<N, S>, DFSNodeWithStateImpl<N, S>> {
+
+    @Override
+    DFSNodeWithStateImpl<N, S> newNode(N node) {
+      return new DFSNodeWithStateImpl<>(node);
+    }
+
+    @Override
+    boolean isStateful() {
+      return true;
+    }
+
+    @Override
+    public abstract TraversalContinuation joiner(
+        DFSNodeWithState<N, S> node, List<DFSNodeWithState<N, S>> childStates);
+  }
+}