Allow resolving method on class from method proto and name

Change-Id: Icc6a6bcf503123d995d91be514b25018b3a314bd
diff --git a/src/main/java/com/android/tools/r8/graph/AppInfoWithClassHierarchy.java b/src/main/java/com/android/tools/r8/graph/AppInfoWithClassHierarchy.java
index e6db3e1..30a1c04 100644
--- a/src/main/java/com/android/tools/r8/graph/AppInfoWithClassHierarchy.java
+++ b/src/main/java/com/android/tools/r8/graph/AppInfoWithClassHierarchy.java
@@ -688,15 +688,21 @@
   }
 
   public MethodResolutionResult resolveMethodOnClass(DexMethod method, DexClass clazz) {
+    return resolveMethodOnClass(clazz, method.getProto(), method.getName());
+  }
+
+  public MethodResolutionResult resolveMethodOnClass(
+      DexClass clazz, DexProto methodProto, DexString methodName) {
     assert checkIfObsolete();
     assert !clazz.isInterface();
     // Step 2:
-    MethodResolutionResult result = resolveMethodOnClassStep2(clazz, method, clazz);
+    MethodResolutionResult result =
+        resolveMethodOnClassStep2(clazz, methodProto, methodName, clazz);
     if (result != null) {
       return result;
     }
     // Finally Step 3:
-    return resolveMethodStep3(clazz, method);
+    return resolveMethodStep3(clazz, methodProto, methodName);
   }
 
   /**
@@ -705,16 +711,19 @@
    * 5.4.3.3 of the JVM Spec</a>.
    */
   private MethodResolutionResult resolveMethodOnClassStep2(
-      DexClass clazz, DexMethod method, DexClass initialResolutionHolder) {
+      DexClass clazz,
+      DexProto methodProto,
+      DexString methodName,
+      DexClass initialResolutionHolder) {
     // Pt. 1: Signature polymorphic method check.
     // See also <a href="https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.9">
     // Section 2.9 of the JVM Spec</a>.
-    DexEncodedMethod result = clazz.lookupSignaturePolymorphicMethod(method.name, dexItemFactory());
+    DexEncodedMethod result = clazz.lookupSignaturePolymorphicMethod(methodName, dexItemFactory());
     if (result != null) {
       return new SingleResolutionResult(initialResolutionHolder, clazz, result);
     }
     // Pt 2: Find a method that matches the descriptor.
-    result = clazz.lookupMethod(method);
+    result = clazz.lookupMethod(methodProto, methodName);
     if (result != null) {
       // If the resolved method is private, then it can only be accessed if the symbolic reference
       // that initiated the resolution was the type at which the method resolved on. If that is not
@@ -730,7 +739,8 @@
     if (clazz.superType != null) {
       DexClass superClass = definitionFor(clazz.superType);
       if (superClass != null) {
-        return resolveMethodOnClassStep2(superClass, method, initialResolutionHolder);
+        return resolveMethodOnClassStep2(
+            superClass, methodProto, methodName, initialResolutionHolder);
       }
     }
     return null;
@@ -742,35 +752,45 @@
    * 5.4.3.3 of the JVM Spec</a>. As this is the same for interfaces and classes, we share one
    * implementation.
    */
-  private MethodResolutionResult resolveMethodStep3(DexClass clazz, DexMethod method) {
+  private MethodResolutionResult resolveMethodStep3(
+      DexClass clazz, DexProto methodProto, DexString methodName) {
     MaximallySpecificMethodsBuilder builder = new MaximallySpecificMethodsBuilder();
-    resolveMethodStep3Helper(method, clazz, builder);
+    resolveMethodStep3Helper(methodProto, methodName, clazz, builder);
     return builder.resolve(clazz);
   }
 
   // Non-private lookup (ie, not resolution) to find interface targets.
   DexClassAndMethod lookupMaximallySpecificTarget(DexClass clazz, DexMethod method) {
     MaximallySpecificMethodsBuilder builder = new MaximallySpecificMethodsBuilder();
-    resolveMethodStep3Helper(method, clazz, builder);
+    resolveMethodStep3Helper(method.getProto(), method.getName(), clazz, builder);
     return builder.lookup();
   }
 
   // Non-private lookup (ie, not resolution) to find interface targets.
   DexClassAndMethod lookupMaximallySpecificTarget(LambdaDescriptor lambda, DexMethod method) {
     MaximallySpecificMethodsBuilder builder = new MaximallySpecificMethodsBuilder();
-    resolveMethodStep3Helper(method, dexItemFactory().objectType, lambda.interfaces, builder);
+    resolveMethodStep3Helper(
+        method.getProto(),
+        method.getName(),
+        dexItemFactory().objectType,
+        lambda.interfaces,
+        builder);
     return builder.lookup();
   }
 
   /** Helper method that builds the set of maximally specific methods. */
   private void resolveMethodStep3Helper(
-      DexMethod method, DexClass clazz, MaximallySpecificMethodsBuilder builder) {
+      DexProto methodProto,
+      DexString methodName,
+      DexClass clazz,
+      MaximallySpecificMethodsBuilder builder) {
     resolveMethodStep3Helper(
-        method, clazz.superType, Arrays.asList(clazz.interfaces.values), builder);
+        methodProto, methodName, clazz.superType, Arrays.asList(clazz.interfaces.values), builder);
   }
 
   private void resolveMethodStep3Helper(
-      DexMethod method,
+      DexProto methodProto,
+      DexString methodName,
       DexType superType,
       List<DexType> interfaces,
       MaximallySpecificMethodsBuilder builder) {
@@ -781,20 +801,20 @@
         continue;
       }
       assert definition.isInterface();
-      DexEncodedMethod result = definition.lookupMethod(method);
+      DexEncodedMethod result = definition.lookupMethod(methodProto, methodName);
       if (isMaximallySpecificCandidate(result)) {
         // The candidate is added and doing so will prohibit shadowed methods from being in the set.
         builder.addCandidate(definition, result, this);
       } else {
         // Look at the super-interfaces of this class and keep searching.
-        resolveMethodStep3Helper(method, definition, builder);
+        resolveMethodStep3Helper(methodProto, methodName, definition, builder);
       }
     }
     // Now look at indirect super interfaces.
     if (superType != null) {
       DexClass superClass = definitionFor(superType);
       if (superClass != null) {
-        resolveMethodStep3Helper(method, superClass, builder);
+        resolveMethodStep3Helper(methodProto, methodName, superClass, builder);
       }
     }
   }
