Adding return state to DFS worklist helpers
Change-Id: I9d20cfa71fde1ad54c6aa891d66afbeabe683902
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/SimpleDominatingEffectAnalysis.java b/src/main/java/com/android/tools/r8/ir/optimize/SimpleDominatingEffectAnalysis.java
index 75e94fb..9027ee1 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/SimpleDominatingEffectAnalysis.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/SimpleDominatingEffectAnalysis.java
@@ -219,16 +219,20 @@
result = ResultState.NOT_SATISFIED;
}
- public void addSatisfyingInstruction(Instruction instruction) {
+ public SimpleEffectAnalysisResultBuilder addSatisfyingInstruction(Instruction instruction) {
satisfyingInstructions.add(instruction);
+ return this;
}
- public void setFailingBlocksForPartialResults(List<BasicBlock> basicBlocks) {
+ public SimpleEffectAnalysisResultBuilder setFailingBlocksForPartialResults(
+ List<BasicBlock> basicBlocks) {
this.failingBlocksForPartialResults = basicBlocks;
+ return this;
}
- public void setResult(ResultState result) {
+ public SimpleEffectAnalysisResultBuilder setResult(ResultState result) {
this.result = result;
+ return this;
}
public SimpleEffectAnalysisResult build() {
@@ -246,63 +250,69 @@
public static SimpleEffectAnalysisResult run(IRCode code, InstructionAnalysis analysis) {
SimpleEffectAnalysisResultBuilder builder = SimpleEffectAnalysisResult.builder();
IntBox visitedInstructions = new IntBox();
- new StatefulDepthFirstSearchWorkList<BasicBlock, ResultStateWithPartialBlocks>() {
+ TraversalContinuation<Void, ResultStateWithPartialBlocks> runResult =
+ new StatefulDepthFirstSearchWorkList<BasicBlock, ResultStateWithPartialBlocks, Void>() {
- @Override
- protected TraversalContinuation<?, ?> process(
- DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> node,
- Function<BasicBlock, DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks>>
- childNodeConsumer) {
- InstructionEffect effect = NO_EFFECT;
- for (Instruction instruction : node.getNode().getInstructions()) {
- if (visitedInstructions.getAndIncrement() > analysis.maxNumberOfInstructions()) {
- builder.fail();
- return doBreak();
- }
- effect = analysis.analyze(instruction);
- if (!effect.isNoEffect()) {
- if (effect.isDesired()) {
- builder.addSatisfyingInstruction(instruction);
+ @Override
+ protected TraversalContinuation<Void, ResultStateWithPartialBlocks> process(
+ DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> node,
+ Function<BasicBlock, DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks>>
+ childNodeConsumer) {
+ InstructionEffect effect = NO_EFFECT;
+ for (Instruction instruction : node.getNode().getInstructions()) {
+ if (visitedInstructions.getAndIncrement() > analysis.maxNumberOfInstructions()) {
+ return doBreak();
+ }
+ effect = analysis.analyze(instruction);
+ if (!effect.isNoEffect()) {
+ if (effect.isDesired()) {
+ builder.addSatisfyingInstruction(instruction);
+ }
+ break;
+ }
}
- break;
- }
- }
- if (effect.isNoEffect()) {
- List<BasicBlock> successors = analysis.getSuccessors(node.getNode());
- for (BasicBlock successor : successors) {
- DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> childNode =
- childNodeConsumer.apply(successor);
- if (childNode.hasState()) {
- // If we see a block where the children have not been processed we cannot guarantee
- // all paths having the effect since - ex. we could have a non-terminating loop.
- builder.fail();
- return doBreak();
+ if (effect.isNoEffect()) {
+ List<BasicBlock> successors = analysis.getSuccessors(node.getNode());
+ for (BasicBlock successor : successors) {
+ DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> childNode =
+ childNodeConsumer.apply(successor);
+ if (childNode.hasState()) {
+ // If we see a block where the children have not been processed we cannot
+ // guarantee all paths having the effect since - ex. we could have a
+ // non-terminating loop.
+ return doBreak();
+ }
+ }
}
+ node.setState(
+ new ResultStateWithPartialBlocks(effect.toResultState(), ImmutableList.of()));
+ return doContinue();
}
- }
- node.setState(new ResultStateWithPartialBlocks(effect.toResultState(), ImmutableList.of()));
- return doContinue();
- }
- @Override
- protected TraversalContinuation<?, ?> joiner(
- DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> node,
- List<DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks>> childNodes) {
- ResultStateWithPartialBlocks resultState = node.getState();
- if (resultState.state.isNotComputed()) {
- resultState = resultState.joinChildren(childNodes);
- } else {
- assert resultState.state.isSatisfied() || resultState.state.isNotSatisfied();
- assert childNodes.isEmpty();
- }
- node.setState(resultState);
- if (node.getNode().isEntry()) {
- builder.setResult(resultState.state);
- builder.setFailingBlocksForPartialResults(resultState.failingBlocks);
- }
- return doContinue();
- }
- }.run(code.entryBlock());
+ @Override
+ protected TraversalContinuation<Void, ResultStateWithPartialBlocks> joiner(
+ DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks> node,
+ List<DFSNodeWithState<BasicBlock, ResultStateWithPartialBlocks>> childNodes) {
+ ResultStateWithPartialBlocks resultState = node.getState();
+ if (resultState.state.isNotComputed()) {
+ resultState = resultState.joinChildren(childNodes);
+ } else {
+ assert resultState.state.isSatisfied() || resultState.state.isNotSatisfied();
+ assert childNodes.isEmpty();
+ }
+ node.setState(resultState);
+ return doContinue(resultState);
+ }
+ }.run(code.entryBlock());
+
+ if (runResult.isBreak()) {
+ builder.fail();
+ } else {
+ ResultStateWithPartialBlocks resultState = runResult.asContinue().getValue();
+ builder
+ .setResult(resultState.state)
+ .setFailingBlocksForPartialResults(resultState.failingBlocks);
+ }
return builder.build();
}
diff --git a/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java b/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
index c9a56da..c16b39f 100644
--- a/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
+++ b/src/main/java/com/android/tools/r8/utils/DepthFirstSearchWorkListBase.java
@@ -13,12 +13,13 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
-public abstract class DepthFirstSearchWorkListBase<N, T extends DFSNodeImpl<N>> {
+public abstract class DepthFirstSearchWorkListBase<N, T extends DFSNodeImpl<N>, TB, TC> {
public interface DFSNode<N> {
N getNode();
@@ -103,36 +104,48 @@
}
private final ArrayDeque<T> workList = new ArrayDeque<>();
+ // This map is necessary ensure termination since we embed nodes into nodes with state.
+ private final Map<N, T> nodeToNodeWithStateMap = new IdentityHashMap<>();
abstract T createDfsNode(N node);
/** The initial processing of a node during forward search */
- abstract TraversalContinuation<?, ?> internalOnVisit(T node);
+ abstract TraversalContinuation<TB, TC> internalOnVisit(T node);
/** The joining of state during backtracking of the algorithm. */
- abstract TraversalContinuation<?, ?> internalOnJoin(T node);
+ abstract TraversalContinuation<TB, TC> internalOnJoin(T node);
+
+ abstract List<TC> getFinalStateForRoots(Collection<N> roots);
final T internalEnqueueNode(N value) {
- T dfsNode = createDfsNode(value);
+ T dfsNode = nodeToNodeWithStateMap.computeIfAbsent(value, this::createDfsNode);
if (dfsNode.isNotProcessed()) {
workList.addLast(dfsNode);
}
return dfsNode;
}
+ protected T getNodeStateForNode(N value) {
+ return nodeToNodeWithStateMap.get(value);
+ }
+
+ public final TraversalContinuation<TB, TC> run(N root) {
+ return run(Collections.singletonList(root)).map(Function.identity(), results -> results.get(0));
+ }
+
@SafeVarargs
- public final TraversalContinuation<?, ?> run(N... roots) {
+ public final TraversalContinuation<TB, List<TC>> run(N... roots) {
return run(Arrays.asList(roots));
}
- public final TraversalContinuation<?, ?> run(Collection<N> roots) {
+ public final TraversalContinuation<TB, List<TC>> run(Collection<N> roots) {
roots.forEach(this::internalEnqueueNode);
- TraversalContinuation<?, ?> continuation = TraversalContinuation.doContinue();
while (!workList.isEmpty()) {
T node = workList.removeLast();
if (node.isFinished()) {
continue;
}
+ TraversalContinuation<TB, TC> continuation;
if (node.isNotProcessed()) {
workList.addLast(node);
node.setWaiting();
@@ -143,14 +156,15 @@
node.setFinished();
}
if (continuation.shouldBreak()) {
- return continuation;
+ return TraversalContinuation.doBreak(continuation.asBreak().getValue());
}
}
- return continuation;
+
+ return TraversalContinuation.doContinue(getFinalStateForRoots(roots));
}
- public abstract static class DepthFirstSearchWorkList<N>
- extends DepthFirstSearchWorkListBase<N, DFSNodeImpl<N>> {
+ public abstract static class DepthFirstSearchWorkList<N, TB, TC>
+ extends DepthFirstSearchWorkListBase<N, DFSNodeImpl<N>, TB, TC> {
/**
* The initial processing of the node when visiting the first time during the depth first
@@ -161,7 +175,7 @@
* before but not finished there is a cycle.
* @return A value describing if the DFS algorithm should continue to run.
*/
- protected abstract TraversalContinuation<?, ?> process(
+ protected abstract TraversalContinuation<TB, TC> process(
DFSNode<N> node, Function<N, DFSNode<N>> childNodeConsumer);
@Override
@@ -170,18 +184,18 @@
}
@Override
- TraversalContinuation<?, ?> internalOnVisit(DFSNodeImpl<N> node) {
+ TraversalContinuation<TB, TC> internalOnVisit(DFSNodeImpl<N> node) {
return process(node, this::internalEnqueueNode);
}
@Override
- protected TraversalContinuation<?, ?> internalOnJoin(DFSNodeImpl<N> node) {
+ protected TraversalContinuation<TB, TC> internalOnJoin(DFSNodeImpl<N> node) {
return TraversalContinuation.doContinue();
}
}
- public abstract static class StatefulDepthFirstSearchWorkList<N, S>
- extends DepthFirstSearchWorkListBase<N, DFSNodeWithStateImpl<N, S>> {
+ public abstract static class StatefulDepthFirstSearchWorkList<N, S, TB>
+ extends DepthFirstSearchWorkListBase<N, DFSNodeWithStateImpl<N, S>, TB, S> {
private final Map<DFSNodeWithStateImpl<N, S>, List<DFSNodeWithState<N, S>>> childStateMap =
new IdentityHashMap<>();
@@ -195,7 +209,7 @@
* before but not finished there is a cycle.
* @return A value describing if the DFS algorithm should continue to run.
*/
- protected abstract TraversalContinuation<?, ?> process(
+ protected abstract TraversalContinuation<TB, S> process(
DFSNodeWithState<N, S> node, Function<N, DFSNodeWithState<N, S>> childNodeConsumer);
/**
@@ -205,7 +219,7 @@
* @param childStates The already computed child states.
* @return A value describing if the DFS algorithm should continue to run.
*/
- protected abstract TraversalContinuation<?, ?> joiner(
+ protected abstract TraversalContinuation<TB, S> joiner(
DFSNodeWithState<N, S> node, List<DFSNodeWithState<N, S>> childStates);
@Override
@@ -214,7 +228,7 @@
}
@Override
- TraversalContinuation<?, ?> internalOnVisit(DFSNodeWithStateImpl<N, S> node) {
+ TraversalContinuation<TB, S> internalOnVisit(DFSNodeWithStateImpl<N, S> node) {
List<DFSNodeWithState<N, S>> childStates = new ArrayList<>();
List<DFSNodeWithState<N, S>> removedChildStates = childStateMap.put(node, childStates);
assert removedChildStates == null;
@@ -228,7 +242,7 @@
}
@Override
- protected TraversalContinuation<?, ?> internalOnJoin(DFSNodeWithStateImpl<N, S> node) {
+ protected TraversalContinuation<TB, S> internalOnJoin(DFSNodeWithStateImpl<N, S> node) {
return joiner(
node,
childStateMap.computeIfAbsent(
@@ -238,5 +252,10 @@
return new ArrayList<>();
}));
}
+
+ @Override
+ List<S> getFinalStateForRoots(Collection<N> roots) {
+ return ListUtils.map(roots, root -> getNodeStateForNode(root).state);
+ }
}
}
diff --git a/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java b/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
index f89727b..c33d020 100644
--- a/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
+++ b/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
@@ -26,6 +26,16 @@
return null;
}
+ public <TBx, TCx> TraversalContinuation<TBx, TCx> map(
+ Function<TB, TBx> mapBreak, Function<TC, TCx> mapContinue) {
+ if (isBreak()) {
+ return new Break<>(mapBreak.apply(asBreak().getValue()));
+ } else {
+ assert isContinue();
+ return new Continue<>(mapContinue.apply(asContinue().getValue()));
+ }
+ }
+
public static class Continue<TB, TC> extends TraversalContinuation<TB, TC> {
private static final TraversalContinuation.Continue<?, ?> CONTINUE_NO_VALUE =
new Continue<Object, Object>(null) {