Improve performance of collision detection

The collision detection works by collecting all the targeted method signatures in the program, and then checking before each merge operation, if that merge operation could lead to a collision, by considering all the signatures that have been collected.

The current implementation only considers the targeted method signatures that references at least one program type. This CL adds two optimizations: it considers only those method signatures that references at least one merge or mergee candidate, and it also only considers methods that are potentially overloaded.

On GMS Core, this reduces the running time of the collision detector from approx. 25 seconds to 2 seconds.

Change-Id: I8395fa01b61b98c937fe53d2a09b3b13f81b14f2
diff --git a/src/main/java/com/android/tools/r8/graph/DexProto.java b/src/main/java/com/android/tools/r8/graph/DexProto.java
index 17f6efa..5c05605 100644
--- a/src/main/java/com/android/tools/r8/graph/DexProto.java
+++ b/src/main/java/com/android/tools/r8/graph/DexProto.java
@@ -8,6 +8,8 @@
 
 public class DexProto extends IndexedDexItem implements PresortedComparable<DexProto> {
 
+  public static final DexProto SENTINEL = new DexProto(null, null, null);
+
   public final DexString shorty;
   public final DexType returnType;
   public final DexTypeList parameters;
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index 3dc03f8..9acb95c 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -39,6 +39,7 @@
 import com.google.common.collect.ImmutableSet;
 import it.unimi.dsi.fastutil.ints.Int2IntMap;
 import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
+import it.unimi.dsi.fastutil.objects.Reference2BooleanOpenHashMap;
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
 import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
 import java.util.ArrayDeque;
@@ -55,7 +56,6 @@
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Predicate;
-import java.util.stream.Collectors;
 
 /**
  * Merges Supertypes with a single implementation into their single subtype.
@@ -143,12 +143,18 @@
   private final Timing timing;
   private Collection<DexMethod> invokes;
 
+  // Set of merge candidates.
+  private final Set<DexProgramClass> mergeCandidates = new HashSet<>();
+
   // Map from source class to target class.
   private final Map<DexType, DexType> mergedClasses = new HashMap<>();
 
   // Map from target class to the super classes that have been merged into the target class.
   private final Map<DexType, Set<DexType>> mergedClassesInverse = new HashMap<>();
 
+  // Set of types that must not be merged into their subtype.
+  private final Set<DexType> pinnedTypes = new HashSet<>();
+
   // The resulting graph lense that should be used after class merging.
   private final VerticalClassMergerGraphLense.Builder renamedMembersLense;
 
@@ -162,26 +168,36 @@
     this.graphLense = graphLense;
     this.renamedMembersLense = VerticalClassMergerGraphLense.builder(appInfo);
     this.timing = timing;
+
+    Iterable<DexProgramClass> classes = application.classesWithDeterministicOrder();
+    initializePinnedTypes(classes); // Must be initialized prior to mergeCandidates.
+    initializeMergeCandidates(classes);
+  }
+
+  private void initializeMergeCandidates(Iterable<DexProgramClass> classes) {
+    for (DexProgramClass clazz : classes) {
+      if (isMergeCandidate(clazz, pinnedTypes) && isStillMergeCandidate(clazz)) {
+        mergeCandidates.add(clazz);
+      }
+    }
   }
 
   // Returns a set of types that must not be merged into other types.
-  private Set<DexType> getPinnedTypes(Iterable<DexProgramClass> classes) {
-    Set<DexType> pinnedTypes = new HashSet<>();
-
+  private void initializePinnedTypes(Iterable<DexProgramClass> classes) {
     // For all pinned fields, also pin the type of the field (because changing the type of the field
     // implicitly changes the signature of the field). Similarly, for all pinned methods, also pin
     // the return type and the parameter types of the method.
-    extractPinnedItems(appInfo.pinnedItems, pinnedTypes, AbortReason.PINNED_SOURCE);
+    extractPinnedItems(appInfo.pinnedItems, AbortReason.PINNED_SOURCE);
 
     // TODO(christofferqa): Remove the invariant that the graph lense should not modify any
     // methods from the sets alwaysInline and noSideEffects (see use of assertNotModified).
-    extractPinnedItems(appInfo.alwaysInline, pinnedTypes, AbortReason.ALWAYS_INLINE);
-    extractPinnedItems(appInfo.noSideEffects.keySet(), pinnedTypes, AbortReason.NO_SIDE_EFFECTS);
+    extractPinnedItems(appInfo.alwaysInline, AbortReason.ALWAYS_INLINE);
+    extractPinnedItems(appInfo.noSideEffects.keySet(), AbortReason.NO_SIDE_EFFECTS);
 
     for (DexProgramClass clazz : classes) {
       for (DexEncodedMethod method : clazz.methods()) {
         if (method.accessFlags.isNative()) {
-          markTypeAsPinned(clazz.type, pinnedTypes, AbortReason.NATIVE_METHOD);
+          markTypeAsPinned(clazz.type, AbortReason.NATIVE_METHOD);
         }
       }
     }
@@ -206,37 +222,35 @@
         pinnedTypes.add(signature.holder);
       }
     }
-    return pinnedTypes;
   }
 
-  private void extractPinnedItems(
-      Iterable<DexItem> items, Set<DexType> pinnedTypes, AbortReason reason) {
+  private void extractPinnedItems(Iterable<DexItem> items, AbortReason reason) {
     for (DexItem item : items) {
       if (item instanceof DexType || item instanceof DexClass) {
         DexType type = item instanceof DexType ? (DexType) item : ((DexClass) item).type;
-        markTypeAsPinned(type, pinnedTypes, reason);
+        markTypeAsPinned(type, reason);
       } else if (item instanceof DexField || item instanceof DexEncodedField) {
         // Pin the holder and the type of the field.
         DexField field =
             item instanceof DexField ? (DexField) item : ((DexEncodedField) item).field;
-        markTypeAsPinned(field.clazz, pinnedTypes, reason);
-        markTypeAsPinned(field.type, pinnedTypes, reason);
+        markTypeAsPinned(field.clazz, reason);
+        markTypeAsPinned(field.type, reason);
       } else if (item instanceof DexMethod || item instanceof DexEncodedMethod) {
         // Pin the holder, the return type and the parameter types of the method. If we were to
         // merge any of these types into their sub classes, then we would implicitly change the
         // signature of this method.
         DexMethod method =
             item instanceof DexMethod ? (DexMethod) item : ((DexEncodedMethod) item).method;
-        markTypeAsPinned(method.holder, pinnedTypes, reason);
-        markTypeAsPinned(method.proto.returnType, pinnedTypes, reason);
+        markTypeAsPinned(method.holder, reason);
+        markTypeAsPinned(method.proto.returnType, reason);
         for (DexType parameterType : method.proto.parameters.values) {
-          markTypeAsPinned(parameterType, pinnedTypes, reason);
+          markTypeAsPinned(parameterType, reason);
         }
       }
     }
   }
 
-  private void markTypeAsPinned(DexType type, Set<DexType> pinnedTypes, AbortReason reason) {
+  private void markTypeAsPinned(DexType type, AbortReason reason) {
     if (appInfo.isPinned(type)) {
       // We check for the case where the type is pinned according to appInfo.isPinned,
       // so we only need to add it here if it is not the case.
@@ -255,6 +269,8 @@
     }
   }
 
+  // Returns true if [clazz] is a merge candidate. Note that the result of the checks in this
+  // method do not change in response to any class merges.
   private boolean isMergeCandidate(DexProgramClass clazz, Set<DexType> pinnedTypes) {
     if (appInfo.instantiatedTypes.contains(clazz.type)
         || appInfo.instantiatedLambdas.contains(clazz.type)
@@ -262,14 +278,8 @@
         || pinnedTypes.contains(clazz.type)) {
       return false;
     }
-    if (mergedClassesInverse.containsKey(clazz.type)) {
-      // Do not allow merging the resulting class into its subclass.
-      // TODO(christofferqa): Get rid of this limitation.
-      if (Log.ENABLED) {
-        AbortReason.ALREADY_MERGED.printLogMessageForClass(clazz);
-      }
-      return false;
-    }
+    // Note that the property "singleSubtype == null" cannot change during merging, since we visit
+    // classes in a top-down order.
     DexType singleSubtype = clazz.type.getSingleSubtype();
     if (singleSubtype == null) {
       // TODO(christofferqa): Even if [clazz] has multiple subtypes, we could still merge it into
@@ -286,6 +296,7 @@
       if (appInfo.isPinned(method.method)) {
         return false;
       }
+      // TODO(christofferqa): We should always be able to force inline initializers.
       if (method.isInstanceInitializer() && disallowInlining(method)) {
         // Cannot guarantee that markForceInline() will work.
         if (Log.ENABLED) {
@@ -294,13 +305,64 @@
         return false;
       }
     }
-    DexClass targetClass = appInfo.definitionFor(singleSubtype);
+    if (clazz.getEnclosingMethod() != null || !clazz.getInnerClasses().isEmpty()) {
+      // TODO(herhut): Consider supporting merging of enclosing-method and inner-class attributes.
+      if (Log.ENABLED) {
+        AbortReason.UNSUPPORTED_ATTRIBUTES.printLogMessageForClass(clazz);
+      }
+      return false;
+    }
+    return true;
+  }
+
+  // Returns true if [clazz] is a merge candidate. Note that the result of the checks in this
+  // method may change in response to class merges. Therefore, this method should always be called
+  // before merging [clazz] into its subtype.
+  private boolean isStillMergeCandidate(DexProgramClass clazz) {
+    assert isMergeCandidate(clazz, pinnedTypes);
+    if (mergedClassesInverse.containsKey(clazz.type)) {
+      // Do not allow merging the resulting class into its subclass.
+      // TODO(christofferqa): Get rid of this limitation.
+      if (Log.ENABLED) {
+        AbortReason.ALREADY_MERGED.printLogMessageForClass(clazz);
+      }
+      return false;
+    }
+    DexClass targetClass = appInfo.definitionFor(clazz.type.getSingleSubtype());
+    if (clazz.hasClassInitializer() && targetClass.hasClassInitializer()) {
+      // TODO(herhut): Handle class initializers.
+      if (Log.ENABLED) {
+        AbortReason.STATIC_INITIALIZERS.printLogMessageForClass(clazz);
+      }
+      return false;
+    }
+    if (targetClass.getEnclosingMethod() != null || !targetClass.getInnerClasses().isEmpty()) {
+      // TODO(herhut): Consider supporting merging of enclosing-method and inner-class attributes.
+      if (Log.ENABLED) {
+        AbortReason.UNSUPPORTED_ATTRIBUTES.printLogMessageForClass(clazz);
+      }
+      return false;
+    }
     if (mergeMayLeadToIllegalAccesses(clazz, targetClass)) {
       if (Log.ENABLED) {
         AbortReason.ILLEGAL_ACCESS.printLogMessageForClass(clazz);
       }
       return false;
     }
+    if (methodResolutionMayChange(clazz, targetClass)) {
+      if (Log.ENABLED) {
+        AbortReason.RESOLUTION_FOR_METHODS_MAY_CHANGE.printLogMessageForClass(clazz);
+      }
+      return false;
+    }
+    // Field resolution first considers the direct interfaces of [targetClass] before it proceeds
+    // to the super class.
+    if (fieldResolutionMayChange(clazz, targetClass)) {
+      if (Log.ENABLED) {
+        AbortReason.RESOLUTION_FOR_FIELDS_MAY_CHANGE.printLogMessageForClass(clazz);
+      }
+      return false;
+    }
     return true;
   }
 
@@ -349,49 +411,101 @@
     return false;
   }
 
-  private void addProgramMethods(Set<Wrapper<DexMethod>> set, DexMethod method,
-      Equivalence<DexMethod> equivalence) {
-    DexClass definition = appInfo.definitionFor(method.holder);
-    if (definition != null && definition.isProgramClass()) {
-      set.add(equivalence.wrap(method));
-    }
-  }
-
   private Collection<DexMethod> getInvokes() {
     if (invokes == null) {
-      // Collect all reachable methods that are not within a library class. Those defined on
-      // library classes are known not to have program classes in their signature.
-      // Also filter methods that only use types from library classes in their signatures. We
-      // know that those won't conflict.
-      Set<Wrapper<DexMethod>> filteredInvokes = new HashSet<>();
-      Equivalence<DexMethod> equivalence = MethodSignatureEquivalence.get();
-      appInfo.targetedMethods.forEach(m -> addProgramMethods(filteredInvokes, m, equivalence));
-      invokes = filteredInvokes.stream().map(Wrapper::get).filter(this::removeNonProgram)
-          .collect(Collectors.toList());
+      invokes = new OverloadedMethodSignaturesRetriever().get();
     }
     return invokes;
   }
 
-  private boolean isProgramClass(DexType type) {
-    if (type.isArrayType()) {
-      type = type.toBaseType(appInfo.dexItemFactory);
-    }
-    if (type.isClassType()) {
-      DexClass clazz = appInfo.definitionFor(type);
-      if (clazz != null && clazz.isProgramClass()) {
-        return true;
-      }
-    }
-    return false;
-  }
+  // Collects all potentially overloaded method signatures that reference at least one type that
+  // may be the source or target of a merge operation.
+  private class OverloadedMethodSignaturesRetriever {
+    private final Reference2BooleanOpenHashMap<DexProto> cache =
+        new Reference2BooleanOpenHashMap<>();
+    private final Equivalence<DexMethod> equivalence = MethodSignatureEquivalence.get();
+    private final Set<DexType> mergeeCandidates = new HashSet<>();
 
-  private boolean removeNonProgram(DexMethod dexMethod) {
-    for (DexType type : dexMethod.proto.parameters.values) {
-      if (isProgramClass(type)) {
-        return true;
+    public OverloadedMethodSignaturesRetriever() {
+      for (DexProgramClass mergeCandidate : mergeCandidates) {
+        mergeeCandidates.add(mergeCandidate.type.getSingleSubtype());
       }
     }
-    return isProgramClass(dexMethod.proto.returnType);
+
+    public Collection<DexMethod> get() {
+      Map<DexString, DexProto> overloadingInfo = new HashMap<>();
+
+      // Find all signatures that may reference a type that could be the source or target of a
+      // merge operation.
+      Set<Wrapper<DexMethod>> filteredSignatures = new HashSet<>();
+      for (DexMethod signature : appInfo.targetedMethods) {
+        DexClass definition = appInfo.definitionFor(signature.holder);
+        if (definition != null
+            && definition.isProgramClass()
+            && protoMayReferenceMergedSourceOrTarget(signature.proto)) {
+          filteredSignatures.add(equivalence.wrap(signature));
+
+          // Record that we have seen a method named [signature.name] with the proto
+          // [signature.proto]. If at some point, we find a method with the same name, but a
+          // different proto, it could be the case that a method with the given name is overloaded.
+          DexProto existing =
+              overloadingInfo.computeIfAbsent(signature.name, key -> signature.proto);
+          if (!existing.equals(signature.proto)) {
+            // Mark that this signature is overloaded by mapping it to SENTINEL.
+            overloadingInfo.put(signature.name, DexProto.SENTINEL);
+          }
+        }
+      }
+
+      List<DexMethod> result = new ArrayList<>();
+      for (Wrapper<DexMethod> wrappedSignature : filteredSignatures) {
+        DexMethod signature = wrappedSignature.get();
+
+        // Ignore those method names that are definitely not overloaded since they cannot lead to
+        // any collisions.
+        if (overloadingInfo.get(signature.name) == DexProto.SENTINEL) {
+          result.add(signature);
+        }
+      }
+      return result;
+    }
+
+    private boolean protoMayReferenceMergedSourceOrTarget(DexProto proto) {
+      boolean result;
+      if (cache.containsKey(proto)) {
+        result = cache.getBoolean(proto);
+      } else {
+        result = false;
+        if (typeMayReferenceMergedSourceOrTarget(proto.returnType)) {
+          result = true;
+        } else {
+          for (DexType type : proto.parameters.values) {
+            if (typeMayReferenceMergedSourceOrTarget(type)) {
+              result = true;
+              break;
+            }
+          }
+        }
+        cache.put(proto, result);
+      }
+      return result;
+    }
+
+    private boolean typeMayReferenceMergedSourceOrTarget(DexType type) {
+      if (type.isArrayType()) {
+        type = type.toBaseType(appInfo.dexItemFactory);
+      }
+      if (type.isClassType()) {
+        if (mergeeCandidates.contains(type)) {
+          return true;
+        }
+        DexClass clazz = appInfo.definitionFor(type);
+        if (clazz != null && clazz.isProgramClass()) {
+          return mergeCandidates.contains(clazz.asProgramClass());
+        }
+      }
+      return false;
+    }
   }
 
   public GraphLense run() {
@@ -408,19 +522,21 @@
     return result;
   }
 
-  private void addAncestorsToWorklist(
+  private void addCandidateAncestorsToWorklist(
       DexProgramClass clazz, Deque<DexProgramClass> worklist, Set<DexProgramClass> seenBefore) {
     if (seenBefore.contains(clazz)) {
       return;
     }
 
-    worklist.addFirst(clazz);
+    if (mergeCandidates.contains(clazz)) {
+      worklist.addFirst(clazz);
+    }
 
     // Add super classes to worklist.
     if (clazz.superType != null) {
       DexClass definition = appInfo.definitionFor(clazz.superType);
       if (definition != null && definition.isProgramClass()) {
-        addAncestorsToWorklist(definition.asProgramClass(), worklist, seenBefore);
+        addCandidateAncestorsToWorklist(definition.asProgramClass(), worklist, seenBefore);
       }
     }
 
@@ -428,71 +544,60 @@
     for (DexType interfaceType : clazz.interfaces.values) {
       DexClass definition = appInfo.definitionFor(interfaceType);
       if (definition != null && definition.isProgramClass()) {
-        addAncestorsToWorklist(definition.asProgramClass(), worklist, seenBefore);
+        addCandidateAncestorsToWorklist(definition.asProgramClass(), worklist, seenBefore);
       }
     }
   }
 
   private GraphLense mergeClasses(GraphLense graphLense) {
-    Iterable<DexProgramClass> classes = application.classesWithDeterministicOrder();
     Deque<DexProgramClass> worklist = new ArrayDeque<>();
     Set<DexProgramClass> seenBefore = new HashSet<>();
 
     int numberOfMerges = 0;
 
-    // Types that are pinned (in addition to those where appInfo.isPinned returns true).
-    Set<DexType> pinnedTypes = getPinnedTypes(classes);
-
-    Iterator<DexProgramClass> classIterator = classes.iterator();
+    Iterator<DexProgramClass> candidatesIterator = mergeCandidates.iterator();
 
     // Visit the program classes in a top-down order according to the class hierarchy.
-    while (classIterator.hasNext() || !worklist.isEmpty()) {
+    while (candidatesIterator.hasNext() || !worklist.isEmpty()) {
       if (worklist.isEmpty()) {
         // Add the ancestors of this class (including the class itself) to the worklist in such a
         // way that all super types of the class come before the class itself.
-        addAncestorsToWorklist(classIterator.next(), worklist, seenBefore);
+        addCandidateAncestorsToWorklist(candidatesIterator.next(), worklist, seenBefore);
         if (worklist.isEmpty()) {
           continue;
         }
       }
 
       DexProgramClass clazz = worklist.removeFirst();
-      if (!seenBefore.add(clazz) || !isMergeCandidate(clazz, pinnedTypes)) {
+      assert isMergeCandidate(clazz, pinnedTypes);
+      if (!seenBefore.add(clazz)) {
         continue;
       }
 
-      DexClass targetClass = appInfo.definitionFor(clazz.type.getSingleSubtype());
+      DexProgramClass targetClass =
+          appInfo.definitionFor(clazz.type.getSingleSubtype()).asProgramClass();
       assert !mergedClasses.containsKey(targetClass.type);
-      if (clazz.hasClassInitializer() && targetClass.hasClassInitializer()) {
-        // TODO(herhut): Handle class initializers.
-        if (Log.ENABLED) {
-          AbortReason.STATIC_INITIALIZERS.printLogMessageForClass(clazz);
+
+      boolean clazzOrTargetClassHasBeenMerged =
+          mergedClassesInverse.containsKey(clazz.type)
+              || mergedClassesInverse.containsKey(targetClass.type);
+      if (clazzOrTargetClassHasBeenMerged) {
+        if (!isStillMergeCandidate(clazz)) {
+          continue;
         }
-        continue;
+      } else {
+        assert isStillMergeCandidate(clazz);
       }
-      if (methodResolutionMayChange(clazz, targetClass)) {
-        if (Log.ENABLED) {
-          AbortReason.RESOLUTION_FOR_METHODS_MAY_CHANGE.printLogMessageForClass(clazz);
-        }
-        continue;
-      }
-      // Field resolution first considers the direct interfaces of [targetClass] before it proceeds
-      // to the super class.
-      if (fieldResolutionMayChange(clazz, targetClass)) {
-        if (Log.ENABLED) {
-          AbortReason.RESOLUTION_FOR_FIELDS_MAY_CHANGE.printLogMessageForClass(clazz);
-        }
-        continue;
-      }
+
       // Guard against the case where we have two methods that may get the same signature
       // if we replace types. This is rare, so we approximate and err on the safe side here.
-      if (new CollisionDetector(clazz.type, targetClass.type, getInvokes(), mergedClasses)
-          .mayCollide()) {
+      if (new CollisionDetector(clazz.type, targetClass.type).mayCollide()) {
         if (Log.ENABLED) {
           AbortReason.CONFLICT.printLogMessageForClass(clazz);
         }
         continue;
       }
+
       ClassMerger merger = new ClassMerger(clazz, targetClass);
       boolean merged = merger.merge();
       if (merged) {
@@ -610,14 +715,6 @@
     }
 
     public boolean merge() {
-      if (source.getEnclosingMethod() != null || !source.getInnerClasses().isEmpty()
-          || target.getEnclosingMethod() != null || !target.getInnerClasses().isEmpty()) {
-        // TODO(herhut): Consider supporting merging of inner-class attributes.
-        if (Log.ENABLED) {
-          AbortReason.UNSUPPORTED_ATTRIBUTES.printLogMessageForClass(source);
-        }
-        return false;
-      }
       // Merge the class [clazz] into [targetClass] by adding all methods to
       // targetClass that are not currently contained.
       // Step 1: Merge methods
@@ -1157,22 +1254,18 @@
 
   private class CollisionDetector {
 
-    private static final int NOT_FOUND = 1 << (Integer.SIZE - 1);
+    private static final int NOT_FOUND = Integer.MIN_VALUE;
 
     // TODO(herhut): Maybe cache seenPositions for target classes.
     private final Map<DexString, Int2IntMap> seenPositions = new IdentityHashMap<>();
     private final Reference2IntMap<DexProto> targetProtoCache;
     private final Reference2IntMap<DexProto> sourceProtoCache;
     private final DexType source, target;
-    private final Collection<DexMethod> invokes;
-    private final Map<DexType, DexType> substituions;
+    private final Collection<DexMethod> invokes = getInvokes();
 
-    private CollisionDetector(DexType source, DexType target, Collection<DexMethod> invokes,
-        Map<DexType, DexType> substitutions) {
+    private CollisionDetector(DexType source, DexType target) {
       this.source = source;
       this.target = target;
-      this.invokes = invokes;
-      this.substituions = substitutions;
       this.targetProtoCache = new Reference2IntOpenHashMap<>(invokes.size() / 2);
       this.targetProtoCache.defaultReturnValue(NOT_FOUND);
       this.sourceProtoCache = new Reference2IntOpenHashMap<>(invokes.size() / 2);
@@ -1181,7 +1274,7 @@
 
     boolean mayCollide() {
       timing.begin("collision detection");
-      fillSeenPositions(invokes);
+      fillSeenPositions();
       boolean result = false;
       // If the type is not used in methods at all, there cannot be any conflict.
       if (!seenPositions.isEmpty()) {
@@ -1192,8 +1285,7 @@
             int previous = positionsMap.get(arity);
             if (previous != NOT_FOUND) {
               assert previous != 0;
-              int positions =
-                  computePositionsFor(method.proto, source, sourceProtoCache, substituions);
+              int positions = computePositionsFor(method.proto, source, sourceProtoCache);
               if ((positions & previous) != 0) {
                 result = true;
                 break;
@@ -1206,11 +1298,11 @@
       return result;
     }
 
-    private void fillSeenPositions(Collection<DexMethod> invokes) {
+    private void fillSeenPositions() {
       for (DexMethod method : invokes) {
         DexType[] parameters = method.proto.parameters.values;
         int arity = parameters.length;
-        int positions = computePositionsFor(method.proto, target, targetProtoCache, substituions);
+        int positions = computePositionsFor(method.proto, target, targetProtoCache);
         if (positions != 0) {
           Int2IntMap positionsMap =
               seenPositions.computeIfAbsent(method.name, k -> {
@@ -1230,8 +1322,10 @@
 
     }
 
-    private int computePositionsFor(DexProto proto, DexType type,
-        Reference2IntMap<DexProto> cache, Map<DexType, DexType> substitutions) {
+    // Given a method signature and a type, this method computes a bit vector that denotes the
+    // positions at which the given type is used in the method signature.
+    private int computePositionsFor(
+        DexProto proto, DexType type, Reference2IntMap<DexProto> cache) {
       int result = cache.getInt(proto);
       if (result != NOT_FOUND) {
         return result;
@@ -1240,13 +1334,8 @@
       int bitsUsed = 0;
       int accumulator = 0;
       for (DexType aType : proto.parameters.values) {
-        if (substitutions != null) {
-          // Substitute the type with the already merged class to estimate what it will
-          // look like.
-          while (substitutions.containsKey(aType)) {
-            aType = substitutions.get(aType);
-          }
-        }
+        // Substitute the type with the already merged class to estimate what it will look like.
+        aType = mergedClasses.getOrDefault(aType, aType);
         accumulator <<= 1;
         bitsUsed++;
         if (aType == type) {
@@ -1260,12 +1349,7 @@
         }
       }
       // We also take the return type into account for potential conflicts.
-      DexType returnType = proto.returnType;
-      if (substitutions != null) {
-        while (substitutions.containsKey(returnType)) {
-          returnType = substitutions.get(returnType);
-        }
-      }
+      DexType returnType = mergedClasses.getOrDefault(proto.returnType, proto.returnType);
       accumulator <<= 1;
       if (returnType == type) {
         accumulator |= 1;