@@ -860,7 +880,7 @@
     }
     // Step 3: Look for maximally-specific superinterface methods or any interface definition.
     //         This is the same for classes and interfaces.
-    return resolveMethodStep3(definition, desc);
+    return resolveMethodStep3(definition, desc.getProto(), desc.getName());
   }
 
   /**
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index e4d00aa..d3a63f6 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -588,6 +588,10 @@
     return methodCollection.getMethod(method);
   }
 
+  public DexEncodedMethod lookupMethod(DexProto methodProto, DexString methodName) {
+    return methodCollection.getMethod(methodProto, methodName);
+  }
+
   /** Find method in this class matching {@param method}. */
   public DexEncodedMethod lookupMethod(Predicate<DexEncodedMethod> predicate) {
     return methodCollection.getMethod(predicate);
diff --git a/src/main/java/com/android/tools/r8/graph/DexMethod.java b/src/main/java/com/android/tools/r8/graph/DexMethod.java
index 5a109e7..c5cb0ff 100644
--- a/src/main/java/com/android/tools/r8/graph/DexMethod.java
+++ b/src/main/java/com/android/tools/r8/graph/DexMethod.java
@@ -228,7 +228,15 @@
 
   @Override
   public boolean match(DexMethod method) {
-    return method.name == name && method.proto == proto;
+    return match(method.getProto(), method.getName());
+  }
+
+  public boolean match(DexMethodSignature method) {
+    return match(method.getProto(), method.getName());
+  }
+
+  public boolean match(DexProto methodProto, DexString methodName) {
+    return proto == methodProto && name == methodName;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/graph/DexMethodSignature.java b/src/main/java/com/android/tools/r8/graph/DexMethodSignature.java
index 8b2fcc4..74d464f 100644
--- a/src/main/java/com/android/tools/r8/graph/DexMethodSignature.java
+++ b/src/main/java/com/android/tools/r8/graph/DexMethodSignature.java
@@ -4,9 +4,13 @@
 
 package com.android.tools.r8.graph;
 
+import com.android.tools.r8.utils.structural.StructuralItem;
+import com.android.tools.r8.utils.structural.StructuralMapping;
+import com.android.tools.r8.utils.structural.StructuralSpecification;
 import java.util.Objects;
 
-public class DexMethodSignature {
+public class DexMethodSignature implements StructuralItem<DexMethodSignature> {
+
   private final DexProto proto;
   private final DexString name;
 
@@ -21,14 +25,31 @@
     this.name = name;
   }
 
-  public DexProto getProto() {
-    return proto;
+  public int getArity() {
+    return proto.getArity();
   }
 
   public DexString getName() {
     return name;
   }
 
+  public DexProto getProto() {
+    return proto;
+  }
+
+  public DexType getReturnType() {
+    return proto.returnType;
+  }
+
+  @Override
+  public StructuralMapping<DexMethodSignature> getStructuralMapping() {
+    return DexMethodSignature::specify;
+  }
+
+  private static void specify(StructuralSpecification<DexMethodSignature, ?> spec) {
+    spec.withItem(DexMethodSignature::getName).withItem(DexMethodSignature::getProto);
+  }
+
   public DexMethodSignature withName(DexString name) {
     return new DexMethodSignature(proto, name);
   }
@@ -58,12 +79,9 @@
     return Objects.hash(proto, name);
   }
 
-  public DexType getReturnType() {
-    return proto.returnType;
-  }
-
-  public int getArity() {
-    return proto.getArity();
+  @Override
+  public DexMethodSignature self() {
+    return this;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java b/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
index 7e67c3b..6930912 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodArrayBacking.java
@@ -249,9 +249,21 @@
   }
 
   @Override
-  DexEncodedMethod getMethod(DexMethod method) {
-    DexEncodedMethod result = getDirectMethod(method);
-    return result == null ? getVirtualMethod(method) : result;
+  DexEncodedMethod getMethod(DexProto methodProto, DexString methodName) {
+    DexEncodedMethod directMethod = internalGetMethod(methodProto, methodName, directMethods);
+    return directMethod == null
+        ? internalGetMethod(methodProto, methodName, virtualMethods)
+        : directMethod;
+  }
+
+  private static DexEncodedMethod internalGetMethod(
+      DexProto methodProto, DexString methodName, DexEncodedMethod[] methods) {
+    for (DexEncodedMethod method : methods) {
+      if (method.getReference().match(methodProto, methodName)) {
+        return method;
+      }
+    }
+    return null;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/graph/MethodCollection.java b/src/main/java/com/android/tools/r8/graph/MethodCollection.java
index 3509b8c..3dc2b58 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodCollection.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodCollection.java
@@ -160,7 +160,11 @@
   }
 
   public DexEncodedMethod getMethod(DexMethod method) {
-    return backing.getMethod(method);
+    return backing.getMethod(method.getProto(), method.getName());
+  }
+
+  public DexEncodedMethod getMethod(DexProto methodProto, DexString methodName) {
+    return backing.getMethod(methodProto, methodName);
   }
 
   public DexEncodedMethod getMethod(Predicate<DexEncodedMethod> predicate) {
diff --git a/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java b/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
index 301352a..90fbf1f 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodCollectionBacking.java
@@ -68,7 +68,7 @@
 
   // Lookup methods.
 
-  abstract DexEncodedMethod getMethod(DexMethod method);
+  abstract DexEncodedMethod getMethod(DexProto methodProto, DexString methodName);
 
   abstract DexEncodedMethod getDirectMethod(DexMethod method);
 
diff --git a/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java b/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java
index a7d3c45..8f705d4 100644
--- a/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java
+++ b/src/main/java/com/android/tools/r8/graph/MethodMapBacking.java
@@ -5,67 +5,59 @@
 
 import com.android.tools.r8.utils.Box;
 import com.android.tools.r8.utils.IteratorUtils;
-import com.android.tools.r8.utils.MethodSignatureEquivalence;
 import com.android.tools.r8.utils.TraversalContinuation;
-import com.google.common.base.Equivalence.Wrapper;
 import com.google.common.collect.Lists;
-import it.unimi.dsi.fastutil.objects.Object2ReferenceLinkedOpenHashMap;
-import it.unimi.dsi.fastutil.objects.Object2ReferenceMap;
-import it.unimi.dsi.fastutil.objects.Object2ReferenceRBTreeMap;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
+import java.util.TreeMap;
 import java.util.function.Function;
 import java.util.function.Predicate;
 
 public class MethodMapBacking extends MethodCollectionBacking {
 
-  private Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> methodMap;
+  private Map<DexMethodSignature, DexEncodedMethod> methodMap;
 
   public MethodMapBacking() {
     this(createMap());
   }
 
-  private MethodMapBacking(Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> methodMap) {
+  private MethodMapBacking(Map<DexMethodSignature, DexEncodedMethod> methodMap) {
     this.methodMap = methodMap;
   }
 
   public static MethodMapBacking createSorted() {
-    Comparator<Wrapper<DexMethod>> comparator = Comparator.comparing(Wrapper::get);
-    return new MethodMapBacking(new Object2ReferenceRBTreeMap<>(comparator));
+    return new MethodMapBacking(new TreeMap<>());
   }
 
-  private static Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> createMap() {
+  private static Map<DexMethodSignature, DexEncodedMethod> createMap() {
     // Maintain a linked map so the output order remains a deterministic function of the input.
-    return new Object2ReferenceLinkedOpenHashMap<>();
+    return new HashMap<>();
   }
 
-  private static Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> createMap(int capacity) {
+  private static Map<DexMethodSignature, DexEncodedMethod> createMap(int capacity) {
     // Maintain a linked map so the output order remains a deterministic function of the input.
-    return new Object2ReferenceLinkedOpenHashMap<>(capacity);
+    return new HashMap<>(capacity);
   }
 
-  private Wrapper<DexMethod> wrap(DexMethod method) {
-    return MethodSignatureEquivalence.get().wrap(method);
-  }
-
-  private void replace(Wrapper<DexMethod> existingKey, DexEncodedMethod method) {
-    if (existingKey.get().match(method)) {
+  private void replace(DexMethodSignature existingKey, DexEncodedMethod method) {
+    if (method.getReference().match(existingKey)) {
       methodMap.put(existingKey, method);
     } else {
       methodMap.remove(existingKey);
-      methodMap.put(wrap(method.getReference()), method);
+      methodMap.put(method.getSignature(), method);
     }
   }
 
   @Override
   boolean verify() {
     methodMap.forEach(
-        (key, method) -> {
-          assert key.get().match(method);
+        (signature, method) -> {
+          assert method.getReference().match(signature);
         });
     return true;
   }
@@ -97,7 +89,7 @@
 
   @Override
   TraversalContinuation traverse(Function<DexEncodedMethod, TraversalContinuation> fn) {
-    for (Entry<Wrapper<DexMethod>, DexEncodedMethod> entry : methodMap.object2ReferenceEntrySet()) {
+    for (Entry<DexMethodSignature, DexEncodedMethod> entry : methodMap.entrySet()) {
       TraversalContinuation result = fn.apply(entry.getValue());
       if (result.shouldBreak()) {
         return result;
@@ -122,8 +114,8 @@
   }
 
   @Override
-  DexEncodedMethod getMethod(DexMethod method) {
-    return methodMap.get(wrap(method));
+  DexEncodedMethod getMethod(DexProto methodProto, DexString methodName) {
+    return methodMap.get(new DexMethodSignature(methodProto, methodName));
   }
 
   private DexEncodedMethod getMethod(Predicate<DexEncodedMethod> predicate) {
@@ -141,7 +133,7 @@
 
   @Override
   DexEncodedMethod getDirectMethod(DexMethod method) {
-    DexEncodedMethod definition = getMethod(method);
+    DexEncodedMethod definition = getMethod(method.getProto(), method.getName());
     return definition != null && belongsToDirectPool(definition) ? definition : null;
   }
 
@@ -153,7 +145,7 @@
 
   @Override
   DexEncodedMethod getVirtualMethod(DexMethod method) {
-    DexEncodedMethod definition = getMethod(method);
+    DexEncodedMethod definition = getMethod(method.getProto(), method.getName());
     return definition != null && belongsToVirtualPool(definition) ? definition : null;
   }
 
@@ -165,7 +157,7 @@
 
   @Override
   void addMethod(DexEncodedMethod method) {
-    Wrapper<DexMethod> key = wrap(method.getReference());
+    DexMethodSignature key = method.getSignature();
     DexEncodedMethod old = methodMap.put(key, method);
     assert old == null;
   }
@@ -208,12 +200,12 @@
 
   @Override
   DexEncodedMethod removeMethod(DexMethod method) {
-    return methodMap.remove(wrap(method));
+    return methodMap.remove(method.getSignature());
   }
 
   @Override
   void removeMethods(Set<DexEncodedMethod> methods) {
-    methods.forEach(method -> methodMap.remove(wrap(method.getReference())));
+    methods.forEach(method -> methodMap.remove(method.getSignature()));
   }
 
   @Override
@@ -224,17 +216,16 @@
     if (methods == null) {
       methods = DexEncodedMethod.EMPTY_ARRAY;
     }
-    Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> newMap =
-        createMap(size() + methods.length);
+    Map<DexMethodSignature, DexEncodedMethod> newMap = createMap(size() + methods.length);
     forEachMethod(
         method -> {
           if (belongsToVirtualPool(method)) {
-            newMap.put(wrap(method.getReference()), method);
+            newMap.put(method.getSignature(), method);
           }
         });
     for (DexEncodedMethod method : methods) {
       assert belongsToDirectPool(method);
-      newMap.put(wrap(method.getReference()), method);
+      newMap.put(method.getSignature(), method);
     }
     methodMap = newMap;
   }
@@ -247,17 +238,16 @@
     if (methods == null) {
       methods = DexEncodedMethod.EMPTY_ARRAY;
     }
-    Object2ReferenceMap<Wrapper<DexMethod>, DexEncodedMethod> newMap =
-        createMap(size() + methods.length);
+    Map<DexMethodSignature, DexEncodedMethod> newMap = createMap(size() + methods.length);
     forEachMethod(
         method -> {
           if (belongsToDirectPool(method)) {
-            newMap.put(wrap(method.getReference()), method);
+            newMap.put(method.getSignature(), method);
           }
         });
     for (DexEncodedMethod method : methods) {
       assert belongsToVirtualPool(method);
-      newMap.put(wrap(method.getReference()), method);
+      newMap.put(method.getSignature(), method);
     }
     methodMap = newMap;
   }
@@ -325,7 +315,7 @@
       DexMethod method,
       Function<DexEncodedMethod, DexEncodedMethod> replacement,
       Predicate<DexEncodedMethod> predicate) {
-    Wrapper<DexMethod> key = wrap(method);
+    DexMethodSignature key = method.getSignature();
     DexEncodedMethod existing = methodMap.get(key);
     if (existing == null || !predicate.test(existing)) {
       return null;
@@ -339,7 +329,7 @@
   @Override
   DexEncodedMethod replaceDirectMethodWithVirtualMethod(
       DexMethod method, Function<DexEncodedMethod, DexEncodedMethod> replacement) {
-    Wrapper<DexMethod> key = wrap(method);
+    DexMethodSignature key = method.getSignature();
     DexEncodedMethod existing = methodMap.get(key);
     if (existing == null || belongsToVirtualPool(existing)) {
       return null;
@@ -359,7 +349,7 @@
   private boolean verifyVirtualizedMethods(Set<DexEncodedMethod> methods) {
     for (DexEncodedMethod method : methods) {
       assert belongsToVirtualPool(method);
-      assert methodMap.get(wrap(method.getReference())) == method;
+      assert methodMap.get(method.getSignature()) == method;
     }
     return true;
   }