Add disjoint sets data structure

Change-Id: Ibb13fe7a7f54a6e7d521cb09ee0f5c39e84821b6
diff --git a/src/main/java/com/android/tools/r8/utils/DisjointSets.java b/src/main/java/com/android/tools/r8/utils/DisjointSets.java
new file mode 100644
index 0000000..695bf62
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/DisjointSets.java
@@ -0,0 +1,163 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Disjoint sets of instances of type T. Each of the sets will be represented by one of the
+ * instances in the set.
+ */
+public class DisjointSets<T> {
+  // Forrest represented in a set. Each element maps to its parent. A root of a tree in
+  // the forest maps to itself. Each tree in the forrest represent a set in the disjoint sets.
+  private final Map<T, T> parent = new HashMap<>();
+
+  /**
+   * Create a new set containing only <code>element</code>.
+   *
+   * <p>The <code>element</code> must not be present in any existing sets.
+   */
+  public T makeSet(T element) {
+    assert !parent.containsKey(element);
+    parent.put(element, element);
+    assert findSet(element) == element;
+    return element;
+  }
+
+  /**
+   * Returns the representative for the set containing <code>element</code>.
+   *
+   * <p>Returns null if <code>element</code> is not in any sets.
+   */
+  public T findSet(T element) {
+    T candidate = parent.get(element);
+    if (candidate == null) {
+      return null;
+    }
+    T candidateParent = parent.get(candidate);
+    if (candidate == candidateParent) {
+      return candidate;
+    }
+    // If not at the representative recurse and compress path.
+    T representative = findSet(candidateParent);
+    parent.put(element, representative);
+    return representative;
+  }
+
+  /**
+   * Check if <code>element</code> is the representative for a set or not present at all.
+   *
+   * <p>Returns <code>true</code> if that is the case.
+   */
+  public boolean isRepresentativeOrNotPresent(T element) {
+    T representative = findSet(element);
+    return representative == null || representative.equals(element);
+  }
+
+  /**
+   * Returns the set containing <code>element</code>.
+   *
+   * <p>Returns null if <code>element</code> is not in any sets.
+   */
+  public Set<T> collectSet(T element) {
+    T representative = findSet(element);
+    if (representative == null) {
+      return null;
+    }
+    HashSet<T> result = new HashSet<>();
+    for (T t : parent.keySet()) {
+      // Find root with path-compression.
+      if (findSet(t).equals(representative)) {
+        result.add(t);
+      }
+    }
+    return result;
+  }
+
+  /**
+   * Returns the representative for the set containing <code>element</code> while creating a set
+   * containing <code>element</code> if <code>element</code> is not present in any of the existing
+   * sets.
+   *
+   * <p>Returns the representative for the set containing <code>element</code>.
+   */
+  public T findOrMakeSet(T element) {
+    T representative = findSet(element);
+    return representative != null ? representative : makeSet(element);
+  }
+
+  /**
+   * Union the two sets represented by <code>representative1</code> and <code>representative2</code>
+   * .
+   *
+   * <p>Both <code>representative1</code> and <code>representative2</code> must be actual
+   * representatives of a set, and not just any member of a set.
+   *
+   * <p>Returns the representative for the union set.
+   */
+  public T union(T representative1, T representative2) {
+    // The two representatives must be roots in different trees.
+    assert representative1 != null;
+    assert representative2 != null;
+    if (representative1 == representative2) {
+      return representative1;
+    }
+    assert parent.get(representative1) == representative1;
+    assert parent.get(representative2) == representative2;
+    // Join the trees.
+    parent.put(representative2, representative1);
+    assert findSet(representative1) == representative1;
+    assert findSet(representative2) == representative1;
+    return representative1;
+  }
+
+  /**
+   * Union the two sets containing by <code>element1</code> and <code>element2</code> while creating
+   * a set containing <code>element1</code> and <code>element2</code> if one of them is not present
+   * in any of the existing sets.
+   *
+   * <p>Returns the representative for the union set.
+   */
+  public T unionWithMakeSet(T element1, T element2) {
+    if (element1.toString().contains("Enum") || element2.toString().contains("Enum")) {
+      System.out.println();
+    }
+    if (element1 == element2) {
+      return findOrMakeSet(element1);
+    }
+    return union(findOrMakeSet(element1), findOrMakeSet(element2));
+  }
+
+  /** Returns the sets currently represented. */
+  public Map<T, Set<T>> collectSets() {
+    Map<T, Set<T>> unification = new HashMap<>();
+    for (T element : parent.keySet()) {
+      // Find root with path-compression.
+      T representative = findSet(element);
+      unification.computeIfAbsent(representative, k -> new HashSet<>()).add(element);
+    }
+    return unification;
+  }
+
+  @Override
+  public String toString() {
+    Map<T, Set<T>> sets = collectSets();
+    StringBuilder sb =
+        new StringBuilder()
+            .append("Number of sets: ")
+            .append(sets.keySet().size())
+            .append(System.lineSeparator());
+    sets.forEach(
+        (representative, set) -> {
+          sb.append("Representative: ").append(representative).append(System.lineSeparator());
+          set.forEach(v -> sb.append("    ").append(v).append(System.lineSeparator()));
+        });
+    return sb.toString();
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/utils/DisjointSetsTest.java b/src/test/java/com/android/tools/r8/utils/DisjointSetsTest.java
new file mode 100644
index 0000000..5afab7e
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/utils/DisjointSetsTest.java
@@ -0,0 +1,90 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import static org.junit.Assert.assertEquals;
+
+import java.util.Map;
+import java.util.Set;
+import org.junit.Test;
+
+public class DisjointSetsTest {
+
+  private DisjointSets<Integer> initTestSet(int size) {
+    DisjointSets<Integer> ds = new DisjointSets<>();
+    for (int i = 0; i < size; i++) {
+      ds.makeSet(i);
+    }
+    return ds;
+  }
+
+  public void runOddEvenTest(int size) {
+    DisjointSets<Integer> ds = initTestSet(size);
+
+    assertEquals(size, ds.collectSets().size());
+
+    for (int i = 2; i < size; i++) {
+      if (i % 2 == 0) {
+        ds.union(ds.findSet(0), ds.findSet(i));
+      } else {
+        ds.union(ds.findSet(1), ds.findSet(i));
+      }
+    }
+
+    Map<Integer, Set<Integer>> sets = ds.collectSets();
+    assertEquals(2, sets.size());
+    int elements = 0;
+    for (Integer representative : sets.keySet()) {
+      int oddOrEven = representative % 2;
+      Set<Integer> set = sets.get(representative);
+      set.forEach(s -> assertEquals(oddOrEven, s % 2));
+      elements += set.size();
+    }
+    assertEquals(size, elements);
+
+    for (int i = 2; i < size; i++) {
+      Set<Integer> set = ds.collectSet(i);
+      if (i % 2 == 0) {
+        set.forEach(s -> assertEquals(0, s % 2));
+        assertEquals(size / 2 + size % 2, set.size());
+      } else {
+        set.forEach(s -> assertEquals(1, s % 2));
+        assertEquals(size / 2, set.size());
+      }
+    }
+    assertEquals(size, ds.collectSet(0).size() + ds.collectSet(1).size());
+
+    ds.union(ds.findSet(size - 2), ds.findSet(size - 1));
+    assertEquals(1, ds.collectSets().size());
+  }
+
+  @Test
+  public void testOddEven() {
+    runOddEvenTest(2);
+    runOddEvenTest(3);
+    runOddEvenTest(4);
+    runOddEvenTest(10);
+    runOddEvenTest(100);
+    runOddEvenTest(1000);
+  }
+
+  public void runUnionAllTest(int size) {
+    DisjointSets<Integer> ds = initTestSet(size);
+    for (int i = 1; i < size; i++) {
+      ds.union(ds.findSet(i - 1), ds.findSet(i));
+      assertEquals(size - i, ds.collectSets().size());
+    }
+  }
+
+  @Test
+  public void unionAllTest() {
+    runUnionAllTest(2);
+    runUnionAllTest(3);
+    runUnionAllTest(4);
+    runUnionAllTest(10);
+    runUnionAllTest(100);
+    runUnionAllTest(1000);
+  }
+}