Improve precision of deadlock after class merging detection

Bug: 205611444
Change-Id: I2b41b07c8c4c53354fe9dc60a940adef48ae7650
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index e16a15f..202f78c 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -83,7 +83,8 @@
     // Run the policies on all program classes to produce a final grouping.
     List<Policy> policies =
         PolicyScheduler.getPolicies(appView, codeProvider, mode, runtimeTypeCheckInfo);
-    Collection<MergeGroup> groups = new PolicyExecutor().run(getInitialGroups(), policies, timing);
+    Collection<MergeGroup> groups =
+        new PolicyExecutor().run(getInitialGroups(), policies, executorService, timing);
 
     // If there are no groups, then end horizontal class merging.
     if (groups.isEmpty()) {
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicyWithPreprocessing.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicyWithPreprocessing.java
index 067fcd5..d634479 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicyWithPreprocessing.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicyWithPreprocessing.java
@@ -5,6 +5,8 @@
 package com.android.tools.r8.horizontalclassmerging;
 
 import java.util.Collection;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 
 public abstract class MultiClassPolicyWithPreprocessing<T> extends Policy {
 
@@ -12,14 +14,15 @@
    * Apply the multi class policy to a group of program classes.
    *
    * @param group This is a group of program classes which can currently still be merged.
-   * @param data The result of calling {@link #preprocess(Collection)}.
+   * @param data The result of calling {@link #preprocess(Collection, ExecutorService)}.
    * @return The same collection of program classes split into new groups of candidates which can be
    *     merged. If the policy detects no issues then `group` will be returned unchanged. If classes
    *     cannot be merged with any other classes they are returned as singleton lists.
    */
   public abstract Collection<MergeGroup> apply(MergeGroup group, T data);
 
-  public abstract T preprocess(Collection<MergeGroup> groups);
+  public abstract T preprocess(Collection<MergeGroup> groups, ExecutorService executorService)
+      throws ExecutionException;
 
   @Override
   public boolean isMultiClassPolicyWithPreprocessing() {
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/PolicyExecutor.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/PolicyExecutor.java
index c1c4f64..c206e75 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/PolicyExecutor.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/PolicyExecutor.java
@@ -9,6 +9,8 @@
 import java.util.Collection;
 import java.util.Iterator;
 import java.util.LinkedList;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 
 /**
  * This is a simple policy executor that ensures regular sequential execution of policies. It should
@@ -50,9 +52,12 @@
   }
 
   private <T> LinkedList<MergeGroup> applyMultiClassPolicyWithPreprocessing(
-      MultiClassPolicyWithPreprocessing<T> policy, LinkedList<MergeGroup> groups) {
+      MultiClassPolicyWithPreprocessing<T> policy,
+      LinkedList<MergeGroup> groups,
+      ExecutorService executorService)
+      throws ExecutionException {
     // For each group apply the multi class policy and add all the new groups together.
-    T data = policy.preprocess(groups);
+    T data = policy.preprocess(groups, executorService);
     LinkedList<MergeGroup> newGroups = new LinkedList<>();
     groups.forEach(
         group -> {
@@ -73,7 +78,11 @@
    * class groups.
    */
   public Collection<MergeGroup> run(
-      Collection<MergeGroup> inputGroups, Collection<Policy> policies, Timing timing) {
+      Collection<MergeGroup> inputGroups,
+      Collection<Policy> policies,
+      ExecutorService executorService,
+      Timing timing)
+      throws ExecutionException {
     LinkedList<MergeGroup> linkedGroups;
 
     if (inputGroups instanceof LinkedList) {
@@ -96,7 +105,7 @@
         assert policy.isMultiClassPolicyWithPreprocessing();
         linkedGroups =
             applyMultiClassPolicyWithPreprocessing(
-                policy.asMultiClassPolicyWithPreprocessing(), linkedGroups);
+                policy.asMultiClassPolicyWithPreprocessing(), linkedGroups, executorService);
       }
       timing.end();
 
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoClassInitializerCycles.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoClassInitializerCycles.java
index f8faefd..a6145ec 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoClassInitializerCycles.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoClassInitializerCycles.java
@@ -5,12 +5,12 @@
 package com.android.tools.r8.horizontalclassmerging.policies;
 
 import static com.android.tools.r8.graph.DexClassAndMethod.asProgramMethodOrNull;
-import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
 import static com.android.tools.r8.utils.MapUtils.ignoreKey;
 
 import com.android.tools.r8.code.CfOrDexInstruction;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexCallSite;
+import com.android.tools.r8.graph.DexClassAndMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
@@ -20,10 +20,12 @@
 import com.android.tools.r8.graph.UseRegistry;
 import com.android.tools.r8.horizontalclassmerging.MergeGroup;
 import com.android.tools.r8.horizontalclassmerging.MultiClassPolicyWithPreprocessing;
+import com.android.tools.r8.horizontalclassmerging.policies.deadlock.SingleCallerInformation;
 import com.android.tools.r8.ir.desugar.LambdaDescriptor;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.InternalOptions.HorizontalClassMergerOptions;
 import com.android.tools.r8.utils.SetUtils;
+import com.android.tools.r8.utils.TraversalContinuation;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Sets;
@@ -32,11 +34,14 @@
 import java.util.Collections;
 import java.util.Deque;
 import java.util.IdentityHashMap;
+import java.util.LinkedHashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.ListIterator;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 
 /**
  * Disallows merging of classes when the merging could introduce class initialization deadlocks.
@@ -93,23 +98,38 @@
   // Mapping from each merge candidate to its merge group.
   final Map<DexProgramClass, MergeGroup> allGroups = new IdentityHashMap<>();
 
+  private SingleCallerInformation singleCallerInformation;
+
   public NoClassInitializerCycles(AppView<AppInfoWithLiveness> appView) {
     this.appView = appView;
   }
 
   @Override
   public Collection<MergeGroup> apply(MergeGroup group, Void nothing) {
-    Tracer tracer = new Tracer(group);
-    removeClassesWithPossibleClassInitializerDeadlock(group, tracer);
-
+    // Partition the merge group into smaller groups that may be merged. If the class initialization
+    // of a parent class may initialize a member of the merge group, then this member is not
+    // eligible for class merging, unless the only way to class initialize this member is from the
+    // class initialization of the parent class. In this case, the member may be merged with other
+    // group members that are also guaranteed to only be class initialized from the class
+    // initialization of the parent class.
+    List<MergeGroup> partitioning = partitionClassesWithPossibleClassInitializerDeadlock(group);
     List<MergeGroup> newGroups = new LinkedList<>();
-    for (DexProgramClass clazz : group) {
-      MergeGroup newGroup = getOrCreateGroupFor(clazz, newGroups, tracer);
-      if (newGroup != null) {
-        newGroup.add(clazz);
-      } else {
-        // Ineligible for merging.
+
+    // Revisit each partition. If the class initialization of a group member may initialize another
+    // class (not necessarily a group member), and vice versa, then class initialization could
+    // deadlock if the group member is merged with another class that is initialized concurrently.
+    for (MergeGroup partition : partitioning) {
+      List<MergeGroup> newGroupsFromPartition = new LinkedList<>();
+      Tracer tracer = new Tracer(partition);
+      for (DexProgramClass clazz : partition) {
+        MergeGroup newGroup = getOrCreateGroupFor(clazz, newGroupsFromPartition, tracer);
+        if (newGroup != null) {
+          newGroup.add(clazz);
+        } else {
+          // Ineligible for merging.
+        }
       }
+      newGroups.addAll(newGroupsFromPartition);
     }
     return removeTrivialGroups(newGroups);
   }
@@ -118,12 +138,12 @@
       DexProgramClass clazz, List<MergeGroup> groups, Tracer tracer) {
     assert !tracer.hasPossibleClassInitializerDeadlock(clazz);
 
-    ProgramMethod classInitializer = clazz.getProgramClassInitializer();
-    if (classInitializer != null) {
-      assert tracer.verifySeenSetIsEmpty();
-      assert tracer.verifyWorklistIsEmpty();
+    if (clazz.hasClassInitializer()) {
+      // Trace from the class initializer of this group member. If an execution path is found that
+      // leads back to the class initializer then this class may be involved in a deadlock, and we
+      // should not merge any other classes into it.
       tracer.setTracingRoot(clazz);
-      tracer.enqueueMethod(classInitializer);
+      tracer.enqueueTracingRoot(clazz.getProgramClassInitializer());
       tracer.trace();
       if (tracer.hasPossibleClassInitializerDeadlock(clazz)) {
         // Ineligible for merging.
@@ -163,11 +183,63 @@
    * If the class initializer of one of the classes in the merge group is reached, then that class
    * is not eligible for merging.
    */
-  private void removeClassesWithPossibleClassInitializerDeadlock(MergeGroup group, Tracer tracer) {
+  private List<MergeGroup> partitionClassesWithPossibleClassInitializerDeadlock(MergeGroup group) {
+    Set<DexProgramClass> superclasses = Sets.newIdentityHashSet();
+    appView
+        .appInfo()
+        .traverseSuperClasses(
+            group.iterator().next(),
+            (supertype, superclass, immediateSubclass) -> {
+              if (superclass != null && superclass.isProgramClass()) {
+                superclasses.add(superclass.asProgramClass());
+                return TraversalContinuation.CONTINUE;
+              }
+              return TraversalContinuation.BREAK;
+            });
+
+    // Run the tracer from the class initializers of the superclasses.
+    Tracer tracer = new Tracer(group);
     tracer.setTracingRoots(group);
-    tracer.enqueueParentClassInitializers(group);
+    for (DexProgramClass superclass : superclasses) {
+      if (superclass.hasClassInitializer()) {
+        tracer.enqueueTracingRoot(superclass.getProgramClassInitializer());
+      }
+    }
     tracer.trace();
-    group.removeIf(tracer::hasPossibleClassInitializerDeadlock);
+
+    MergeGroup notInitializedByInitializationOfParent = new MergeGroup();
+    Map<DexProgramClass, MergeGroup> partitioning = new LinkedHashMap<>();
+    for (DexProgramClass member : group) {
+      if (tracer.hasPossibleClassInitializerDeadlock(member)) {
+        DexProgramClass nearestLock = getNearestLock(member, superclasses);
+        if (nearestLock != null) {
+          partitioning.computeIfAbsent(nearestLock, ignoreKey(MergeGroup::new)).add(member);
+        } else {
+          // Ineligible for merging.
+        }
+      } else {
+        notInitializedByInitializationOfParent.add(member);
+      }
+    }
+
+    return ImmutableList.<MergeGroup>builder()
+        .add(notInitializedByInitializationOfParent)
+        .addAll(partitioning.values())
+        .build();
+  }
+
+  private DexProgramClass getNearestLock(
+      DexProgramClass clazz, Set<DexProgramClass> candidateOwners) {
+    ProgramMethodSet seen = ProgramMethodSet.create();
+    ProgramMethod singleCaller = singleCallerInformation.getSingleClassInitializerCaller(clazz);
+    while (singleCaller != null && seen.add(singleCaller)) {
+      if (singleCaller.getDefinition().isClassInitializer()
+          && candidateOwners.contains(singleCaller.getHolder())) {
+        return singleCaller.getHolder();
+      }
+      singleCaller = singleCallerInformation.getSingleCaller(singleCaller);
+    }
+    return null;
   }
 
   @Override
@@ -181,12 +253,15 @@
   }
 
   @Override
-  public Void preprocess(Collection<MergeGroup> groups) {
+  public Void preprocess(Collection<MergeGroup> groups, ExecutorService executorService)
+      throws ExecutionException {
     for (MergeGroup group : groups) {
       for (DexProgramClass clazz : group) {
         allGroups.put(clazz, group);
       }
     }
+    singleCallerInformation =
+        SingleCallerInformation.builder(appView).analyze(executorService).build();
     return null;
   }
 
@@ -222,34 +297,26 @@
       seenMethods.clear();
     }
 
+    void clearWorklist() {
+      worklist.clear();
+    }
+
     boolean markClassInitializerAsSeen(DexProgramClass clazz) {
       return seenClassInitializers.add(clazz);
     }
 
     boolean enqueueMethod(ProgramMethod method) {
       if (seenMethods.add(method)) {
-        worklist.add(method);
+        worklist.addLast(method);
         return true;
       }
       return false;
     }
 
-    void enqueueParentClassInitializers(MergeGroup group) {
-      DexProgramClass member = group.iterator().next();
-      enqueueParentClassInitializers(member);
-    }
-
-    void enqueueParentClassInitializers(DexProgramClass clazz) {
-      DexProgramClass superClass =
-          asProgramClassOrNull(appView.definitionFor(clazz.getSuperType()));
-      if (superClass == null) {
-        return;
-      }
-      ProgramMethod classInitializer = superClass.getProgramClassInitializer();
-      if (classInitializer != null) {
-        enqueueMethod(classInitializer);
-      }
-      enqueueParentClassInitializers(superClass);
+    void enqueueTracingRoot(ProgramMethod tracingRoot) {
+      boolean added = seenMethods.add(tracingRoot);
+      assert added;
+      worklist.add(tracingRoot);
     }
 
     void recordClassInitializerReachableFromTracingRoots(DexProgramClass clazz) {
@@ -284,20 +351,25 @@
           .contains(classBeingInitialized);
     }
 
+    private void processWorklist() {
+      while (!worklist.isEmpty()) {
+        ProgramMethod method = worklist.removeLast();
+        method.registerCodeReferences(new TracerUseRegistry(method));
+      }
+    }
+
     void setTracingRoot(DexProgramClass tracingRoot) {
       setTracingRoots(ImmutableList.of(tracingRoot));
     }
 
     void setTracingRoots(Collection<DexProgramClass> tracingRoots) {
+      assert verifySeenSetIsEmpty();
+      assert verifyWorklistIsEmpty();
       this.tracingRoots = tracingRoots;
     }
 
     void trace() {
-      // TODO(b/205611444): Avoid redundant tracing of the same methods.
-      while (!worklist.isEmpty()) {
-        ProgramMethod method = worklist.removeLast();
-        method.registerCodeReferences(new TracerUseRegistry(method));
-      }
+      processWorklist();
       clearSeen();
     }
 
@@ -322,6 +394,7 @@
         // Ensures that hasPossibleClassInitializerDeadlock() returns true for each tracing root.
         recordTracingRootsIneligibleForClassMerging();
         doBreak();
+        clearWorklist();
       }
 
       private void triggerClassInitializerIfNotAlreadyTriggeredInContext(DexType type) {
@@ -338,8 +411,6 @@
       }
 
       private boolean isClassAlreadyInitializedInCurrentContext(DexProgramClass clazz) {
-        // TODO(b/205611444): There is only a risk of a deadlock if the execution path comes from
-        //  outside the merge group. We could address this by updating this check.
         return appView.appInfo().isSubtype(getContext().getHolder(), clazz);
       }
 
@@ -400,7 +471,13 @@
 
       @Override
       public void registerInvokeInterface(DexMethod method) {
-        fail();
+        DexMethod rewrittenMethod =
+            appView.graphLens().lookupInvokeInterface(method, getContext()).getReference();
+        DexClassAndMethod resolvedMethod =
+            appView.appInfo().resolveMethodOnInterface(rewrittenMethod).getResolutionPair();
+        if (resolvedMethod != null) {
+          fail();
+        }
       }
 
       @Override
@@ -432,7 +509,17 @@
 
       @Override
       public void registerInvokeVirtual(DexMethod method) {
-        fail();
+        DexMethod rewrittenMethod =
+            appView.graphLens().lookupInvokeVirtual(method, getContext()).getReference();
+        DexClassAndMethod resolvedMethod =
+            appView.appInfo().resolveMethodOnClass(rewrittenMethod).getResolutionPair();
+        if (resolvedMethod != null) {
+          if (!resolvedMethod.getHolder().isEffectivelyFinal(appView)) {
+            fail();
+          } else if (resolvedMethod.isProgramMethod()) {
+            enqueueMethod(resolvedMethod.asProgramMethod());
+          }
+        }
       }
 
       @Override
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoConstructorCollisions.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoConstructorCollisions.java
index 191adc4..e77b31a 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoConstructorCollisions.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoConstructorCollisions.java
@@ -25,6 +25,7 @@
 import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutorService;
 
 /**
  * In the final round, we're not allowed to resolve constructor collisions by appending null
@@ -66,7 +67,7 @@
    * lead to constructor collisions.
    */
   @Override
-  public Set<DexType> preprocess(Collection<MergeGroup> groups) {
+  public Set<DexType> preprocess(Collection<MergeGroup> groups, ExecutorService executorService) {
     // Build a mapping from types to groups.
     Map<DexType, MergeGroup> groupsByType = new IdentityHashMap<>();
     for (MergeGroup group : groups) {
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodCollisions.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodCollisions.java
index a2362bd..5ebe1a6 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodCollisions.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodCollisions.java
@@ -36,6 +36,7 @@
 import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutorService;
 import java.util.function.Function;
 
 /**
@@ -143,7 +144,8 @@
   }
 
   @Override
-  public Map<DexType, InterfaceInfo> preprocess(Collection<MergeGroup> groups) {
+  public Map<DexType, InterfaceInfo> preprocess(
+      Collection<MergeGroup> groups, ExecutorService executorService) {
     SubtypingInfo subtypingInfo = new SubtypingInfo(appView);
     Collection<DexProgramClass> classesOfInterest = computeClassesOfInterest(subtypingInfo);
     Map<DexType, DexMethodSignatureSet> inheritedClassMethodsPerClass =
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/OnlyDirectlyConnectedOrUnrelatedInterfaces.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/OnlyDirectlyConnectedOrUnrelatedInterfaces.java
index c53395c..f86a398 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/OnlyDirectlyConnectedOrUnrelatedInterfaces.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/OnlyDirectlyConnectedOrUnrelatedInterfaces.java
@@ -27,6 +27,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutorService;
 import java.util.function.Function;
 
 /**
@@ -160,7 +161,7 @@
   }
 
   @Override
-  public SubtypingInfo preprocess(Collection<MergeGroup> groups) {
+  public SubtypingInfo preprocess(Collection<MergeGroup> groups, ExecutorService executorService) {
     return new SubtypingInfo(appView);
   }
 
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/deadlock/SingleCallerInformation.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/deadlock/SingleCallerInformation.java
new file mode 100644
index 0000000..0e35bf3
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/deadlock/SingleCallerInformation.java
@@ -0,0 +1,262 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging.policies.deadlock;
+
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.UseRegistry;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+
+/**
+ * Stores the single caller (if any) for each non-virtual method. Virtual methods are not considered
+ * since computing single caller information for such methods is expensive (it involves computing
+ * the possible dispatch targets for each virtual invoke).
+ *
+ * <p>Unlike the {@link com.android.tools.r8.ir.conversion.CallGraph} that is used to determine if a
+ * method can be single caller inlined, this considers a method that is called from multiple call
+ * sites in the same method to have a single caller.
+ */
+// TODO(b/205611444): account for -keep rules.
+public class SingleCallerInformation {
+
+  private final ProgramMethodMap<ProgramMethod> singleCallers;
+  private final Map<DexProgramClass, ProgramMethod> singleClinitCallers;
+
+  SingleCallerInformation(
+      ProgramMethodMap<ProgramMethod> singleCallers,
+      Map<DexProgramClass, ProgramMethod> singleClinitCallers) {
+    this.singleCallers = singleCallers;
+    this.singleClinitCallers = singleClinitCallers;
+  }
+
+  public static Builder builder(AppView<? extends AppInfoWithClassHierarchy> appView) {
+    return new Builder(appView);
+  }
+
+  public ProgramMethod getSingleCaller(ProgramMethod method) {
+    return singleCallers.get(method);
+  }
+
+  public ProgramMethod getSingleClassInitializerCaller(DexProgramClass clazz) {
+    return singleClinitCallers.get(clazz);
+  }
+
+  public static class Builder {
+
+    private final AppView<? extends AppInfoWithClassHierarchy> appView;
+
+    // The single callers for each method and class initializer.
+    // If a method is not in the map, then a call to that method has never been seen.
+    // If a method is mapped to Optional.empty(), then the method has multiple calling contexts.
+    // If a method is mapped to Optional.of(m), then the method is only called from method m.
+    final ProgramMethodMap<Optional<ProgramMethod>> callers = ProgramMethodMap.createConcurrent();
+    final Map<DexProgramClass, Optional<ProgramMethod>> clinitCallers = new ConcurrentHashMap<>();
+
+    Builder(AppView<? extends AppInfoWithClassHierarchy> appView) {
+      this.appView = appView;
+    }
+
+    public Builder analyze(ExecutorService executorService) throws ExecutionException {
+      ThreadUtils.processItems(
+          appView.appInfo()::forEachMethod, this::processMethod, executorService);
+      return this;
+    }
+
+    public SingleCallerInformation build() {
+      ProgramMethodMap<ProgramMethod> singleCallers = ProgramMethodMap.create();
+      callers.forEach(
+          (method, callers) -> callers.ifPresent(caller -> singleCallers.put(method, caller)));
+      Map<DexProgramClass, ProgramMethod> singleClinitCallers = new IdentityHashMap<>();
+      clinitCallers.forEach(
+          (clazz, callers) -> callers.ifPresent(caller -> singleClinitCallers.put(clazz, caller)));
+      return new SingleCallerInformation(singleCallers, singleClinitCallers);
+    }
+
+    private void processMethod(ProgramMethod method) {
+      method.registerCodeReferences(new InvokeExtractor(appView, method));
+    }
+
+    private class InvokeExtractor extends UseRegistry<ProgramMethod> {
+
+      private final AppView<? extends AppInfoWithClassHierarchy> appView;
+
+      InvokeExtractor(AppView<? extends AppInfoWithClassHierarchy> appView, ProgramMethod context) {
+        super(appView, context);
+        this.appView = appView;
+      }
+
+      private void recordDispatchTarget(ProgramMethod target) {
+        callers.compute(
+            target,
+            (key, value) -> {
+              if (value == null) {
+                // This target is now called from the current context (only).
+                return Optional.of(getContext());
+              }
+              // If the target is only called from the current context, then that is still the
+              // case.
+              if (value.orElse(null) == getContext()) {
+                return value;
+              }
+              // The target is now called from more than one place.
+              return Optional.empty();
+            });
+      }
+
+      private void triggerClassInitializerIfNotAlreadyTriggeredInContext(DexType type) {
+        DexProgramClass clazz = type.asProgramClass(appView);
+        if (clazz != null) {
+          triggerClassInitializerIfNotAlreadyTriggeredInContext(clazz);
+        }
+      }
+
+      private void triggerClassInitializerIfNotAlreadyTriggeredInContext(DexProgramClass clazz) {
+        if (!isClassAlreadyInitializedInCurrentContext(clazz)) {
+          triggerClassInitializer(clazz);
+        }
+      }
+
+      private boolean isClassAlreadyInitializedInCurrentContext(DexProgramClass clazz) {
+        return appView.appInfo().isSubtype(getContext().getHolder(), clazz);
+      }
+
+      private void triggerClassInitializer(DexType type) {
+        DexProgramClass clazz = type.asProgramClass(appView);
+        if (clazz != null) {
+          triggerClassInitializer(clazz);
+        }
+      }
+
+      private void triggerClassInitializer(DexProgramClass clazz) {
+        Optional<ProgramMethod> callers = clinitCallers.get(clazz);
+        if (callers != null) {
+          if (!callers.isPresent()) {
+            // Optional.empty() represents that this class initializer has multiple (unknown)
+            // callers. Since this <clinit> and all of the parent <clinit>s are already triggered
+            // from multiple places, there is no need to record it is also triggered from the
+            // current context.
+            return;
+          }
+          if (callers.get() == getContext()) {
+            // This <clinit> is already triggered from the current context. No need to record this
+            // again.
+            return;
+          }
+        }
+
+        // Record that the given class is now initialized from the current context.
+        clinitCallers.compute(
+            clazz,
+            (key, value) -> {
+              if (value == null) {
+                // This <clinit> was not triggered before.
+                return Optional.of(getContext());
+              }
+              // This <clinit> was triggered from another context than the current.
+              assert value.orElse(null) != getContext();
+              return Optional.empty();
+            });
+
+        // Repeat for the parent classes.
+        triggerClassInitializer(clazz.getSuperType());
+      }
+
+      @Override
+      public void registerInitClass(DexType type) {
+        DexType rewrittenType = appView.graphLens().lookupType(type);
+        triggerClassInitializerIfNotAlreadyTriggeredInContext(rewrittenType);
+      }
+
+      @Override
+      public void registerInstanceFieldRead(DexField field) {
+        // Intentionally empty.
+      }
+
+      @Override
+      public void registerInstanceFieldWrite(DexField field) {
+        // Intentionally empty.
+      }
+
+      @Override
+      public void registerInvokeDirect(DexMethod method) {
+        DexMethod rewrittenMethod =
+            appView.graphLens().lookupInvokeDirect(method, getContext()).getReference();
+        DexProgramClass holder = rewrittenMethod.getHolderType().asProgramClass(appView);
+        ProgramMethod target = rewrittenMethod.lookupOnProgramClass(holder);
+        if (target != null) {
+          recordDispatchTarget(target);
+        }
+      }
+
+      @Override
+      public void registerInvokeInterface(DexMethod method) {
+        // Intentionally empty, as we don't aim to collect single caller information for virtual
+        // methods.
+      }
+
+      @Override
+      public void registerInvokeStatic(DexMethod method) {
+        DexMethod rewrittenMethod =
+            appView.graphLens().lookupInvokeDirect(method, getContext()).getReference();
+        ProgramMethod target =
+            appView
+                .appInfo()
+                .unsafeResolveMethodDueToDexFormat(rewrittenMethod)
+                .getResolvedProgramMethod();
+        if (target != null) {
+          recordDispatchTarget(target);
+          triggerClassInitializerIfNotAlreadyTriggeredInContext(target.getHolder());
+        }
+      }
+
+      @Override
+      public void registerInvokeSuper(DexMethod method) {
+        // Intentionally empty, as we don't aim to collect single caller information for virtual
+        // methods.
+      }
+
+      @Override
+      public void registerInvokeVirtual(DexMethod method) {
+        // Intentionally empty, as we don't aim to collect single caller information for virtual
+        // methods.
+      }
+
+      @Override
+      public void registerNewInstance(DexType type) {
+        DexType rewrittenType = appView.graphLens().lookupType(type);
+        triggerClassInitializerIfNotAlreadyTriggeredInContext(rewrittenType);
+      }
+
+      @Override
+      public void registerStaticFieldRead(DexField field) {
+        DexField rewrittenField = appView.graphLens().lookupField(field);
+        triggerClassInitializerIfNotAlreadyTriggeredInContext(rewrittenField.getHolderType());
+      }
+
+      @Override
+      public void registerStaticFieldWrite(DexField field) {
+        DexField rewrittenField = appView.graphLens().lookupField(field);
+        triggerClassInitializerIfNotAlreadyTriggeredInContext(rewrittenField.getHolderType());
+      }
+
+      @Override
+      public void registerTypeReference(DexType type) {
+        // Intentionally empty.
+      }
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/utils/collections/ProgramMemberMap.java b/src/main/java/com/android/tools/r8/utils/collections/ProgramMemberMap.java
index 3a648fe..9c1c5af 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/ProgramMemberMap.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/ProgramMemberMap.java
@@ -8,6 +8,7 @@
 import com.google.common.base.Equivalence.Wrapper;
 import java.util.Map;
 import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
 import java.util.function.BiPredicate;
 import java.util.function.Function;
 import java.util.function.Supplier;
@@ -28,6 +29,10 @@
     backing.clear();
   }
 
+  public V compute(K member, BiFunction<K, V, V> fn) {
+    return backing.compute(wrap(member), (key, value) -> fn.apply(member, value));
+  }
+
   public V computeIfAbsent(K member, Function<K, V> fn) {
     return backing.computeIfAbsent(wrap(member), key -> fn.apply(key.get()));
   }
@@ -44,6 +49,10 @@
     return backing.get(wrap(member));
   }
 
+  public V getOrDefault(K member, V defaultValue) {
+    return backing.getOrDefault(wrap(member), defaultValue);
+  }
+
   public boolean isEmpty() {
     return backing.isEmpty();
   }
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/ClinitDeadlockAfterMergingSingletonClassesInstantiatedByCompanionTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/ClinitDeadlockAfterMergingSingletonClassesInstantiatedByCompanionTest.java
index 2ea472e..f7ca4c1 100644
--- a/src/test/java/com/android/tools/r8/classmerging/horizontal/ClinitDeadlockAfterMergingSingletonClassesInstantiatedByCompanionTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/ClinitDeadlockAfterMergingSingletonClassesInstantiatedByCompanionTest.java
@@ -9,7 +9,6 @@
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.classmerging.horizontal.ClinitDeadlockAfterMergingSingletonClassesInstantiatedByCompanionTest.Host.Companion.HostA;
 import com.android.tools.r8.classmerging.horizontal.ClinitDeadlockAfterMergingSingletonClassesInstantiatedByCompanionTest.Host.Companion.HostB;
-import com.android.tools.r8.utils.codeinspector.HorizontallyMergedClassesInspector;
 import com.google.common.collect.ImmutableList;
 import java.util.List;
 import org.junit.Test;
@@ -43,9 +42,12 @@
             "  public static void thread0();",
             "  public static void thread" + thread + "();",
             "}")
-        // TODO(b/205611444): HostA and HostB should be merged when thread is 1.
         .addHorizontallyMergedClassesInspector(
-            HorizontallyMergedClassesInspector::assertNoClassesMerged)
+            inspector ->
+                inspector
+                    .applyIf(
+                        thread == 1, i -> i.assertIsCompleteMergeGroup(HostA.class, HostB.class))
+                    .assertNoOtherClassesMerged())
         .addOptionsModification(
             options ->
                 options.horizontalClassMergerOptions().setEnableClassInitializerDeadlockDetection())