Use traversal over predecessors/successors in dataflow analysis

This changes the dataflow analysis to use a traversal over predecessors/successors, instead of retrieving a collection of predecessors/successors and then iterating the given collection.

In extending the dataflow analysis to CfCode, this makes it possible to avoid needing to materialize the predecessor/successor collections.

Change-Id: Ide27c619d6a87a3f9a1f4651cd986255a2213259
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/ControlFlowGraph.java b/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/ControlFlowGraph.java
index 5710f6e..a6150bc 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/ControlFlowGraph.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/ControlFlowGraph.java
@@ -5,38 +5,155 @@
 package com.android.tools.r8.ir.analysis.framework.intraprocedural;
 
 import com.android.tools.r8.utils.TraversalContinuation;
-import com.google.common.collect.Iterables;
-import java.util.Collection;
+import com.android.tools.r8.utils.TraversalUtils;
 import java.util.function.BiFunction;
+import java.util.function.Consumer;
+import java.util.function.Function;
 
 public interface ControlFlowGraph<Block, Instruction> {
 
-  Collection<Block> getPredecessors(Block block);
-
-  Collection<Block> getSuccessors(Block block);
-
   default boolean hasUniquePredecessor(Block block) {
-    return getPredecessors(block).size() == 1;
-  }
-
-  default Block getUniquePredecessor(Block block) {
-    assert hasUniquePredecessor(block);
-    return Iterables.getOnlyElement(getPredecessors(block));
+    return TraversalUtils.isSingleton(counter -> traversePredecessors(block, counter));
   }
 
   default boolean hasUniqueSuccessor(Block block) {
-    return getSuccessors(block).size() == 1;
+    return TraversalUtils.isSingleton(counter -> traverseSuccessors(block, counter));
   }
 
   default boolean hasUniqueSuccessorWithUniquePredecessor(Block block) {
-    return hasUniqueSuccessor(block) && getPredecessors(getUniqueSuccessor(block)).size() == 1;
+    return hasUniqueSuccessor(block) && hasUniquePredecessor(getUniqueSuccessor(block));
   }
 
   default Block getUniqueSuccessor(Block block) {
     assert hasUniqueSuccessor(block);
-    return Iterables.getOnlyElement(getSuccessors(block));
+    return TraversalUtils.getFirst(collector -> traverseSuccessors(block, collector));
   }
 
+  // Block traversal.
+
+  default <BT, CT> TraversalContinuation<BT, CT> traversePredecessors(
+      Block block, Function<? super Block, TraversalContinuation<BT, CT>> fn) {
+    return traversePredecessors(block, (predecessor, ignore) -> fn.apply(predecessor), null);
+  }
+
+  default <BT, CT> TraversalContinuation<BT, CT> traverseNormalPredecessors(
+      Block block, Function<? super Block, TraversalContinuation<BT, CT>> fn) {
+    return traverseNormalPredecessors(block, (predecessor, ignore) -> fn.apply(predecessor), null);
+  }
+
+  default <BT, CT> TraversalContinuation<BT, CT> traverseExceptionalPredecessors(
+      Block block, Function<? super Block, TraversalContinuation<BT, CT>> fn) {
+    return traverseExceptionalPredecessors(
+        block, (predecessor, ignore) -> fn.apply(predecessor), null);
+  }
+
+  default <BT, CT> TraversalContinuation<BT, CT> traverseSuccessors(
+      Block block, Function<? super Block, TraversalContinuation<BT, CT>> fn) {
+    return traverseSuccessors(block, (successor, ignore) -> fn.apply(successor), null);
+  }
+
+  default <BT, CT> TraversalContinuation<BT, CT> traverseNormalSuccessors(
+      Block block, Function<? super Block, TraversalContinuation<BT, CT>> fn) {
+    return traverseNormalSuccessors(block, (successor, ignore) -> fn.apply(successor), null);
+  }
+
+  default <BT, CT> TraversalContinuation<BT, CT> traverseExceptionalSuccessors(
+      Block block, Function<? super Block, TraversalContinuation<BT, CT>> fn) {
+    return traverseExceptionalSuccessors(block, (successor, ignore) -> fn.apply(successor), null);
+  }
+
+  // Block traversal with result.
+
+  default <BT, CT> TraversalContinuation<BT, CT> traversePredecessors(
+      Block block,
+      BiFunction<? super Block, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    return traverseNormalPredecessors(block, fn, initialValue)
+        .ifContinueThen(
+            continuation ->
+                traverseExceptionalPredecessors(block, fn, continuation.getValueOrDefault(null)));
+  }
+
+  <BT, CT> TraversalContinuation<BT, CT> traverseNormalPredecessors(
+      Block block,
+      BiFunction<? super Block, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue);
+
+  <BT, CT> TraversalContinuation<BT, CT> traverseExceptionalPredecessors(
+      Block block,
+      BiFunction<? super Block, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue);
+
+  default <BT, CT> TraversalContinuation<BT, CT> traverseSuccessors(
+      Block block,
+      BiFunction<? super Block, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    return traverseNormalSuccessors(block, fn, initialValue)
+        .ifContinueThen(
+            continuation ->
+                traverseExceptionalSuccessors(block, fn, continuation.getValueOrDefault(null)));
+  }
+
+  <BT, CT> TraversalContinuation<BT, CT> traverseNormalSuccessors(
+      Block block,
+      BiFunction<? super Block, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue);
+
+  <BT, CT> TraversalContinuation<BT, CT> traverseExceptionalSuccessors(
+      Block block,
+      BiFunction<? super Block, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue);
+
+  // Block iteration.
+
+  default void forEachPredecessor(Block block, Consumer<Block> consumer) {
+    forEachNormalPredecessor(block, consumer);
+    forEachExceptionalPredecessor(block, consumer);
+  }
+
+  default void forEachNormalPredecessor(Block block, Consumer<Block> consumer) {
+    traverseNormalPredecessors(
+        block,
+        predecessor -> {
+          consumer.accept(predecessor);
+          return TraversalContinuation.doContinue();
+        });
+  }
+
+  default void forEachExceptionalPredecessor(Block block, Consumer<Block> consumer) {
+    traverseExceptionalPredecessors(
+        block,
+        exceptionalPredecessor -> {
+          consumer.accept(exceptionalPredecessor);
+          return TraversalContinuation.doContinue();
+        });
+  }
+
+  default void forEachSuccessor(Block block, Consumer<Block> consumer) {
+    forEachNormalSuccessor(block, consumer);
+    forEachExceptionalSuccessor(block, consumer);
+  }
+
+  default void forEachNormalSuccessor(Block block, Consumer<Block> consumer) {
+    traverseNormalSuccessors(
+        block,
+        successor -> {
+          consumer.accept(successor);
+          return TraversalContinuation.doContinue();
+        });
+  }
+
+  default void forEachExceptionalSuccessor(Block block, Consumer<Block> consumer) {
+    traverseExceptionalSuccessors(
+        block,
+        exceptionalSuccessor -> {
+          consumer.accept(exceptionalSuccessor);
+          return TraversalContinuation.doContinue();
+        });
+  }
+
+  // Instruction traversal.
+
   <BT, CT> TraversalContinuation<BT, CT> traverseInstructions(
       Block block, BiFunction<Instruction, CT, TraversalContinuation<BT, CT>> fn, CT initialValue);
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IRControlFlowGraph.java b/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IRControlFlowGraph.java
new file mode 100644
index 0000000..aae27f7
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IRControlFlowGraph.java
@@ -0,0 +1,31 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.analysis.framework.intraprocedural;
+
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.Instruction;
+
+public interface IRControlFlowGraph extends ControlFlowGraph<BasicBlock, Instruction> {
+
+  @Override
+  default boolean hasUniquePredecessor(BasicBlock block) {
+    return block.hasUniquePredecessor();
+  }
+
+  @Override
+  default boolean hasUniqueSuccessor(BasicBlock block) {
+    return block.hasUniqueSuccessor();
+  }
+
+  @Override
+  default boolean hasUniqueSuccessorWithUniquePredecessor(BasicBlock block) {
+    return block.hasUniqueSuccessorWithUniquePredecessor();
+  }
+
+  @Override
+  default BasicBlock getUniqueSuccessor(BasicBlock block) {
+    return block.getUniqueSuccessor();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IntraProceduralDataflowAnalysisBase.java b/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IntraProceduralDataflowAnalysisBase.java
index 969197f..f904da3 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IntraProceduralDataflowAnalysisBase.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IntraProceduralDataflowAnalysisBase.java
@@ -8,6 +8,7 @@
 import com.android.tools.r8.ir.analysis.framework.intraprocedural.DataflowAnalysisResult.SuccessfulDataflowAnalysisResult;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.TraversalContinuation;
+import com.android.tools.r8.utils.TraversalUtils;
 import com.android.tools.r8.utils.WorkList;
 import java.util.IdentityHashMap;
 import java.util.Map;
@@ -100,7 +101,7 @@
       // Update the block exit state, and re-enqueue all successor blocks if the abstract state
       // changed.
       if (setBlockExitState(end, state)) {
-        worklist.addAllIgnoringSeenSet(cfg.getSuccessors(end));
+        cfg.forEachSuccessor(end, worklist::addIgnoringSeenSet);
       }
 
       // Add the computed exit state to the entry state of each successor that satisfies the
@@ -114,14 +115,17 @@
     if (shouldCacheBlockEntryStateFor(block)) {
       return blockEntryStatesCache.getOrDefault(block, bottom);
     }
-    StateType result = bottom;
-    for (Block predecessor : cfg.getPredecessors(block)) {
-      StateType edgeState =
-          transfer.computeBlockEntryState(
-              block, predecessor, blockExitStates.getOrDefault(predecessor, bottom));
-      result = result.join(edgeState);
-    }
-    return result;
+    TraversalContinuation<?, StateType> traversalContinuation =
+        cfg.traversePredecessors(
+            block,
+            (predecessor, entryState) -> {
+              StateType edgeState =
+                  transfer.computeBlockEntryState(
+                      block, predecessor, blockExitStates.getOrDefault(predecessor, bottom));
+              return TraversalContinuation.doContinue(entryState.join(edgeState));
+            },
+            bottom);
+    return traversalContinuation.asContinue().getValue();
   }
 
   boolean setBlockExitState(Block block, StateType state) {
@@ -132,16 +136,18 @@
   }
 
   void updateBlockEntryStateCacheForSuccessors(Block block, StateType state) {
-    for (Block successor : cfg.getSuccessors(block)) {
-      if (shouldCacheBlockEntryStateFor(successor)) {
-        StateType edgeState = transfer.computeBlockEntryState(successor, block, state);
-        StateType previous = blockEntryStatesCache.getOrDefault(successor, bottom);
-        blockEntryStatesCache.put(successor, previous.join(edgeState));
-      }
-    }
+    cfg.forEachSuccessor(
+        block,
+        successor -> {
+          if (shouldCacheBlockEntryStateFor(successor)) {
+            StateType edgeState = transfer.computeBlockEntryState(successor, block, state);
+            StateType previous = blockEntryStatesCache.getOrDefault(successor, bottom);
+            blockEntryStatesCache.put(successor, previous.join(edgeState));
+          }
+        });
   }
 
   boolean shouldCacheBlockEntryStateFor(Block block) {
-    return cfg.getPredecessors(block).size() > 2;
+    return TraversalUtils.isSizeGreaterThan(counter -> cfg.traversePredecessors(block, counter), 2);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IntraproceduralDataflowAnalysis.java b/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IntraproceduralDataflowAnalysis.java
index 13a9c42..789a2cb 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IntraproceduralDataflowAnalysis.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/framework/intraprocedural/IntraproceduralDataflowAnalysis.java
@@ -17,4 +17,9 @@
       AbstractTransferFunction<BasicBlock, Instruction, StateType> transfer) {
     super(bottom, code, transfer);
   }
+
+  @Override
+  boolean shouldCacheBlockEntryStateFor(BasicBlock block) {
+    return block.getPredecessors().size() > 2;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
index b664c13..0d1bf75 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
@@ -25,6 +25,7 @@
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.StringUtils.BraceType;
+import com.android.tools.r8.utils.TraversalContinuation;
 import com.google.common.base.Equivalence;
 import com.google.common.base.Equivalence.Wrapper;
 import com.google.common.collect.ImmutableList;
@@ -48,6 +49,7 @@
 import java.util.NoSuchElementException;
 import java.util.Set;
 import java.util.WeakHashMap;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
@@ -119,7 +121,7 @@
 
   public enum ThrowingInfo {
     NO_THROW,
-    CAN_THROW;
+    CAN_THROW
   }
 
   public enum EdgeType {
@@ -189,6 +191,73 @@
   // Map of registers to current SSA value. Used during SSA numbering and cleared once filled.
   private Map<Integer, Value> currentDefinitions = new HashMap<>();
 
+  public <BT, CT> TraversalContinuation<BT, CT> traverseNormalPredecessors(
+      BiFunction<? super BasicBlock, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    TraversalContinuation<BT, CT> traversalContinuation =
+        TraversalContinuation.doContinue(initialValue);
+    for (BasicBlock predecessor : getPredecessors()) {
+      if (predecessor.hasCatchSuccessor(this)) {
+        continue;
+      }
+      traversalContinuation =
+          fn.apply(predecessor, traversalContinuation.asContinue().getValueOrDefault(null));
+      if (traversalContinuation.isBreak()) {
+        break;
+      }
+    }
+    return traversalContinuation;
+  }
+
+  public <BT, CT> TraversalContinuation<BT, CT> traverseNormalSuccessors(
+      BiFunction<? super BasicBlock, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    TraversalContinuation<BT, CT> traversalContinuation =
+        TraversalContinuation.doContinue(initialValue);
+    for (int i = successors.size() - numberOfNormalSuccessors(); i < successors.size(); i++) {
+      traversalContinuation =
+          fn.apply(successors.get(i), traversalContinuation.asContinue().getValueOrDefault(null));
+      if (traversalContinuation.isBreak()) {
+        break;
+      }
+    }
+    return traversalContinuation;
+  }
+
+  public <BT, CT> TraversalContinuation<BT, CT> traverseExceptionalPredecessors(
+      BiFunction<? super BasicBlock, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    TraversalContinuation<BT, CT> traversalContinuation =
+        TraversalContinuation.doContinue(initialValue);
+    for (BasicBlock predecessor : getPredecessors()) {
+      if (!predecessor.hasCatchSuccessor(this)) {
+        continue;
+      }
+      traversalContinuation =
+          fn.apply(predecessor, traversalContinuation.asContinue().getValueOrDefault(null));
+      if (traversalContinuation.isBreak()) {
+        break;
+      }
+    }
+    return traversalContinuation;
+  }
+
+  public <BT, CT> TraversalContinuation<BT, CT> traverseExceptionalSuccessors(
+      BiFunction<? super BasicBlock, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    int numberOfExceptionalSuccessors = numberOfExceptionalSuccessors();
+    TraversalContinuation<BT, CT> traversalContinuation =
+        TraversalContinuation.doContinue(initialValue);
+    for (int i = 0; i < numberOfExceptionalSuccessors; i++) {
+      traversalContinuation =
+          fn.apply(successors.get(i), traversalContinuation.asContinue().getValueOrDefault(null));
+      if (traversalContinuation.isBreak()) {
+        break;
+      }
+    }
+    return traversalContinuation;
+  }
+
   public void addControlFlowEdgesMayChangeListener(BasicBlockChangeListener listener) {
     if (onControlFlowEdgesMayChangeListeners == null) {
       // WeakSet to allow the listeners to be garbage collected.
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCode.java b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
index 53fa359..bf8f7ad 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCode.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
@@ -16,7 +16,7 @@
 import com.android.tools.r8.graph.classmerging.MergedClassesCollection;
 import com.android.tools.r8.ir.analysis.TypeChecker;
 import com.android.tools.r8.ir.analysis.VerifyTypesHelper;
-import com.android.tools.r8.ir.analysis.framework.intraprocedural.ControlFlowGraph;
+import com.android.tools.r8.ir.analysis.framework.intraprocedural.IRControlFlowGraph;
 import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
@@ -64,7 +64,7 @@
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
-public class IRCode implements ControlFlowGraph<BasicBlock, Instruction>, ValueFactory {
+public class IRCode implements IRControlFlowGraph, ValueFactory {
 
   private static final int MAX_MARKING_COLOR = 0x40000000;
 
@@ -1352,17 +1352,47 @@
     return blocks;
   }
 
-  @Override
   public Collection<BasicBlock> getPredecessors(BasicBlock block) {
     return block.getPredecessors();
   }
 
-  @Override
   public Collection<BasicBlock> getSuccessors(BasicBlock block) {
     return block.getSuccessors();
   }
 
   @Override
+  public <BT, CT> TraversalContinuation<BT, CT> traverseNormalPredecessors(
+      BasicBlock block,
+      BiFunction<? super BasicBlock, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    return block.traverseNormalPredecessors(fn, initialValue);
+  }
+
+  @Override
+  public <BT, CT> TraversalContinuation<BT, CT> traverseNormalSuccessors(
+      BasicBlock block,
+      BiFunction<? super BasicBlock, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    return block.traverseNormalSuccessors(fn, initialValue);
+  }
+
+  @Override
+  public <BT, CT> TraversalContinuation<BT, CT> traverseExceptionalPredecessors(
+      BasicBlock block,
+      BiFunction<? super BasicBlock, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    return block.traverseExceptionalPredecessors(fn, initialValue);
+  }
+
+  @Override
+  public <BT, CT> TraversalContinuation<BT, CT> traverseExceptionalSuccessors(
+      BasicBlock block,
+      BiFunction<? super BasicBlock, ? super CT, TraversalContinuation<BT, CT>> fn,
+      CT initialValue) {
+    return block.traverseExceptionalSuccessors(fn, initialValue);
+  }
+
+  @Override
   public <BT, CT> TraversalContinuation<BT, CT> traverseInstructions(
       BasicBlock block,
       BiFunction<Instruction, CT, TraversalContinuation<BT, CT>> fn,
diff --git a/src/main/java/com/android/tools/r8/utils/IntBox.java b/src/main/java/com/android/tools/r8/utils/IntBox.java
index 9a31ebd..79684cd 100644
--- a/src/main/java/com/android/tools/r8/utils/IntBox.java
+++ b/src/main/java/com/android/tools/r8/utils/IntBox.java
@@ -53,6 +53,15 @@
     value += i;
   }
 
+  public int incrementAndGet() {
+    return incrementAndGet(1);
+  }
+
+  public int incrementAndGet(int i) {
+    increment(i);
+    return get();
+  }
+
   public void set(int value) {
     this.value = value;
   }
diff --git a/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java b/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
index 19f89d8..f89727b 100644
--- a/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
+++ b/src/main/java/com/android/tools/r8/utils/TraversalContinuation.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.utils;
 
 import com.android.tools.r8.errors.Unreachable;
+import java.util.function.Function;
 
 /** Two value continuation value to indicate the continuation of a loop/traversal. */
 /* This class is used for building up api class member traversals. */
@@ -26,13 +27,18 @@
   }
 
   public static class Continue<TB, TC> extends TraversalContinuation<TB, TC> {
-    private static final TraversalContinuation<?, ?> CONTINUE_NO_VALUE =
+    private static final TraversalContinuation.Continue<?, ?> CONTINUE_NO_VALUE =
         new Continue<Object, Object>(null) {
           @Override
           public Object getValue() {
             return new Unreachable(
                 "Invalid attempt at getting a value from a no-value continue state.");
           }
+
+          @Override
+          public Object getValueOrDefault(Object defaultValue) {
+            return defaultValue;
+          }
         };
 
     private final TC value;
@@ -45,6 +51,10 @@
       return value;
     }
 
+    public TC getValueOrDefault(TC defaultValue) {
+      return value;
+    }
+
     @Override
     public boolean isContinue() {
       return true;
@@ -57,13 +67,18 @@
   }
 
   public static class Break<TB, TC> extends TraversalContinuation<TB, TC> {
-    private static final TraversalContinuation<?, ?> BREAK_NO_VALUE =
+    private static final TraversalContinuation.Break<?, ?> BREAK_NO_VALUE =
         new Break<Object, Object>(null) {
           @Override
           public Object getValue() {
             return new Unreachable(
                 "Invalid attempt at getting a value from a no-value break state.");
           }
+
+          @Override
+          public Object getValueOrDefault(Object defaultValue) {
+            return defaultValue;
+          }
         };
 
     private final TB value;
@@ -76,6 +91,10 @@
       return value;
     }
 
+    public TB getValueOrDefault(TB defaultValue) {
+      return value;
+    }
+
     @Override
     public boolean isBreak() {
       return true;
@@ -87,29 +106,34 @@
     }
   }
 
-  public static TraversalContinuation<?, ?> breakIf(boolean condition) {
+  public static <TB, TC> TraversalContinuation<TB, TC> breakIf(boolean condition) {
     return continueIf(!condition);
   }
 
-  public static TraversalContinuation<?, ?> continueIf(boolean condition) {
+  public static <TB, TC> TraversalContinuation<TB, TC> continueIf(boolean condition) {
     return condition ? doContinue() : doBreak();
   }
 
-  @SuppressWarnings("unchecked")
-  public static <TB, TC> TraversalContinuation<TB, TC> doContinue() {
-    return (TraversalContinuation<TB, TC>) Continue.CONTINUE_NO_VALUE;
+  public TraversalContinuation<TB, TC> ifContinueThen(
+      Function<TraversalContinuation.Continue<TB, TC>, TraversalContinuation<TB, TC>> fn) {
+    return isContinue() ? fn.apply(asContinue()) : this;
   }
 
-  public static <TB, TC> TraversalContinuation<TB, TC> doContinue(TC value) {
+  @SuppressWarnings("unchecked")
+  public static <TB, TC> TraversalContinuation.Continue<TB, TC> doContinue() {
+    return (TraversalContinuation.Continue<TB, TC>) Continue.CONTINUE_NO_VALUE;
+  }
+
+  public static <TB, TC> TraversalContinuation.Continue<TB, TC> doContinue(TC value) {
     return new Continue<>(value);
   }
 
   @SuppressWarnings("unchecked")
-  public static <TB, TC> TraversalContinuation<TB, TC> doBreak() {
-    return (TraversalContinuation<TB, TC>) Break.BREAK_NO_VALUE;
+  public static <TB, TC> TraversalContinuation.Break<TB, TC> doBreak() {
+    return (TraversalContinuation.Break<TB, TC>) Break.BREAK_NO_VALUE;
   }
 
-  public static <TB, TC> TraversalContinuation<TB, TC> doBreak(TB value) {
+  public static <TB, TC> TraversalContinuation.Break<TB, TC> doBreak(TB value) {
     return new Break<>(value);
   }
 
diff --git a/src/main/java/com/android/tools/r8/utils/TraversalUtils.java b/src/main/java/com/android/tools/r8/utils/TraversalUtils.java
new file mode 100644
index 0000000..c24769a
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/TraversalUtils.java
@@ -0,0 +1,40 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import static com.android.tools.r8.utils.FunctionUtils.ignoreArgument;
+
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+public class TraversalUtils {
+
+  public static <BT, CT> BT getFirst(
+      Function<Function<BT, TraversalContinuation<BT, CT>>, TraversalContinuation<BT, CT>>
+          traversal) {
+    return traversal.apply(TraversalContinuation::doBreak).asBreak().getValue();
+  }
+
+  public static <BT, CT> boolean isSingleton(
+      Consumer<Function<CT, TraversalContinuation<BT, CT>>> traversal) {
+    return isSizeExactly(traversal, 1);
+  }
+
+  public static <BT, CT> boolean isSizeExactly(
+      Consumer<Function<CT, TraversalContinuation<BT, CT>>> traversal, int value) {
+    IntBox counter = new IntBox();
+    traversal.accept(
+        ignoreArgument(() -> TraversalContinuation.breakIf(counter.incrementAndGet() > value)));
+    return counter.get() == value;
+  }
+
+  public static <BT, CT> boolean isSizeGreaterThan(
+      Consumer<Function<CT, TraversalContinuation<BT, CT>>> traversal, int value) {
+    IntBox counter = new IntBox();
+    traversal.accept(
+        ignoreArgument(() -> TraversalContinuation.breakIf(counter.incrementAndGet() > value)));
+    return counter.get() > value;
+  }
+}