Use disjoint sets data structure in minifier

Change-Id: I8a65000cf2a6c7cc0266dee34acdd6f09379aa7d
diff --git a/src/main/java/com/android/tools/r8/naming/InterfaceMethodNameMinifier.java b/src/main/java/com/android/tools/r8/naming/InterfaceMethodNameMinifier.java
index f91dec1..6deecef 100644
--- a/src/main/java/com/android/tools/r8/naming/InterfaceMethodNameMinifier.java
+++ b/src/main/java/com/android/tools/r8/naming/InterfaceMethodNameMinifier.java
@@ -11,6 +11,7 @@
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.DisjointSets;
 import com.android.tools.r8.utils.MethodJavaSignatureEquivalence;
 import com.android.tools.r8.utils.MethodSignatureEquivalence;
 import com.android.tools.r8.utils.Timing;
@@ -396,8 +397,9 @@
     // Union-find structure to keep track of methods that must be renamed together.
     // Note that if the input does not use multi-interface lambdas unificationParent will remain
     // empty.
-    Map<Wrapper<DexMethod>, Wrapper<DexMethod>> unificationParent = new HashMap<>();
     timing.begin("Union-find");
+    DisjointSets<Wrapper<DexMethod>> unification = new DisjointSets<>();
+
     liveCallSites.forEach(
         callSite -> {
           Set<Wrapper<DexMethod>> callSiteMethods = new HashSet<>();
@@ -418,32 +420,25 @@
           if (callSiteMethods.size() > 1) {
             // Implemented interfaces have different protos. Unify them.
             Wrapper<DexMethod> mainKey = callSiteMethods.iterator().next();
-            mainKey = unificationParent.getOrDefault(mainKey, mainKey);
+            Wrapper<DexMethod> representative = unification.findOrMakeSet(mainKey);
             for (Wrapper<DexMethod> key : callSiteMethods) {
-              unificationParent.put(key, mainKey);
+              unification.unionWithMakeSet(representative, key);
             }
           }
         });
-    Map<Wrapper<DexMethod>, Set<Wrapper<DexMethod>>> unification = new HashMap<>();
-    for (Wrapper<DexMethod> key : unificationParent.keySet()) {
-      // Find root with path-compression.
-      Wrapper<DexMethod> root = unificationParent.get(key);
-      while (unificationParent.get(root) != root) {
-        Wrapper<DexMethod> k = unificationParent.get(unificationParent.get(root));
-        unificationParent.put(root, k);
-        root = k;
-      }
-      unification.computeIfAbsent(root, k -> new HashSet<>()).add(key);
-    }
+
     timing.end();
 
     // We now have roots for all unions. Add all of the states for the groups to the method state
     // for the unions to allow consistent naming across different protos.
     timing.begin("States for union");
-    for (Wrapper<DexMethod> wrapped : unification.keySet()) {
+    Map<Wrapper<DexMethod>, Set<Wrapper<DexMethod>>> unions = unification.collectSets();
+
+    for (Wrapper<DexMethod> wrapped : unions.keySet()) {
       InterfaceMethodGroupState groupState = globalStateMap.get(wrapped);
       assert groupState != null;
-      for (Wrapper<DexMethod> groupedMethod : unification.get(wrapped)) {
+
+      for (Wrapper<DexMethod> groupedMethod : unions.get(wrapped)) {
         DexMethod method = groupedMethod.get();
         assert method != null;
         groupState.appendMethodGroupState(globalStateMap.get(groupedMethod));
@@ -457,7 +452,7 @@
     // referenced in many places.
     List<Wrapper<DexMethod>> interfaceMethodGroups =
         globalStateMap.keySet().stream()
-            .filter(wrapper -> unificationParent.getOrDefault(wrapper, wrapper).equals(wrapper))
+            .filter(unification::isRepresentativeOrNotPresent)
             .sorted(
                 appView
                     .options()