Add a hashmap method collection backing.

Change-Id: Id79823315f77b566f0a114a54a93162c314ee08c
diff --git a/src/main/java/com/android/tools/r8/GenerateLintFiles.java b/src/main/java/com/android/tools/r8/GenerateLintFiles.java
index 830f6b6..a6d704b 100644
--- a/src/main/java/com/android/tools/r8/GenerateLintFiles.java
+++ b/src/main/java/com/android/tools/r8/GenerateLintFiles.java
@@ -145,7 +145,7 @@
               ParameterAnnotationsList.empty(),
               code,
               50);
-      if (method.accessFlags.isStatic()) {
+      if (method.isStatic() || method.isDirectMethod()) {
         directMethods.add(throwingMethod);
       } else {
         virtualMethods.add(throwingMethod);
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index 6c74124..85810ed 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -85,9 +85,7 @@
     this.type = type;
     setStaticFields(staticFields);
     setInstanceFields(instanceFields);
-    this.methodCollection = new MethodCollection(this);
-    setDirectMethods(directMethods);
-    setVirtualMethods(virtualMethods);
+    this.methodCollection = new MethodCollection(this, directMethods, virtualMethods);
     this.nestHost = nestHost;
     this.nestMembers = nestMembers;
     assert nestMembers != null;
diff --git a/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java b/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
index 7f87c94..4eda679 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
@@ -20,16 +20,6 @@
   private DexEncodedMethod[] directMethods = DexEncodedMethod.EMPTY_ARRAY;
   private DexEncodedMethod[] virtualMethods = DexEncodedMethod.EMPTY_ARRAY;
 
-  private boolean belongsToDirectPool(DexEncodedMethod method) {
-    return method.accessFlags.isStatic()
-        || method.accessFlags.isPrivate()
-        || method.accessFlags.isConstructor();
-  }
-
-  private boolean belongsToVirtualPool(DexEncodedMethod method) {
-    return !belongsToDirectPool(method);
-  }
-
   private boolean verifyNoDuplicateMethods() {
     Set<DexMethod> unique = Sets.newIdentityHashSet();
     forEachMethod(
diff --git a/src/main/java/com/android/tools/r8/graph/MethodCollection.java b/src/main/java/com/android/tools/r8/graph/MethodCollection.java
index 54b212f..269b36b 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodCollection.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodCollection.java
@@ -15,12 +15,29 @@
 
 public class MethodCollection {
 
+  // Threshold between using an array and a map for the backing store.
+  // Compiling R8 plus library shows classes with up to 30 methods account for about 95% of classes.
+  private static final int ARRAY_BACKING_THRESHOLD = 30;
+
   private final DexClass holder;
-  private final MethodCollectionBacking backing = new MethodArrayBacking();
+  private final MethodCollectionBacking backing;
   private DexEncodedMethod cachedClassInitializer = DexEncodedMethod.SENTINEL;
 
-  public MethodCollection(DexClass holder) {
+  public MethodCollection(
+      DexClass holder, DexEncodedMethod[] directMethods, DexEncodedMethod[] virtualMethods) {
     this.holder = holder;
+    if (directMethods.length + virtualMethods.length > ARRAY_BACKING_THRESHOLD) {
+      backing = new MethodMapBacking();
+    } else {
+      backing = new MethodArrayBacking();
+    }
+    backing.setDirectMethods(directMethods);
+    backing.setVirtualMethods(virtualMethods);
+  }
+
+  private void resetCaches() {
+    resetDirectMethodCaches();
+    resetVirtualMethodCaches();
   }
 
   private void resetDirectMethodCaches() {
@@ -31,11 +48,6 @@
     // Nothing to do.
   }
 
-  private void resetCaches() {
-    resetDirectMethodCaches();
-    resetVirtualMethodCaches();
-  }
-
   public int size() {
     return backing.size();
   }
diff --git a/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java b/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
index 09de861..427f022 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
@@ -17,6 +17,16 @@
 
   abstract boolean verify();
 
+  boolean belongsToDirectPool(DexEncodedMethod method) {
+    return method.accessFlags.isStatic()
+        || method.accessFlags.isPrivate()
+        || method.accessFlags.isConstructor();
+  }
+
+  boolean belongsToVirtualPool(DexEncodedMethod method) {
+    return !belongsToDirectPool(method);
+  }
+
   // Collection methods.
 
   abstract int size();
diff --git a/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java b/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java
new file mode 100644
index 0000000..8afffc6
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java
@@ -0,0 +1,306 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.graph;
+
+import com.android.tools.r8.utils.Box;
+import com.android.tools.r8.utils.MethodSignatureEquivalence;
+import com.android.tools.r8.utils.TraversalContinuation;
+import com.google.common.base.Equivalence.Wrapper;
+import it.unimi.dsi.fastutil.objects.Object2ReferenceLinkedOpenHashMap;
+import it.unimi.dsi.fastutil.objects.Object2ReferenceMap;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+public class MethodMapBacking extends MethodCollectionBacking {
+
+  private Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> methodMap;
+
+  public MethodMapBacking() {
+    this.methodMap = createMap();
+  }
+
+  private static Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> createMap() {
+    // Maintain a linked map so the output order remains a deterministic function of the input.
+    return new Object2ReferenceLinkedOpenHashMap<>();
+  }
+
+  private static Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> createMap(int capacity) {
+    // Maintain a linked map so the output order remains a deterministic function of the input.
+    return new Object2ReferenceLinkedOpenHashMap<>(capacity);
+  }
+
+  private Wrapper<DexMethod> wrap(DexMethod method) {
+    return MethodSignatureEquivalence.get().wrap(method);
+  }
+
+  private void replace(Wrapper<DexMethod> existingKey, DexEncodedMethod method) {
+    if (existingKey.get().match(method)) {
+      methodMap.put(existingKey, method);
+    } else {
+      methodMap.remove(existingKey);
+      methodMap.put(wrap(method.method), method);
+    }
+  }
+
+  private void rehash() {
+    Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> newMap = createMap(methodMap.size());
+    for (DexEncodedMethod method : methodMap.values()) {
+      newMap.put(wrap(method.method), method);
+    }
+    methodMap = newMap;
+  }
+
+  @Override
+  boolean verify() {
+    methodMap.forEach(
+        (key, method) -> {
+          assert key.get().match(method);
+        });
+    return true;
+  }
+
+  @Override
+  int size() {
+    return methodMap.size();
+  }
+
+  @Override
+  TraversalContinuation traverse(Function<DexEncodedMethod, TraversalContinuation> fn) {
+    for (Entry<Wrapper<DexMethod>, DexEncodedMethod> entry : methodMap.object2ReferenceEntrySet()) {
+      TraversalContinuation result = fn.apply(entry.getValue());
+      if (result.shouldBreak()) {
+        return result;
+      }
+    }
+    return TraversalContinuation.CONTINUE;
+  }
+
+  @Override
+  Iterable<DexEncodedMethod> methods() {
+    return methodMap.values();
+  }
+
+  @Override
+  List<DexEncodedMethod> directMethods() {
+    List<DexEncodedMethod> methods = new ArrayList<>(size());
+    forEachMethod(
+        method -> {
+          if (belongsToDirectPool(method)) {
+            methods.add(method);
+          }
+        });
+    return methods;
+  }
+
+  @Override
+  List<DexEncodedMethod> virtualMethods() {
+    List<DexEncodedMethod> methods = new ArrayList<>(size());
+    forEachMethod(
+        method -> {
+          if (belongsToVirtualPool(method)) {
+            methods.add(method);
+          }
+        });
+    return methods;
+  }
+
+  @Override
+  DexEncodedMethod getMethod(DexMethod method) {
+    return methodMap.get(wrap(method));
+  }
+
+  private DexEncodedMethod getMethod(Predicate<DexEncodedMethod> predicate) {
+    Box<DexEncodedMethod> found = new Box<>();
+    traverse(
+        method -> {
+          if (predicate.test(method)) {
+            found.set(method);
+            return TraversalContinuation.BREAK;
+          }
+          return TraversalContinuation.CONTINUE;
+        });
+    return found.get();
+  }
+
+  @Override
+  DexEncodedMethod getDirectMethod(DexMethod method) {
+    DexEncodedMethod definition = getMethod(method);
+    return definition != null && belongsToDirectPool(definition) ? definition : null;
+  }
+
+  @Override
+  DexEncodedMethod getDirectMethod(Predicate<DexEncodedMethod> predicate) {
+    Predicate<DexEncodedMethod> isDirect = this::belongsToDirectPool;
+    return getMethod(isDirect.and(predicate));
+  }
+
+  @Override
+  DexEncodedMethod getVirtualMethod(DexMethod method) {
+    DexEncodedMethod definition = getMethod(method);
+    return definition != null && belongsToVirtualPool(definition) ? definition : null;
+  }
+
+  @Override
+  DexEncodedMethod getVirtualMethod(Predicate<DexEncodedMethod> predicate) {
+    Predicate<DexEncodedMethod> isVirtual = this::belongsToVirtualPool;
+    return getMethod(isVirtual.and(predicate));
+  }
+
+  @Override
+  void addMethod(DexEncodedMethod method) {
+    Wrapper<DexMethod> key = wrap(method.method);
+    DexEncodedMethod old = methodMap.put(key, method);
+    assert old == null;
+  }
+
+  @Override
+  void addDirectMethod(DexEncodedMethod method) {
+    assert belongsToDirectPool(method);
+    addMethod(method);
+  }
+
+  @Override
+  void addVirtualMethod(DexEncodedMethod method) {
+    assert belongsToVirtualPool(method);
+    addMethod(method);
+  }
+
+  @Override
+  void addDirectMethods(Collection<DexEncodedMethod> methods) {
+    for (DexEncodedMethod method : methods) {
+      addDirectMethod(method);
+    }
+  }
+
+  @Override
+  void addVirtualMethods(Collection<DexEncodedMethod> methods) {
+    for (DexEncodedMethod method : methods) {
+      addVirtualMethod(method);
+    }
+  }
+
+  @Override
+  void removeDirectMethod(DexMethod method) {
+    methodMap.remove(wrap(method));
+  }
+
+  @Override
+  void setDirectMethods(DexEncodedMethod[] methods) {
+    if ((methods == null || methods.length == 0) && methodMap.isEmpty()) {
+      return;
+    }
+    if (methods == null) {
+      methods = DexEncodedMethod.EMPTY_ARRAY;
+    }
+    Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> newMap =
+        createMap(size() + methods.length);
+    forEachMethod(
+        method -> {
+          if (belongsToVirtualPool(method)) {
+            newMap.put(wrap(method.method), method);
+          }
+        });
+    for (DexEncodedMethod method : methods) {
+      assert belongsToDirectPool(method);
+      newMap.put(wrap(method.method), method);
+    }
+    methodMap = newMap;
+  }
+
+  @Override
+  void setVirtualMethods(DexEncodedMethod[] methods) {
+    if ((methods == null || methods.length == 0) && methodMap.isEmpty()) {
+      return;
+    }
+    if (methods == null) {
+      methods = DexEncodedMethod.EMPTY_ARRAY;
+    }
+    Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> newMap =
+        createMap(size() + methods.length);
+    forEachMethod(
+        method -> {
+          if (belongsToDirectPool(method)) {
+            newMap.put(wrap(method.method), method);
+          }
+        });
+    for (DexEncodedMethod method : methods) {
+      assert belongsToVirtualPool(method);
+      newMap.put(wrap(method.method), method);
+    }
+    methodMap = newMap;
+  }
+
+  @Override
+  void replaceMethods(Function<DexEncodedMethod, DexEncodedMethod> replacement) {
+    boolean rehash = false;
+    for (Object2ReferenceMap.Entry<Wrapper<DexMethod>, DexEncodedMethod> entry :
+        methodMap.object2ReferenceEntrySet()) {
+      DexEncodedMethod newMethod = replacement.apply(entry.getValue());
+      if (newMethod != entry.getValue()) {
+        rehash = rehash || newMethod.method != entry.getKey().get();
+        entry.setValue(newMethod);
+      }
+    }
+    if (rehash) {
+      rehash();
+    }
+  }
+
+  @Override
+  void replaceDirectMethods(Function<DexEncodedMethod, DexEncodedMethod> replacement) {
+    replaceMethods(method -> belongsToDirectPool(method) ? replacement.apply(method) : method);
+  }
+
+  @Override
+  void replaceVirtualMethods(Function<DexEncodedMethod, DexEncodedMethod> replacement) {
+    replaceMethods(method -> belongsToVirtualPool(method) ? replacement.apply(method) : method);
+  }
+
+  @Override
+  DexEncodedMethod replaceDirectMethod(
+      DexMethod method, Function<DexEncodedMethod, DexEncodedMethod> replacement) {
+    Wrapper<DexMethod> key = wrap(method);
+    DexEncodedMethod existing = methodMap.get(key);
+    if (existing == null || belongsToVirtualPool(existing)) {
+      return null;
+    }
+    DexEncodedMethod newMethod = replacement.apply(existing);
+    assert belongsToDirectPool(newMethod);
+    replace(key, newMethod);
+    return newMethod;
+  }
+
+  @Override
+  DexEncodedMethod replaceDirectMethodWithVirtualMethod(
+      DexMethod method, Function<DexEncodedMethod, DexEncodedMethod> replacement) {
+    Wrapper<DexMethod> key = wrap(method);
+    DexEncodedMethod existing = methodMap.get(key);
+    if (existing == null || belongsToVirtualPool(existing)) {
+      return null;
+    }
+    DexEncodedMethod newMethod = replacement.apply(existing);
+    assert belongsToVirtualPool(newMethod);
+    replace(key, newMethod);
+    return newMethod;
+  }
+
+  @Override
+  void virtualizeMethods(Set<DexEncodedMethod> privateInstanceMethods) {
+    // This is a no-op as the virtualizer has modified the encoded method bits.
+    assert verifyVirtualizedMethods(privateInstanceMethods);
+  }
+
+  private boolean verifyVirtualizedMethods(Set<DexEncodedMethod> methods) {
+    for (DexEncodedMethod method : methods) {
+      assert belongsToVirtualPool(method);
+      assert methodMap.get(wrap(method.method)) == method;
+    }
+    return true;
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/TestCompileResult.java b/src/test/java/com/android/tools/r8/TestCompileResult.java
index 87cc650..7f777ed 100644
--- a/src/test/java/com/android/tools/r8/TestCompileResult.java
+++ b/src/test/java/com/android/tools/r8/TestCompileResult.java
@@ -208,10 +208,18 @@
     return self();
   }
 
+  public CR disableVerifer() {
+    assert getBackend() == Backend.CF;
+    if (!vmArguments.contains("-noverify")) {
+      vmArguments.add("-noverify");
+    }
+    return self();
+  }
+
   public CR enableRuntimeAssertions() {
     assert getBackend() == Backend.CF;
-    if (!this.vmArguments.contains("-ea")) {
-      this.vmArguments.add("-ea");
+    if (!vmArguments.contains("-ea")) {
+      vmArguments.add("-ea");
     }
     return self();
   }
diff --git a/src/test/java/com/android/tools/r8/inspection/InspectionApiTest.java b/src/test/java/com/android/tools/r8/inspection/InspectionApiTest.java
index 5c5a6e3..9f413d5 100644
--- a/src/test/java/com/android/tools/r8/inspection/InspectionApiTest.java
+++ b/src/test/java/com/android/tools/r8/inspection/InspectionApiTest.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.inspection;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
 
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
@@ -14,6 +15,9 @@
 import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.references.Reference;
 import com.android.tools.r8.utils.StringUtils;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -49,7 +53,7 @@
         .apply(b -> b.getBuilder().addOutputInspection(this::inspection))
         .run(parameters.getRuntime(), TestClass.class)
         .assertSuccessWithOutput(EXPECTED);
-    assertFound();
+    assertFound(false);
   }
 
   @Test
@@ -61,37 +65,47 @@
         .apply(b -> b.getBuilder().addOutputInspection(this::inspection))
         .run(parameters.getRuntime(), TestClass.class)
         .assertSuccessWithOutput(EXPECTED);
-    assertFound();
+    assertFound(true);
   }
 
   ClassReference foundClass = null;
   FieldReference foundField = null;
-  MethodReference foundMethod = null;
+  Set<MethodReference> foundMethods = new HashSet<>();
 
   private void inspection(Inspector inspector) {
     inspector.forEachClass(
         classInspector -> {
+          assertNull(foundClass);
           foundClass = classInspector.getClassReference();
           classInspector.forEachField(
               fieldInspector -> {
+                assertNull(foundField);
                 foundField = fieldInspector.getFieldReference();
               });
           classInspector.forEachMethod(
               methodInspector -> {
-                // Ignore clinit (which is removed in R8).
-                if (!methodInspector.getMethodReference().getMethodName().equals("<clinit>")) {
-                  foundMethod = methodInspector.getMethodReference();
-                }
+                foundMethods.add(methodInspector.getMethodReference());
               });
         });
   }
 
-  private void assertFound() throws Exception {
+  private void assertFound(boolean isR8) throws Exception {
     assertEquals(Reference.classFromClass(TestClass.class), foundClass);
     assertEquals(Reference.fieldFromField(TestClass.class.getDeclaredField("foo")), foundField);
-    assertEquals(
-        Reference.methodFromMethod(TestClass.class.getDeclaredMethod("main", String[].class)),
-        foundMethod);
+
+    Set<MethodReference> expectedMethods = new HashSet<>();
+    expectedMethods.add(
+        Reference.methodFromMethod(TestClass.class.getDeclaredMethod("main", String[].class)));
+    expectedMethods.add(Reference.methodFromMethod(TestClass.class.getDeclaredConstructor()));
+    if (!isR8) {
+      expectedMethods.add(
+          Reference.method(
+              Reference.classFromClass(TestClass.class),
+              "<clinit>",
+              Collections.emptyList(),
+              null));
+    }
+    assertEquals(expectedMethods, foundMethods);
   }
 
   static class TestClass {