Compute classes to merge up front

Change-Id: Iac79064bb64bff8dfcc5e510de2647643fb9cb74
diff --git a/src/main/java/com/android/tools/r8/verticalclassmerging/CollisionDetector.java b/src/main/java/com/android/tools/r8/verticalclassmerging/CollisionDetector.java
deleted file mode 100644
index 961f719..0000000
--- a/src/main/java/com/android/tools/r8/verticalclassmerging/CollisionDetector.java
+++ /dev/null
@@ -1,138 +0,0 @@
-// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-package com.android.tools.r8.verticalclassmerging;
-
-import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexItemFactory;
-import com.android.tools.r8.graph.DexMethod;
-import com.android.tools.r8.graph.DexProto;
-import com.android.tools.r8.graph.DexString;
-import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.utils.Timing;
-import it.unimi.dsi.fastutil.ints.Int2IntMap;
-import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
-import it.unimi.dsi.fastutil.objects.Reference2IntMap;
-import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
-import java.util.Collection;
-import java.util.IdentityHashMap;
-import java.util.Map;
-
-class CollisionDetector {
-
-  private static final int NOT_FOUND = Integer.MIN_VALUE;
-
-  private final DexItemFactory dexItemFactory;
-  private final Collection<DexMethod> invokes;
-
-  private final DexType source;
-  private final Reference2IntMap<DexProto> sourceProtoCache;
-
-  private final DexType target;
-  private final Reference2IntMap<DexProto> targetProtoCache;
-
-  private final Map<DexString, Int2IntMap> seenPositions = new IdentityHashMap<>();
-
-  CollisionDetector(
-      AppView<AppInfoWithLiveness> appView,
-      Collection<DexMethod> invokes,
-      DexType source,
-      DexType target) {
-    this.dexItemFactory = appView.dexItemFactory();
-    this.invokes = invokes;
-    this.source = source;
-    this.sourceProtoCache = new Reference2IntOpenHashMap<>(invokes.size() / 2);
-    this.sourceProtoCache.defaultReturnValue(NOT_FOUND);
-    this.target = target;
-    this.targetProtoCache = new Reference2IntOpenHashMap<>(invokes.size() / 2);
-    this.targetProtoCache.defaultReturnValue(NOT_FOUND);
-  }
-
-  boolean mayCollide(Timing timing) {
-    timing.begin("collision detection");
-    fillSeenPositions();
-    boolean result = false;
-    // If the type is not used in methods at all, there cannot be any conflict.
-    if (!seenPositions.isEmpty()) {
-      for (DexMethod method : invokes) {
-        Int2IntMap positionsMap = seenPositions.get(method.getName());
-        if (positionsMap != null) {
-          int arity = method.getArity();
-          int previous = positionsMap.get(arity);
-          if (previous != NOT_FOUND) {
-            assert previous != 0;
-            int positions = computePositionsFor(method.getProto(), source, sourceProtoCache);
-            if ((positions & previous) != 0) {
-              result = true;
-              break;
-            }
-          }
-        }
-      }
-    }
-    timing.end();
-    return result;
-  }
-
-  private void fillSeenPositions() {
-    for (DexMethod method : invokes) {
-      int arity = method.getArity();
-      int positions = computePositionsFor(method.getProto(), target, targetProtoCache);
-      if (positions != 0) {
-        Int2IntMap positionsMap =
-            seenPositions.computeIfAbsent(
-                method.getName(),
-                k -> {
-                  Int2IntMap result = new Int2IntOpenHashMap();
-                  result.defaultReturnValue(NOT_FOUND);
-                  return result;
-                });
-        int value = 0;
-        int previous = positionsMap.get(arity);
-        if (previous != NOT_FOUND) {
-          value = previous;
-        }
-        value |= positions;
-        positionsMap.put(arity, value);
-      }
-    }
-  }
-
-  // Given a method signature and a type, this method computes a bit vector that denotes the
-  // positions at which the given type is used in the method signature.
-  private int computePositionsFor(DexProto proto, DexType type, Reference2IntMap<DexProto> cache) {
-    int result = cache.getInt(proto);
-    if (result != NOT_FOUND) {
-      return result;
-    }
-    result = 0;
-    int bitsUsed = 0;
-    int accumulator = 0;
-    for (DexType parameterBaseType : proto.getParameterBaseTypes(dexItemFactory)) {
-      // Substitute the type with the already merged class to estimate what it will look like.
-      DexType mappedType = parameterBaseType;
-      accumulator <<= 1;
-      bitsUsed++;
-      if (mappedType.isIdenticalTo(type)) {
-        accumulator |= 1;
-      }
-      // Handle overflow on 31 bit boundary.
-      if (bitsUsed == Integer.SIZE - 1) {
-        result |= accumulator;
-        accumulator = 0;
-        bitsUsed = 0;
-      }
-    }
-    // We also take the return type into account for potential conflicts.
-    DexType returnBaseType = proto.getReturnType().toBaseType(dexItemFactory);
-    DexType mappedReturnType = returnBaseType;
-    accumulator <<= 1;
-    if (mappedReturnType.isIdenticalTo(type)) {
-      accumulator |= 1;
-    }
-    result |= accumulator;
-    cache.put(proto, result);
-    return result;
-  }
-}
diff --git a/src/main/java/com/android/tools/r8/verticalclassmerging/ConnectedComponentVerticalClassMerger.java b/src/main/java/com/android/tools/r8/verticalclassmerging/ConnectedComponentVerticalClassMerger.java
index a823a1a..f984383 100644
--- a/src/main/java/com/android/tools/r8/verticalclassmerging/ConnectedComponentVerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/verticalclassmerging/ConnectedComponentVerticalClassMerger.java
@@ -3,43 +3,21 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.verticalclassmerging;
 
