Cleanup VirtualRootMethodsAnalysis

Change-Id: I5af1efd08caa4d04be598c70907819ba426f5f0b
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/VirtualRootMethodsAnalysis.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/VirtualRootMethodsAnalysis.java
index 422b122..dd7eac12 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/VirtualRootMethodsAnalysis.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/codescanner/VirtualRootMethodsAnalysis.java
@@ -8,17 +8,16 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexMethod;
-import com.android.tools.r8.graph.DexMethodSignature;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorCodeScanner;
 import com.android.tools.r8.optimize.argumentpropagation.utils.DepthFirstTopDownClassHierarchyTraversal;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.collections.DexMethodSignatureMap;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.Sets;
 import java.util.Collection;
-import java.util.HashMap;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
@@ -45,13 +44,13 @@
     }
 
     VirtualRootMethod(ProgramMethod root, VirtualRootMethod parent) {
+      assert root != null;
       this.parent = parent;
       this.root = root;
     }
 
-    @SuppressWarnings("ReferenceEquality")
     void addOverride(ProgramMethod override) {
-      assert override != root;
+      assert override.getDefinition() != root.getDefinition();
       assert override.getMethodSignature().equals(root.getMethodSignature());
       overrides.add(override);
       if (hasParent()) {
@@ -103,7 +102,7 @@
     }
   }
 
-  private final Map<DexProgramClass, Map<DexMethodSignature, VirtualRootMethod>>
+  private final Map<DexProgramClass, DexMethodSignatureMap<VirtualRootMethod>>
       virtualRootMethodsPerClass = new IdentityHashMap<>();
 
   private final Set<DexMethod> monomorphicVirtualMethods = Sets.newIdentityHashSet();
@@ -138,17 +137,18 @@
 
   @Override
   public void visit(DexProgramClass clazz) {
-    Map<DexMethodSignature, VirtualRootMethod> state = computeVirtualRootMethodsState(clazz);
+    DexMethodSignatureMap<VirtualRootMethod> state = computeVirtualRootMethodsState(clazz);
     virtualRootMethodsPerClass.put(clazz, state);
   }
 
-  private Map<DexMethodSignature, VirtualRootMethod> computeVirtualRootMethodsState(
+  private DexMethodSignatureMap<VirtualRootMethod> computeVirtualRootMethodsState(
       DexProgramClass clazz) {
-    Map<DexMethodSignature, VirtualRootMethod> virtualRootMethodsForClass = new HashMap<>();
+    DexMethodSignatureMap<VirtualRootMethod> virtualRootMethodsForClass =
+        DexMethodSignatureMap.create();
     immediateSubtypingInfo.forEachImmediateProgramSuperClass(
         clazz,
         superclass -> {
-          Map<DexMethodSignature, VirtualRootMethod> virtualRootMethodsForSuperclass =
+          DexMethodSignatureMap<VirtualRootMethod> virtualRootMethodsForSuperclass =
               virtualRootMethodsPerClass.get(superclass);
           virtualRootMethodsForSuperclass.forEach(
               (signature, info) ->
@@ -157,11 +157,10 @@
         });
     clazz.forEachProgramVirtualMethod(
         method -> {
-          DexMethodSignature signature = method.getMethodSignature();
-          if (virtualRootMethodsForClass.containsKey(signature)) {
-            virtualRootMethodsForClass.get(signature).getParent().addOverride(method);
+          if (virtualRootMethodsForClass.containsKey(method)) {
+            virtualRootMethodsForClass.get(method).getParent().addOverride(method);
           } else {
-            virtualRootMethodsForClass.put(signature, new VirtualRootMethod(method));
+            virtualRootMethodsForClass.put(method, new VirtualRootMethod(method));
           }
         });
     return virtualRootMethodsForClass;
@@ -170,7 +169,7 @@
   @Override
   public void prune(DexProgramClass clazz) {
     // Record the overrides for each virtual method that is rooted at this class.
-    Map<DexMethodSignature, VirtualRootMethod> virtualRootMethodsForClass =
+    DexMethodSignatureMap<VirtualRootMethod> virtualRootMethodsForClass =
         virtualRootMethodsPerClass.remove(clazz);
     clazz.forEachProgramVirtualMethod(
         rootCandidate -> {
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureMap.java b/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureMap.java
index 8aa569a..65881a8 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureMap.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureMap.java
@@ -52,7 +52,11 @@
   }
 
   public T put(DexEncodedMethod method, T value) {
-    return put(method.getReference(), value);
+    return put(method.getSignature(), value);
+  }
+
+  public T put(DexClassAndMethod method, T value) {
+    return put(method.getMethodSignature(), value);
   }
 
   @Override
@@ -163,6 +167,10 @@
     return containsKey(method.getSignature());
   }
 
+  public boolean containsKey(DexClassAndMethod method) {
+    return containsKey(method.getMethodSignature());
+  }
+
   @Override
   public boolean containsKey(Object o) {
     return backing.containsKey(o);
@@ -182,6 +190,10 @@
     return get(method.getSignature());
   }
 
+  public T get(DexClassAndMethod method) {
+    return get(method.getMethodSignature());
+  }
+
   public boolean containsKey(DexMethodSignature signature) {
     return backing.containsKey(signature);
   }