Support for creating ProgramMethodSet with initial capacity

Change-Id: I262414fbe8cbbbdf01e086223031e04e37610174
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMethodSet.java b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMethodSet.java
index ad7c3fe..3832cce 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMethodSet.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMethodSet.java
@@ -11,49 +11,25 @@
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Supplier;
 
-public class DexClassAndMethodSet extends DexClassAndMethodSetBase<DexClassAndMethod> {
+public abstract class DexClassAndMethodSet extends DexClassAndMethodSetBase<DexClassAndMethod> {
 
-  private static final DexClassAndMethodSet EMPTY = new DexClassAndMethodSet(ImmutableMap::of);
+  private static final DexClassAndMethodSet EMPTY = new EmptyDexClassAndMethodSet();
 
-  protected DexClassAndMethodSet(
-      Supplier<? extends Map<DexMethod, DexClassAndMethod>> backingFactory) {
-    super(backingFactory);
-  }
-
-  protected DexClassAndMethodSet(
-      Supplier<? extends Map<DexMethod, DexClassAndMethod>> backingFactory,
-      Map<DexMethod, DexClassAndMethod> backing) {
-    super(backingFactory, backing);
+  DexClassAndMethodSet() {
+    super();
   }
 
   public static DexClassAndMethodSet create() {
-    return new DexClassAndMethodSet(IdentityHashMap::new);
-  }
-
-  public static DexClassAndMethodSet create(int capacity) {
-    return new DexClassAndMethodSet(IdentityHashMap::new, new IdentityHashMap<>(capacity));
-  }
-
-  public static DexClassAndMethodSet create(DexClassAndMethod element) {
-    DexClassAndMethodSet result = create();
-    result.add(element);
-    return result;
-  }
-
-  public static DexClassAndMethodSet create(DexClassAndMethodSet methodSet) {
-    DexClassAndMethodSet newMethodSet = create();
-    newMethodSet.addAll(methodSet);
-    return newMethodSet;
+    return new IdentityDexClassAndMethodSet();
   }
 
   public static DexClassAndMethodSet createConcurrent() {
-    return new DexClassAndMethodSet(ConcurrentHashMap::new);
+    return new ConcurrentDexClassAndMethodSet();
   }
 
   public static DexClassAndMethodSet createLinked() {
-    return new DexClassAndMethodSet(LinkedHashMap::new);
+    return new LinkedDexClassAndMethodSet();
   }
 
   public static DexClassAndMethodSet empty() {
@@ -63,4 +39,56 @@
   public void addAll(DexClassAndMethodSet methods) {
     backing.putAll(methods.backing);
   }
+
+  private static class ConcurrentDexClassAndMethodSet extends DexClassAndMethodSet {
+
+    @Override
+    Map<DexMethod, DexClassAndMethod> createBacking() {
+      return new ConcurrentHashMap<>();
+    }
+
+    @Override
+    Map<DexMethod, DexClassAndMethod> createBacking(int capacity) {
+      return new ConcurrentHashMap<>(capacity);
+    }
+  }
+
+  private static class EmptyDexClassAndMethodSet extends DexClassAndMethodSet {
+
+    @Override
+    Map<DexMethod, DexClassAndMethod> createBacking() {
+      return ImmutableMap.of();
+    }
+
+    @Override
+    Map<DexMethod, DexClassAndMethod> createBacking(int capacity) {
+      return ImmutableMap.of();
+    }
+  }
+
+  private static class IdentityDexClassAndMethodSet extends DexClassAndMethodSet {
+
+    @Override
+    Map<DexMethod, DexClassAndMethod> createBacking() {
+      return new IdentityHashMap<>();
+    }
+
+    @Override
+    Map<DexMethod, DexClassAndMethod> createBacking(int capacity) {
+      return new IdentityHashMap<>(capacity);
+    }
+  }
+
+  private static class LinkedDexClassAndMethodSet extends DexClassAndMethodSet {
+
+    @Override
+    Map<DexMethod, DexClassAndMethod> createBacking() {
+      return new LinkedHashMap<>();
+    }
+
+    @Override
+    Map<DexMethod, DexClassAndMethod> createBacking(int capacity) {
+      return new LinkedHashMap<>(capacity);
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMethodSetBase.java b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMethodSetBase.java
index 0787662..007aa51 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMethodSetBase.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMethodSetBase.java
@@ -16,25 +16,29 @@
 import java.util.Set;
 import java.util.function.IntFunction;
 import java.util.function.Predicate;
-import java.util.function.Supplier;
 import java.util.stream.Stream;
 
 public abstract class DexClassAndMethodSetBase<T extends DexClassAndMethod>
     implements Collection<T> {
 
-  protected final Map<DexMethod, T> backing;
-  protected final Supplier<? extends Map<DexMethod, T>> backingFactory;
+  Map<DexMethod, T> backing;
 
-  protected DexClassAndMethodSetBase(Supplier<? extends Map<DexMethod, T>> backingFactory) {
-    this(backingFactory, backingFactory.get());
+  DexClassAndMethodSetBase() {
+    this.backing = createBacking();
   }
 
-  protected DexClassAndMethodSetBase(
-      Supplier<? extends Map<DexMethod, T>> backingFactory, Map<DexMethod, T> backing) {
+  DexClassAndMethodSetBase(Map<DexMethod, T> backing) {
     this.backing = backing;
-    this.backingFactory = backingFactory;
   }
 
+  DexClassAndMethodSetBase(int capacity) {
+    this.backing = createBacking(capacity);
+  }
+
+  abstract Map<DexMethod, T> createBacking();
+
+  abstract Map<DexMethod, T> createBacking(int capacity);
+
   @Override
   public boolean add(T method) {
     T existing = backing.put(method.getReference(), method);
@@ -171,4 +175,12 @@
     forEach(method -> definitions.add(method.getDefinition()));
     return definitions;
   }
+
+  public void trimCapacityIfSizeLessThan(int expectedSize) {
+    if (size() < expectedSize) {
+      Map<DexMethod, T> newBacking = createBacking(size());
+      newBacking.putAll(backing);
+      backing = newBacking;
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/collections/LinkedProgramMethodSet.java b/src/main/java/com/android/tools/r8/utils/collections/LinkedProgramMethodSet.java
index b323e7e..258ee4f 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/LinkedProgramMethodSet.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/LinkedProgramMethodSet.java
@@ -4,23 +4,28 @@
 
 package com.android.tools.r8.utils.collections;
 
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.ProgramMethod;
 import java.util.LinkedHashMap;
+import java.util.Map;
 
 public class LinkedProgramMethodSet extends ProgramMethodSet {
 
   LinkedProgramMethodSet() {
-    super(LinkedProgramMethodSet::createBacking, createBacking());
+    super();
   }
 
   LinkedProgramMethodSet(int capacity) {
-    super(LinkedProgramMethodSet::createBacking, createBacking(capacity));
+    super(capacity);
   }
 
-  private static <K, V> LinkedHashMap<K, V> createBacking() {
+  @Override
+  Map<DexMethod, ProgramMethod> createBacking() {
     return new LinkedHashMap<>();
   }
 
-  private static <K, V> LinkedHashMap<K, V> createBacking(int capacity) {
+  @Override
+  Map<DexMethod, ProgramMethod> createBacking(int capacity) {
     return new LinkedHashMap<>(capacity);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/collections/ProgramMethodSet.java b/src/main/java/com/android/tools/r8/utils/collections/ProgramMethodSet.java
index 57ad688..b2ece0b 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/ProgramMethodSet.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/ProgramMethodSet.java
@@ -15,32 +15,33 @@
 import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Supplier;
 
-public class ProgramMethodSet extends DexClassAndMethodSetBase<ProgramMethod> {
+public abstract class ProgramMethodSet extends DexClassAndMethodSetBase<ProgramMethod> {
 
-  private static final ProgramMethodSet EMPTY = new ProgramMethodSet(ImmutableMap::of);
+  private static final ProgramMethodSet EMPTY = new EmptyProgramMethodSet();
 
-  protected ProgramMethodSet(Supplier<? extends Map<DexMethod, ProgramMethod>> backingFactory) {
-    super(backingFactory);
+  ProgramMethodSet() {
+    super();
   }
 
-  protected ProgramMethodSet(
-      Supplier<? extends Map<DexMethod, ProgramMethod>> backingFactory,
-      Map<DexMethod, ProgramMethod> backing) {
-    super(backingFactory, backing);
+  ProgramMethodSet(Map<DexMethod, ProgramMethod> backing) {
+    super(backing);
+  }
+
+  ProgramMethodSet(int capacity) {
+    super(capacity);
   }
 
   public static ProgramMethodSet create() {
-    return new ProgramMethodSet(IdentityHashMap::new);
+    return new IdentityProgramMethodSet();
   }
 
   public static ProgramMethodSet create(int capacity) {
-    return new ProgramMethodSet(IdentityHashMap::new, new IdentityHashMap<>(capacity));
+    return new IdentityProgramMethodSet(capacity);
   }
 
   public static ProgramMethodSet create(ProgramMethod element) {
-    ProgramMethodSet result = create();
+    ProgramMethodSet result = create(1);
     result.add(element);
     return result;
   }
@@ -52,13 +53,13 @@
   }
 
   public static ProgramMethodSet create(ProgramMethodSet methodSet) {
-    ProgramMethodSet newMethodSet = create();
+    ProgramMethodSet newMethodSet = create(methodSet.size());
     newMethodSet.addAll(methodSet);
     return newMethodSet;
   }
 
   public static ProgramMethodSet createConcurrent() {
-    return new ProgramMethodSet(ConcurrentHashMap::new);
+    return new ConcurrentProgramMethodSet();
   }
 
   public static LinkedProgramMethodSet createLinked() {
@@ -82,14 +83,63 @@
   }
 
   public ProgramMethodSet rewrittenWithLens(DexDefinitionSupplier definitions, GraphLens lens) {
-    ProgramMethodSet rewritten = new ProgramMethodSet(backingFactory);
+    ProgramMethodSet rewritten = ProgramMethodSet.create(size());
     forEach(
         method -> {
           ProgramMethod newMethod = lens.mapProgramMethod(method, definitions);
           if (newMethod != null) {
+            assert !newMethod.getDefinition().isObsolete();
             rewritten.add(newMethod);
           }
         });
+    rewritten.trimCapacityIfSizeLessThan(size());
     return rewritten;
   }
+
+  private static class ConcurrentProgramMethodSet extends ProgramMethodSet {
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking() {
+      return new ConcurrentHashMap<>();
+    }
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking(int capacity) {
+      return new ConcurrentHashMap<>(capacity);
+    }
+  }
+
+  private static class EmptyProgramMethodSet extends ProgramMethodSet {
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking() {
+      return ImmutableMap.of();
+    }
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking(int capacity) {
+      return ImmutableMap.of();
+    }
+  }
+
+  private static class IdentityProgramMethodSet extends ProgramMethodSet {
+
+    IdentityProgramMethodSet() {
+      super();
+    }
+
+    IdentityProgramMethodSet(int capacity) {
+      super(capacity);
+    }
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking() {
+      return new IdentityHashMap<>();
+    }
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking(int capacity) {
+      return new IdentityHashMap<>(capacity);
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/collections/SortedProgramMethodSet.java b/src/main/java/com/android/tools/r8/utils/collections/SortedProgramMethodSet.java
index 3a17dc5..aa24dde 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/SortedProgramMethodSet.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/SortedProgramMethodSet.java
@@ -9,28 +9,25 @@
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.lens.GraphLens;
-import com.android.tools.r8.utils.ComparatorUtils;
 import com.android.tools.r8.utils.ForEachable;
-import com.android.tools.r8.utils.ForEachableUtils;
+import java.util.Collections;
 import java.util.Comparator;
+import java.util.Map;
 import java.util.Set;
-import java.util.SortedMap;
 import java.util.TreeMap;
 import java.util.TreeSet;
 import java.util.concurrent.ConcurrentSkipListMap;
-import java.util.function.Supplier;
 
-public class SortedProgramMethodSet extends ProgramMethodSet {
+public abstract class SortedProgramMethodSet extends ProgramMethodSet {
 
-  private static final SortedProgramMethodSet EMPTY =
-      new SortedProgramMethodSet(() -> new TreeMap<>(ComparatorUtils.unreachableComparator()));
+  private static final SortedProgramMethodSet EMPTY = new EmptySortedProgramMethodSet();
 
-  private SortedProgramMethodSet(Supplier<SortedMap<DexMethod, ProgramMethod>> backingFactory) {
-    super(backingFactory);
+  private SortedProgramMethodSet() {
+    super();
   }
 
   public static SortedProgramMethodSet create() {
-    return create(ForEachableUtils.empty());
+    return new TreeSortedProgramMethodSet();
   }
 
   public static SortedProgramMethodSet create(ProgramMethod method) {
@@ -40,14 +37,13 @@
   }
 
   public static SortedProgramMethodSet create(ForEachable<ProgramMethod> methods) {
-    SortedProgramMethodSet result =
-        new SortedProgramMethodSet(() -> new TreeMap<>(DexMethod::compareTo));
+    SortedProgramMethodSet result = create();
     methods.forEach(result::add);
     return result;
   }
 
   public static SortedProgramMethodSet createConcurrent() {
-    return new SortedProgramMethodSet(() -> new ConcurrentSkipListMap<>(DexMethod::compareTo));
+    return new ConcurrentSortedProgramMethodSet();
   }
 
   public static SortedProgramMethodSet empty() {
@@ -55,6 +51,11 @@
   }
 
   @Override
+  Map<DexMethod, ProgramMethod> createBacking(int capacity) {
+    return createBacking();
+  }
+
+  @Override
   public SortedProgramMethodSet rewrittenWithLens(
       DexDefinitionSupplier definitions, GraphLens lens) {
     return create(
@@ -63,10 +64,33 @@
 
   @Override
   public Set<DexEncodedMethod> toDefinitionSet() {
-    Comparator<DexEncodedMethod> comparator =
-        (x, y) -> x.getReference().compareTo(y.getReference());
+    Comparator<DexEncodedMethod> comparator = Comparator.comparing(DexEncodedMethod::getReference);
     Set<DexEncodedMethod> definitions = new TreeSet<>(comparator);
     forEach(method -> definitions.add(method.getDefinition()));
     return definitions;
   }
+
+  private static class ConcurrentSortedProgramMethodSet extends SortedProgramMethodSet {
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking() {
+      return new ConcurrentSkipListMap<>(DexMethod::compareTo);
+    }
+  }
+
+  private static class EmptySortedProgramMethodSet extends SortedProgramMethodSet {
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking() {
+      return Collections.emptyMap();
+    }
+  }
+
+  private static class TreeSortedProgramMethodSet extends SortedProgramMethodSet {
+
+    @Override
+    Map<DexMethod, ProgramMethod> createBacking() {
+      return new TreeMap<>(DexMethod::compareTo);
+    }
+  }
 }