-import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
-
 import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexClass;
-import com.android.tools.r8.graph.DexEncodedMethod;
-import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexProto;
-import com.android.tools.r8.graph.DexString;
-import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
-import com.android.tools.r8.graph.TopDownClassHierarchyTraversal;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.MainDexInfo;
 import com.android.tools.r8.utils.ListUtils;
-import com.android.tools.r8.utils.MethodSignatureEquivalence;
-import com.android.tools.r8.utils.Timing;
-import com.google.common.base.Equivalence;
-import com.google.common.base.Equivalence.Wrapper;
-import com.google.common.collect.Iterables;
-import it.unimi.dsi.fastutil.objects.Reference2BooleanOpenHashMap;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
-import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 
 public class ConnectedComponentVerticalClassMerger {
 
   private final AppView<AppInfoWithLiveness> appView;
-  private final MainDexInfo mainDexInfo;
-
-  private Collection<DexMethod> invokes;
+  private final Set<DexProgramClass> classesToMerge;
 
   // The resulting graph lens that should be used after class merging.
   private final VerticalClassMergerGraphLens.Builder lensBuilder;
@@ -50,190 +28,46 @@
   private final VerticallyMergedClasses.Builder verticallyMergedClassesBuilder =
       VerticallyMergedClasses.builder();
 
-  ConnectedComponentVerticalClassMerger(AppView<AppInfoWithLiveness> appView) {
+  ConnectedComponentVerticalClassMerger(
+      AppView<AppInfoWithLiveness> appView, Set<DexProgramClass> classesToMerge) {
     this.appView = appView;
-    this.mainDexInfo = appView.appInfo().getMainDexInfo();
+    this.classesToMerge = classesToMerge;
     this.lensBuilder = new VerticalClassMergerGraphLens.Builder(appView);
   }
 
-  public VerticalClassMergerResult.Builder run(
-      Set<DexProgramClass> connectedComponent,
-      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
-      Set<DexProgramClass> pinnedClasses,
-      Timing timing)
+  public boolean isEmpty() {
+    return classesToMerge.isEmpty();
+  }
+
+  public VerticalClassMergerResult.Builder run(ImmediateProgramSubtypingInfo immediateSubtypingInfo)
       throws ExecutionException {
-    // Visit the program classes in a top-down order according to the class hierarchy.
-    VerticalClassMergerPolicyExecutor policyExecutor =
-        new VerticalClassMergerPolicyExecutor(
-            appView, pinnedClasses, verticallyMergedClassesBuilder);
-    Set<DexProgramClass> mergeCandidates =
-        policyExecutor.run(connectedComponent, immediateSubtypingInfo);
-    List<DexProgramClass> mergeCandidatesSorted =
-        ListUtils.sort(mergeCandidates, Comparator.comparing(DexProgramClass::getType));
-    TopDownClassHierarchyTraversal.forProgramClasses(appView)
-        .visit(
-            mergeCandidatesSorted,
-            clazz ->
-                mergeClassIfPossible(
-                    clazz, immediateSubtypingInfo, mergeCandidates, policyExecutor, timing));
+    List<DexProgramClass> classesToMergeSorted =
+        ListUtils.sort(classesToMerge, Comparator.comparing(DexProgramClass::getType));
+    for (DexProgramClass clazz : classesToMergeSorted) {
+      mergeClassIfPossible(clazz, immediateSubtypingInfo);
+    }
     return VerticalClassMergerResult.builder(
         lensBuilder, synthesizedBridges, verticallyMergedClassesBuilder);
   }
 
   private void mergeClassIfPossible(
-      DexProgramClass clazz,
-      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
-      Set<DexProgramClass> mergeCandidates,
-      VerticalClassMergerPolicyExecutor policyExecutor,
-      Timing timing)
+      DexProgramClass sourceClass, ImmediateProgramSubtypingInfo immediateSubtypingInfo)
       throws ExecutionException {
-    if (!mergeCandidates.contains(clazz)) {
-      return;
-    }
-    List<DexProgramClass> subclasses = immediateSubtypingInfo.getSubclasses(clazz);
-    if (subclasses.size() != 1) {
-      return;
-    }
+    List<DexProgramClass> subclasses = immediateSubtypingInfo.getSubclasses(sourceClass);
+    assert subclasses.size() == 1;
     DexProgramClass targetClass = ListUtils.first(subclasses);
-    assert !verticallyMergedClassesBuilder.isMergeSource(targetClass);
-    if (verticallyMergedClassesBuilder.isMergeTarget(clazz)) {
+    if (verticallyMergedClassesBuilder.isMergeSource(targetClass)
+        || verticallyMergedClassesBuilder.isMergeTarget(sourceClass)) {
       return;
     }
-    if (verticallyMergedClassesBuilder.isMergeTarget(targetClass)) {
-      if (!policyExecutor.isStillMergeCandidate(clazz, targetClass)) {
-        return;
-      }
-    } else {
-      assert policyExecutor.isStillMergeCandidate(clazz, targetClass);
-    }
-
-    // Guard against the case where we have two methods that may get the same signature
-    // if we replace types. This is rare, so we approximate and err on the safe side here.
-    CollisionDetector collisionDetector =
-        new CollisionDetector(
-            appView,
-            getInvokes(immediateSubtypingInfo, mergeCandidates),
-            clazz.getType(),
-            targetClass.getType());
-    if (collisionDetector.mayCollide(timing)) {
-      return;
-    }
-
-    // Check with main dex classes to see if we are allowed to merge.
-    if (!mainDexInfo.canMerge(clazz, targetClass, appView.getSyntheticItems())) {
-      return;
-    }
-
     ClassMerger merger =
-        new ClassMerger(appView, lensBuilder, verticallyMergedClassesBuilder, clazz, targetClass);
+        new ClassMerger(
+            appView, lensBuilder, verticallyMergedClassesBuilder, sourceClass, targetClass);
     if (merger.merge()) {
-      verticallyMergedClassesBuilder.add(clazz, targetClass);
+      verticallyMergedClassesBuilder.add(sourceClass, targetClass);
       // Commit the changes to the graph lens.
       lensBuilder.merge(merger.getRenamings());
       synthesizedBridges.addAll(merger.getSynthesizedBridges());
     }
   }
-
-  private Collection<DexMethod> getInvokes(
-      ImmediateProgramSubtypingInfo immediateSubtypingInfo, Set<DexProgramClass> mergeCandidates) {
-    if (invokes == null) {
-      invokes =
-          new OverloadedMethodSignaturesRetriever(immediateSubtypingInfo, mergeCandidates)
-              .get(mergeCandidates);
-    }
-    return invokes;
-  }
-
-  // Collects all potentially overloaded method signatures that reference at least one type that
-  // may be the source or target of a merge operation.
-  private class OverloadedMethodSignaturesRetriever {
-    private final Reference2BooleanOpenHashMap<DexProto> cache =
-        new Reference2BooleanOpenHashMap<>();
-    private final Equivalence<DexMethod> equivalence = MethodSignatureEquivalence.get();
-    private final Set<DexType> mergeeCandidates = new HashSet<>();
-
-    public OverloadedMethodSignaturesRetriever(
-        ImmediateProgramSubtypingInfo immediateSubtypingInfo,
-        Set<DexProgramClass> mergeCandidates) {
-      for (DexProgramClass mergeCandidate : mergeCandidates) {
-        List<DexProgramClass> subclasses = immediateSubtypingInfo.getSubclasses(mergeCandidate);
-        if (subclasses.size() == 1) {
-          mergeeCandidates.add(ListUtils.first(subclasses).getType());
-        }
-      }
-    }
-
-    public Collection<DexMethod> get(Set<DexProgramClass> mergeCandidates) {
-      Map<DexString, DexProto> overloadingInfo = new HashMap<>();
-
-      // Find all signatures that may reference a type that could be the source or target of a
-      // merge operation.
-      Set<Wrapper<DexMethod>> filteredSignatures = new HashSet<>();
-      for (DexProgramClass clazz : appView.appInfo().classes()) {
-        for (DexEncodedMethod encodedMethod : clazz.methods()) {
-          DexMethod method = encodedMethod.getReference();
-          DexClass definition = appView.definitionFor(method.getHolderType());
-          if (definition != null
-              && definition.isProgramClass()
-              && protoMayReferenceMergedSourceOrTarget(method.getProto(), mergeCandidates)) {
-            filteredSignatures.add(equivalence.wrap(method));
-
-            // Record that we have seen a method named [signature.name] with the proto
-            // [signature.proto]. If at some point, we find a method with the same name, but a
-            // different proto, it could be the case that a method with the given name is
-            // overloaded.
-            DexProto existing =
-                overloadingInfo.computeIfAbsent(method.getName(), key -> method.getProto());
-            if (existing.isNotIdenticalTo(DexProto.SENTINEL)
-                && !existing.equals(method.getProto())) {
-              // Mark that this signature is overloaded by mapping it to SENTINEL.
-              overloadingInfo.put(method.getName(), DexProto.SENTINEL);
-            }
-          }
-        }
-      }
-
-      List<DexMethod> result = new ArrayList<>();
-      for (Wrapper<DexMethod> wrappedSignature : filteredSignatures) {
-        DexMethod signature = wrappedSignature.get();
-
-        // Ignore those method names that are definitely not overloaded since they cannot lead to
-        // any collisions.
-        if (overloadingInfo.get(signature.getName()).isIdenticalTo(DexProto.SENTINEL)) {
-          result.add(signature);
-        }
-      }
-      return result;
-    }
-
-    private boolean protoMayReferenceMergedSourceOrTarget(
-        DexProto proto, Set<DexProgramClass> mergeCandidates) {
-      boolean result;
-      if (cache.containsKey(proto)) {
-        result = cache.getBoolean(proto);
-      } else {
-        result =
-            Iterables.any(
-                proto.getTypes(),
-                type -> typeMayReferenceMergedSourceOrTarget(type, mergeCandidates));
-        cache.put(proto, result);
-      }
-      return result;
-    }
-
-    private boolean typeMayReferenceMergedSourceOrTarget(
-        DexType type, Set<DexProgramClass> mergeCandidates) {
-      type = type.toBaseType(appView.dexItemFactory());
-      if (type.isClassType()) {
-        if (mergeeCandidates.contains(type)) {
-          return true;
-        }
-        DexProgramClass clazz = asProgramClassOrNull(appView.definitionFor(type));
-        if (clazz != null) {
-          return mergeCandidates.contains(clazz.asProgramClass());
-        }
-      }
-      return false;
-    }
-  }
 }
diff --git a/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMerger.java
index ef70058..b70bac7 100644
--- a/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMerger.java
@@ -159,16 +159,28 @@
     timing.begin("Setup");
     ImmediateProgramSubtypingInfo immediateSubtypingInfo =
         ImmediateProgramSubtypingInfo.create(appView);
+
+    // Compute the disjoint class hierarchies for parallel processing.
     List<Set<DexProgramClass>> connectedComponents =
         new ProgramClassesBidirectedGraph(appView, immediateSubtypingInfo)
             .computeStronglyConnectedComponents();
-    Set<DexProgramClass> pinnedClasses = getPinnedClasses();
+
+    // Remove singleton class hierarchies as they are not subject to vertical class merging.
+    Set<DexProgramClass> singletonComponents = Sets.newIdentityHashSet();
+    connectedComponents.removeIf(
+        connectedComponent -> {
+          if (connectedComponent.size() == 1) {
+            singletonComponents.addAll(connectedComponent);
+            return true;
+          }
+          return false;
+        });
     timing.end();
 
     // Apply class merging concurrently in disjoint class hierarchies.
     VerticalClassMergerResult verticalClassMergerResult =
         mergeClassesInConnectedComponents(
-            connectedComponents, immediateSubtypingInfo, pinnedClasses, executorService, timing);
+            connectedComponents, immediateSubtypingInfo, executorService, timing);
     appView.setVerticallyMergedClasses(verticalClassMergerResult.getVerticallyMergedClasses());
     if (verticalClassMergerResult.isEmpty()) {
       return;
@@ -188,21 +200,65 @@
   private VerticalClassMergerResult mergeClassesInConnectedComponents(
       List<Set<DexProgramClass>> connectedComponents,
       ImmediateProgramSubtypingInfo immediateSubtypingInfo,
-      Set<DexProgramClass> pinnedClasses,
       ExecutorService executorService,
       Timing timing)
       throws ExecutionException {
-    VerticalClassMergerResult.Builder verticalClassMergerResult =
-        VerticalClassMergerResult.builder(appView);
-    TimingMerger merger = timing.beginMerger("Merge classes", executorService);
+    Collection<ConnectedComponentVerticalClassMerger> connectedComponentMergers =
+        getConnectedComponentMergers(
+            connectedComponents, immediateSubtypingInfo, executorService, timing);
+    return applyConnectedComponentMergers(
+        connectedComponentMergers, immediateSubtypingInfo, executorService, timing);
+  }
+
+  private Collection<ConnectedComponentVerticalClassMerger> getConnectedComponentMergers(
+      List<Set<DexProgramClass>> connectedComponents,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
+      ExecutorService executorService,
+      Timing timing)
+      throws ExecutionException {
+    TimingMerger merger = timing.beginMerger("Compute classes to merge", executorService);
+    List<ConnectedComponentVerticalClassMerger> connectedComponentMergers =
+        new ArrayList<>(connectedComponents.size());
+    Set<DexProgramClass> pinnedClasses = getPinnedClasses();
     Collection<Timing> timings =
         ThreadUtils.processItemsWithResults(
             connectedComponents,
             connectedComponent -> {
+              Timing threadTiming = Timing.create("Compute classes to merge in component", options);
+              ConnectedComponentVerticalClassMerger connectedComponentMerger =
+                  new VerticalClassMergerPolicyExecutor(appView, pinnedClasses)
+                      .run(connectedComponent, immediateSubtypingInfo);
+              if (!connectedComponentMerger.isEmpty()) {
+                synchronized (connectedComponentMergers) {
+                  connectedComponentMergers.add(connectedComponentMerger);
+                }
+              }
+              threadTiming.end();
+              return threadTiming;
+            },
+            appView.options().getThreadingModule(),
+            executorService);
+    merger.add(timings);
+    merger.end();
+    return connectedComponentMergers;
+  }
+
+  private VerticalClassMergerResult applyConnectedComponentMergers(
+      Collection<ConnectedComponentVerticalClassMerger> connectedComponentMergers,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
+      ExecutorService executorService,
+      Timing timing)
+      throws ExecutionException {
+    TimingMerger merger = timing.beginMerger("Merge classes", executorService);
+    VerticalClassMergerResult.Builder verticalClassMergerResult =
+        VerticalClassMergerResult.builder(appView);
+    Collection<Timing> timings =
+        ThreadUtils.processItemsWithResults(
+            connectedComponentMergers,
+            connectedComponentMerger -> {
               Timing threadTiming = Timing.create("Merge classes in component", options);
               VerticalClassMergerResult.Builder verticalClassMergerComponentResult =
-                  new ConnectedComponentVerticalClassMerger(appView)
-                      .run(connectedComponent, immediateSubtypingInfo, pinnedClasses, threadTiming);
+                  connectedComponentMerger.run(immediateSubtypingInfo);
               verticalClassMergerResult.merge(verticalClassMergerComponentResult);
               threadTiming.end();
               return threadTiming;
diff --git a/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMergerPolicyExecutor.java b/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMergerPolicyExecutor.java
index e69e96e..45796d3 100644
--- a/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMergerPolicyExecutor.java
+++ b/src/main/java/com/android/tools/r8/verticalclassmerging/VerticalClassMergerPolicyExecutor.java
@@ -45,20 +45,16 @@
   private final InternalOptions options;
   private final MainDexInfo mainDexInfo;
   private final Set<DexProgramClass> pinnedClasses;
-  private final VerticallyMergedClasses.Builder verticallyMergedClassesBuilder;
 
   VerticalClassMergerPolicyExecutor(
-      AppView<AppInfoWithLiveness> appView,
-      Set<DexProgramClass> pinnedClasses,
-      VerticallyMergedClasses.Builder verticallyMergedClassesInComponentBuilder) {
+      AppView<AppInfoWithLiveness> appView, Set<DexProgramClass> pinnedClasses) {
     this.appView = appView;
     this.options = appView.options();
     this.mainDexInfo = appView.appInfo().getMainDexInfo();
     this.pinnedClasses = pinnedClasses;
-    this.verticallyMergedClassesBuilder = verticallyMergedClassesInComponentBuilder;
   }
 
-  Set<DexProgramClass> run(
+  ConnectedComponentVerticalClassMerger run(
       Set<DexProgramClass> connectedComponent,
       ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
     Set<DexProgramClass> mergeCandidates = Sets.newIdentityHashSet();
@@ -79,7 +75,7 @@
       }
       mergeCandidates.add(sourceClass);
     }
-    return mergeCandidates;
+    return new ConnectedComponentVerticalClassMerger(appView, mergeCandidates);
   }
 
   // Returns true if [clazz] is a merge candidate. Note that the result of the checks in this
@@ -161,6 +157,12 @@
         return false;
       }
     }
+
+    // Check with main dex classes to see if we are allowed to merge.
+    if (!mainDexInfo.canMerge(sourceClass, targetClass, appView.getSyntheticItems())) {
+      return false;
+    }
+
     return true;
   }
 
@@ -170,7 +172,6 @@
    * called before merging {@param sourceClass} into {@param targetClass}.
    */
   boolean isStillMergeCandidate(DexProgramClass sourceClass, DexProgramClass targetClass) {
-    assert !verticallyMergedClassesBuilder.isMergeTarget(sourceClass);
     // For interface types, this is more complicated, see:
     // https://docs.oracle.com/javase/specs/jvms/se9/html/jvms-5.html#jvms-5.5
     // We basically can't move the clinit, since it is not called when implementing classes have