[KeepAnno] Simplify rule extractor a bit.

Change-Id: I5d45882454a6ac23bbb78d1b062d08bc654ace99
diff --git a/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepBindingReference.java b/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepBindingReference.java
index a5e9eda..481a6f2 100644
--- a/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepBindingReference.java
+++ b/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepBindingReference.java
@@ -6,6 +6,8 @@
 
 import com.android.tools.r8.keepanno.ast.KeepBindings.KeepBindingSymbol;
 import java.util.Objects;
+import java.util.function.Consumer;
+import java.util.function.Function;
 
 public abstract class KeepBindingReference {
 
@@ -47,6 +49,21 @@
     return null;
   }
 
+  public final <T> T apply(
+      Function<KeepClassBindingReference, T> onClass,
+      Function<KeepMemberBindingReference, T> onMember) {
+    if (isClassType()) {
+      return onClass.apply(asClassBindingReference());
+    }
+    assert isMemberType();
+    return onMember.apply(asMemberBindingReference());
+  }
+
+  public final void match(
+      Consumer<KeepClassBindingReference> onClass, Consumer<KeepMemberBindingReference> onMember) {
+    apply(AstUtils.toVoidFunction(onClass), AstUtils.toVoidFunction(onMember));
+  }
+
   @Override
   public String toString() {
     return name.toString();
diff --git a/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepBindings.java b/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepBindings.java
index 19e4d5f..4df258d 100644
--- a/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepBindings.java
+++ b/src/keepanno/java/com/android/tools/r8/keepanno/ast/KeepBindings.java
@@ -3,8 +3,6 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.keepanno.ast;
 
-import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.IdentityHashMap;
@@ -39,6 +37,34 @@
     return bindings.get(bindingReference);
   }
 
+  public KeepClassItemPattern getClassItem(KeepClassBindingReference reference) {
+    KeepBindingSymbol symbol = reference.getName();
+    Binding binding = get(symbol);
+    if (binding == null) {
+      throw new KeepEdgeException("Unbound binding for reference '" + symbol + "'");
+    }
+    KeepItemPattern item = binding.getItem();
+    if (!item.isClassItemPattern()) {
+      throw new KeepEdgeException(
+          "Attempt to get class item from non-class binding '" + symbol + "'");
+    }
+    return item.asClassItemPattern();
+  }
+
+  public KeepMemberItemPattern getMemberItem(KeepMemberBindingReference reference) {
+    KeepBindingSymbol symbol = reference.getName();
+    Binding binding = get(symbol);
+    if (binding == null) {
+      throw new KeepEdgeException("Unbound binding for reference '" + symbol + "'");
+    }
+    KeepItemPattern item = binding.getItem();
+    if (!item.isMemberItemPattern()) {
+      throw new KeepEdgeException(
+          "Attempt to get member item from non-member binding '" + symbol + "'");
+    }
+    return item.asMemberItemPattern();
+  }
+
   public int size() {
     return bindings.size();
   }
diff --git a/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractor.java b/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractor.java
index 8c12901..585eee6 100644
--- a/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractor.java
+++ b/src/keepanno/java/com/android/tools/r8/keepanno/keeprules/KeepRuleExtractor.java
@@ -16,7 +16,6 @@
 import com.android.tools.r8.keepanno.ast.KeepConstraints;
 import com.android.tools.r8.keepanno.ast.KeepDeclaration;
 import com.android.tools.r8.keepanno.ast.KeepEdge;
-import com.android.tools.r8.keepanno.ast.KeepEdgeException;
 import com.android.tools.r8.keepanno.ast.KeepEdgeMetaInfo;
 import com.android.tools.r8.keepanno.ast.KeepFieldAccessPattern;
 import com.android.tools.r8.keepanno.ast.KeepFieldPattern;
@@ -35,11 +34,8 @@
 import com.android.tools.r8.keepanno.keeprules.PgRule.PgKeepAttributeRule;
 import com.android.tools.r8.keepanno.keeprules.PgRule.PgUnconditionalRule;
 import com.android.tools.r8.keepanno.keeprules.PgRule.TargetKeepKind;
-import java.util.ArrayDeque;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
-import java.util.Deque;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -578,38 +574,8 @@
 
   private static KeepBindingSymbol getClassItemBindingReference(
       KeepBindingReference itemReference, KeepBindings bindings) {
-    KeepBindingSymbol classReference = null;
-    for (KeepBindingSymbol reference : getTransitiveBindingReferences(itemReference, bindings)) {
-      if (bindings.get(reference).getItem().isClassItemPattern()) {
-        if (classReference != null) {
-          throw new KeepEdgeException("Unexpected reference to multiple class bindings");
-        }
-        classReference = reference;
-      }
-    }
-    return classReference;
-  }
-
-  private static Set<KeepBindingSymbol> getTransitiveBindingReferences(
-      KeepBindingReference itemReference, KeepBindings bindings) {
-    Set<KeepBindingSymbol> symbols = new HashSet<>(2);
-    Deque<KeepBindingReference> worklist = new ArrayDeque<>();
-    worklist.addAll(getBindingReference(itemReference));
-    while (!worklist.isEmpty()) {
-      KeepBindingReference bindingReference = worklist.pop();
-      if (symbols.add(bindingReference.getName())) {
-        worklist.addAll(getBindingReference(bindings.get(bindingReference).getItem()));
-      }
-    }
-    return symbols;
-  }
-
-  private static Collection<KeepBindingReference> getBindingReference(
-      KeepBindingReference itemReference) {
-    return Collections.singletonList(itemReference);
-  }
-
-  private static Collection<KeepBindingReference> getBindingReference(KeepItemPattern itemPattern) {
-    return itemPattern.getBindingReferences();
+    return itemReference.apply(
+        KeepBindingReference::getName,
+        member -> bindings.getMemberItem(member).getClassReference().getName());
   }
 }