Preliminary policies for interface merging

Bug: 173990042
Change-Id: I43048879b33a4a69592b96aaf6d7918af2f2aee7
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/MergeGroup.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/MergeGroup.java
index 4b9ad6a..b99cb2f 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/MergeGroup.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/MergeGroup.java
@@ -6,8 +6,10 @@
 
 package com.android.tools.r8.horizontalclassmerging;
 
+import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.utils.IterableUtils;
 import com.android.tools.r8.utils.IteratorUtils;
 import com.google.common.collect.Iterables;
 import java.util.Collection;
@@ -16,7 +18,7 @@
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 
-public class MergeGroup implements Iterable<DexProgramClass> {
+public class MergeGroup implements Collection<DexProgramClass> {
 
   public static class Metadata {}
 
@@ -46,20 +48,33 @@
     }
   }
 
-  public void add(DexProgramClass clazz) {
-    classes.add(clazz);
+  @Override
+  public boolean add(DexProgramClass clazz) {
+    return classes.add(clazz);
   }
 
-  public void add(MergeGroup group) {
-    classes.addAll(group.getClasses());
+  public boolean add(MergeGroup group) {
+    return classes.addAll(group.getClasses());
   }
 
-  public void addAll(Collection<DexProgramClass> classes) {
-    this.classes.addAll(classes);
+  @Override
+  public boolean addAll(Collection<? extends DexProgramClass> classes) {
+    return this.classes.addAll(classes);
   }
 
-  public void addFirst(DexProgramClass clazz) {
-    classes.addFirst(clazz);
+  @Override
+  public void clear() {
+    classes.clear();
+  }
+
+  @Override
+  public boolean contains(Object o) {
+    return classes.contains(o);
+  }
+
+  @Override
+  public boolean containsAll(Collection<?> collection) {
+    return classes.containsAll(collection);
   }
 
   public void forEachSource(Consumer<DexProgramClass> consumer) {
@@ -114,6 +129,12 @@
     return classes.isEmpty();
   }
 
+  public boolean isInterfaceGroup() {
+    assert !isEmpty();
+    assert IterableUtils.allIdentical(getClasses(), DexClass::isInterface);
+    return getClasses().getFirst().isInterface();
+  }
+
   @Override
   public Iterator<DexProgramClass> iterator() {
     return classes.iterator();
@@ -123,15 +144,41 @@
     return classes.size();
   }
 
+  @Override
+  public boolean remove(Object o) {
+    return classes.remove(o);
+  }
+
+  @Override
+  public boolean removeAll(Collection<?> collection) {
+    return classes.removeAll(collection);
+  }
+
   public DexProgramClass removeFirst(Predicate<DexProgramClass> predicate) {
     return IteratorUtils.removeFirst(iterator(), predicate);
   }
 
-  public boolean removeIf(Predicate<DexProgramClass> predicate) {
+  @Override
+  public boolean removeIf(Predicate<? super DexProgramClass> predicate) {
     return classes.removeIf(predicate);
   }
 
   public DexProgramClass removeLast() {
     return classes.removeLast();
   }
+
+  @Override
+  public boolean retainAll(Collection<?> collection) {
+    return collection.retainAll(collection);
+  }
+
+  @Override
+  public Object[] toArray() {
+    return classes.toArray();
+  }
+
+  @Override
+  public <T> T[] toArray(T[] ts) {
+    return classes.toArray(ts);
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicyWithPreprocessing.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicyWithPreprocessing.java
new file mode 100644
index 0000000..e06a7ef
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicyWithPreprocessing.java
@@ -0,0 +1,23 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging;
+
+import java.util.Collection;
+
+public abstract class MultiClassPolicyWithPreprocessing<T> extends Policy {
+
+  /**
+   * Apply the multi class policy to a group of program classes.
+   *
+   * @param group This is a group of program classes which can currently still be merged.
+   * @param data The result of calling {@link #preprocess(Collection)}.
+   * @return The same collection of program classes split into new groups of candidates which can be
+   *     merged. If the policy detects no issues then `group` will be returned unchanged. If classes
+   *     cannot be merged with any other classes they are returned as singleton lists.
+   */
+  public abstract Collection<MergeGroup> apply(MergeGroup group, T data);
+
+  public abstract T preprocess(Collection<MergeGroup> groups);
+}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodCollisions.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodCollisions.java
new file mode 100644
index 0000000..0851469
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodCollisions.java
@@ -0,0 +1,338 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging.policies;
+
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.emptySet;
+
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.BottomUpClassHierarchyTraversal;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexMethodSignature;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.SubtypingInfo;
+import com.android.tools.r8.graph.TopDownClassHierarchyTraversal;
+import com.android.tools.r8.horizontalclassmerging.MergeGroup;
+import com.android.tools.r8.horizontalclassmerging.MultiClassPolicyWithPreprocessing;
+import com.android.tools.r8.horizontalclassmerging.policies.NoDefaultInterfaceMethodCollisions.InterfaceInfo;
+import com.android.tools.r8.utils.ListUtils;
+import com.android.tools.r8.utils.MapUtils;
+import com.android.tools.r8.utils.SetUtils;
+import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * This policy prevents that interface merging changes semantics of invoke-interface/invoke-virtual
+ * instructions that dispatch to default interface methods.
+ *
+ * <p>As a simple example, consider the following snippet of code. If we merge interfaces I and K,
+ * then we effectively add the default interface method K.m() to I, which would change the semantics
+ * of calls to A.m().
+ *
+ * <pre>
+ *   interface I {}
+ *   interface J {
+ *     default void m() { print("J"); }
+ *   }
+ *   interface K {
+ *     default void m() { print("K"); }
+ *   }
+ *   class A implements I, J {}
+ * </pre>
+ *
+ * Note that we also cannot merge I with K, even if K does not declare any methods directly:
+ *
+ * <pre>
+ *   interface K0 {
+ *     default void m() { print("K"); }
+ *   }
+ *   interface K extends K0 {}
+ * </pre>
+ *
+ * Also, note that this is not a problem if class A overrides void m().
+ */
+public class NoDefaultInterfaceMethodCollisions
+    extends MultiClassPolicyWithPreprocessing<Map<DexType, InterfaceInfo>> {
+
+  private final AppView<? extends AppInfoWithClassHierarchy> appView;
+
+  public NoDefaultInterfaceMethodCollisions(AppView<? extends AppInfoWithClassHierarchy> appView) {
+    this.appView = appView;
+  }
+
+  @Override
+  public Collection<MergeGroup> apply(MergeGroup group, Map<DexType, InterfaceInfo> infos) {
+    if (!group.isInterfaceGroup()) {
+      return ImmutableList.of(group);
+    }
+
+    // For each interface I in the group, check that each (non-interface) subclass of I does not
+    // inherit a default method that is also declared by another interface J in the merge group.
+    //
+    // Note that the primary piece of work is carried out in the preprocess() method
+    //
+    // TODO(b/173990042): Consider forming multiple groups instead of just filtering. In practice,
+    //  this rarely leads to much filtering, though, since the use of default methods is somewhat
+    //  limited.
+    MergeGroup newGroup = new MergeGroup();
+    for (DexProgramClass clazz : group) {
+      Set<DexMethod> newDefaultMethodsAddedToClassByMerge =
+          computeNewDefaultMethodsAddedToClassByMerge(clazz, group, infos);
+      if (isSafeToAddDefaultMethodsToClass(clazz, newDefaultMethodsAddedToClassByMerge, infos)) {
+        newGroup.add(clazz);
+      }
+    }
+    return newGroup.isTrivial() ? Collections.emptyList() : ListUtils.newLinkedList(newGroup);
+  }
+
+  private Set<DexMethod> computeNewDefaultMethodsAddedToClassByMerge(
+      DexProgramClass clazz, MergeGroup group, Map<DexType, InterfaceInfo> infos) {
+    // Run through the other classes in the merge group, and add the default interface methods that
+    // they declare (or inherit from a super interface) to a set.
+    Set<DexMethod> newDefaultMethodsAddedToClassByMerge = Sets.newIdentityHashSet();
+    for (DexProgramClass other : group) {
+      if (other != clazz) {
+        Collection<Set<DexMethod>> inheritedDefaultMethodsFromOther =
+            infos.get(other.getType()).getInheritedDefaultMethods().values();
+        inheritedDefaultMethodsFromOther.forEach(newDefaultMethodsAddedToClassByMerge::addAll);
+      }
+    }
+    return newDefaultMethodsAddedToClassByMerge;
+  }
+
+  private boolean isSafeToAddDefaultMethodsToClass(
+      DexProgramClass clazz,
+      Set<DexMethod> newDefaultMethodsAddedToClassByMerge,
+      Map<DexType, InterfaceInfo> infos) {
+    // Check if there is a subclass of this interface, which inherits a default interface method
+    // that would also be added by to this interface by merging the interfaces in the group.
+    Map<DexMethodSignature, Set<DexMethod>> defaultMethodsInheritedBySubclassesOfClass =
+        infos.get(clazz.getType()).getDefaultMethodsInheritedBySubclasses();
+    for (DexMethod newDefaultMethodAddedToClassByMerge : newDefaultMethodsAddedToClassByMerge) {
+      Set<DexMethod> defaultMethodsInheritedBySubclassesOfClassWithSameSignature =
+          defaultMethodsInheritedBySubclassesOfClass.getOrDefault(
+              newDefaultMethodAddedToClassByMerge.getSignature(), emptySet());
+      // Look for a method different from the method we're adding.
+      for (DexMethod method : defaultMethodsInheritedBySubclassesOfClassWithSameSignature) {
+        if (method != newDefaultMethodAddedToClassByMerge) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  @Override
+  public Map<DexType, InterfaceInfo> preprocess(Collection<MergeGroup> groups) {
+    SubtypingInfo subtypingInfo = new SubtypingInfo(appView);
+    Collection<DexProgramClass> classesOfInterest = computeClassesOfInterest(subtypingInfo);
+    Map<DexType, DexMethodSignatureSet> inheritedClassMethodsPerClass =
+        computeInheritedClassMethodsPerProgramClass(classesOfInterest);
+    Map<DexType, Map<DexMethodSignature, Set<DexMethod>>> inheritedDefaultMethodsPerClass =
+        computeInheritedDefaultMethodsPerProgramType(
+            classesOfInterest, inheritedClassMethodsPerClass);
+
+    // Finally, do a bottom-up traversal, pushing the inherited default methods upwards.
+    Map<DexType, Map<DexMethodSignature, Set<DexMethod>>>
+        defaultMethodsInheritedBySubclassesPerClass =
+            computeDefaultMethodsInheritedBySubclassesPerProgramClass(
+                classesOfInterest, inheritedDefaultMethodsPerClass, subtypingInfo);
+
+    // Store the computed information for each interface that is subject to merging.
+    Map<DexType, InterfaceInfo> infos = new IdentityHashMap<>();
+    for (MergeGroup group : groups) {
+      if (group.isInterfaceGroup()) {
+        for (DexProgramClass clazz : group) {
+          infos.put(
+              clazz.getType(),
+              new InterfaceInfo(
+                  inheritedDefaultMethodsPerClass.getOrDefault(clazz.getType(), emptyMap()),
+                  defaultMethodsInheritedBySubclassesPerClass.getOrDefault(
+                      clazz.getType(), emptyMap())));
+        }
+      }
+    }
+    return infos;
+  }
+
+  /** Returns the set of program classes that must be considered during preprocessing. */
+  private Collection<DexProgramClass> computeClassesOfInterest(SubtypingInfo subtypingInfo) {
+    // TODO(b/173990042): Limit result to the set of classes that are in the same as one of
+    //  the interfaces that is subject to merging.
+    return appView.appInfo().classes();
+  }
+
+  /**
+   * For each class, computes the (transitive) set of virtual methods that is declared on the class
+   * itself or one of its (non-interface) super classes.
+   */
+  private Map<DexType, DexMethodSignatureSet> computeInheritedClassMethodsPerProgramClass(
+      Collection<DexProgramClass> classesOfInterest) {
+    Map<DexType, DexMethodSignatureSet> inheritedClassMethodsPerClass = new IdentityHashMap<>();
+    TopDownClassHierarchyTraversal.forAllClasses(appView)
+        .excludeInterfaces()
+        .visit(
+            classesOfInterest,
+            clazz -> {
+              DexMethodSignatureSet classMethods =
+                  DexMethodSignatureSet.create(
+                      inheritedClassMethodsPerClass.getOrDefault(
+                          clazz.getSuperType(), DexMethodSignatureSet.empty()));
+              for (DexEncodedMethod method : clazz.virtualMethods()) {
+                classMethods.add(method.getSignature());
+              }
+              inheritedClassMethodsPerClass.put(clazz.getType(), classMethods);
+            });
+    inheritedClassMethodsPerClass
+        .keySet()
+        .removeIf(type -> !appView.definitionFor(type).isProgramClass());
+    return inheritedClassMethodsPerClass;
+  }
+
+  /**
+   * For each class or interface, computes the (transitive) set of virtual methods that is declared
+   * on the class itself or one of its (non-interface) super classes.
+   */
+  private Map<DexType, Map<DexMethodSignature, Set<DexMethod>>>
+      computeInheritedDefaultMethodsPerProgramType(
+          Collection<DexProgramClass> classesOfInterest,
+          Map<DexType, DexMethodSignatureSet> inheritedClassMethodsPerClass) {
+    Map<DexType, Map<DexMethodSignature, Set<DexMethod>>> inheritedDefaultMethodsPerType =
+        new IdentityHashMap<>();
+    TopDownClassHierarchyTraversal.forAllClasses(appView)
+        .visit(
+            classesOfInterest,
+            clazz -> {
+              // Compute the set of default method signatures that this class inherits from its
+              // super class and interfaces.
+              Map<DexMethodSignature, Set<DexMethod>> inheritedDefaultMethods = new HashMap<>();
+              for (DexType supertype : clazz.allImmediateSupertypes()) {
+                Map<DexMethodSignature, Set<DexMethod>> inheritedDefaultMethodsFromSuperType =
+                    inheritedDefaultMethodsPerType.getOrDefault(supertype, emptyMap());
+                inheritedDefaultMethodsFromSuperType.forEach(
+                    (signature, methods) ->
+                        inheritedDefaultMethods
+                            .computeIfAbsent(signature, ignore -> Sets.newIdentityHashSet())
+                            .addAll(methods));
+              }
+
+              // If this is an interface, also include the default methods it declares.
+              if (clazz.isInterface()) {
+                for (DexEncodedMethod method :
+                    clazz.virtualMethods(DexEncodedMethod::isDefaultMethod)) {
+                  inheritedDefaultMethods
+                      .computeIfAbsent(method.getSignature(), ignore -> Sets.newIdentityHashSet())
+                      .add(method.getReference());
+                }
+              }
+
+              // Remove all default methods that are declared as (non-interface) class methods on
+              // the current class.
+              inheritedDefaultMethods
+                  .keySet()
+                  .removeAll(
+                      inheritedClassMethodsPerClass.getOrDefault(
+                          clazz.getType(), DexMethodSignatureSet.empty()));
+
+              if (!inheritedDefaultMethods.isEmpty()) {
+                inheritedDefaultMethodsPerType.put(clazz.getType(), inheritedDefaultMethods);
+              }
+            });
+    inheritedDefaultMethodsPerType
+        .keySet()
+        .removeIf(type -> !appView.definitionFor(type).isProgramClass());
+    return inheritedDefaultMethodsPerType;
+  }
+
+  /**
+   * Performs a bottom-up traversal of the hierarchy, where the inherited default methods of each
+   * class are pushed upwards. This accumulates the set of default methods that are inherited by all
+   * subclasses of a given interface.
+   */
+  private Map<DexType, Map<DexMethodSignature, Set<DexMethod>>>
+      computeDefaultMethodsInheritedBySubclassesPerProgramClass(
+          Collection<DexProgramClass> classesOfInterest,
+          Map<DexType, Map<DexMethodSignature, Set<DexMethod>>> inheritedDefaultMethodsPerClass,
+          SubtypingInfo subtypingInfo) {
+    // Copy the map from classes to their inherited default methods.
+    Map<DexType, Map<DexMethodSignature, Set<DexMethod>>>
+        defaultMethodsInheritedBySubclassesPerClass =
+            MapUtils.clone(
+                inheritedDefaultMethodsPerClass,
+                new HashMap<>(),
+                outerValue ->
+                    MapUtils.clone(outerValue, new HashMap<>(), SetUtils::newIdentityHashSet));
+    BottomUpClassHierarchyTraversal.forProgramClasses(appView, subtypingInfo)
+        .visit(
+            classesOfInterest,
+            clazz -> {
+              // Push the current class' default methods upwards to all super classes.
+              Map<DexMethodSignature, Set<DexMethod>> defaultMethodsToPropagate =
+                  defaultMethodsInheritedBySubclassesPerClass.getOrDefault(
+                      clazz.getType(), emptyMap());
+              for (DexType supertype : clazz.allImmediateSupertypes()) {
+                Map<DexMethodSignature, Set<DexMethod>>
+                    defaultMethodsInheritedBySubclassesForSupertype =
+                        defaultMethodsInheritedBySubclassesPerClass.computeIfAbsent(
+                            supertype, ignore -> new HashMap<>());
+                defaultMethodsToPropagate.forEach(
+                    (signature, methods) ->
+                        defaultMethodsInheritedBySubclassesForSupertype
+                            .computeIfAbsent(signature, ignore -> Sets.newIdentityHashSet())
+                            .addAll(methods));
+              }
+            });
+    defaultMethodsInheritedBySubclassesPerClass
+        .keySet()
+        .removeIf(type -> !appView.definitionFor(type).isProgramClass());
+    return defaultMethodsInheritedBySubclassesPerClass;
+  }
+
+  @Override
+  public String getName() {
+    return "NoDefaultInterfaceMethodCollisions";
+  }
+
+  @Override
+  public boolean shouldSkipPolicy() {
+    return !appView.options().horizontalClassMergerOptions().isInterfaceMergingEnabled();
+  }
+
+  static class InterfaceInfo {
+
+    // The set of default interface methods (grouped by signature) that this interface declares or
+    // inherits from one of its (transitive) super interfaces.
+    private final Map<DexMethodSignature, Set<DexMethod>> inheritedDefaultMethods;
+
+    // The set of default interface methods (grouped by signature) that subclasses of this interface
+    // inherits from one of its (transitively) implemented super interfaces.
+    private final Map<DexMethodSignature, Set<DexMethod>> defaultMethodsInheritedBySubclasses;
+
+    InterfaceInfo(
+        Map<DexMethodSignature, Set<DexMethod>> inheritedDefaultMethods,
+        Map<DexMethodSignature, Set<DexMethod>> defaultMethodsInheritedBySubclasses) {
+      this.inheritedDefaultMethods = inheritedDefaultMethods;
+      this.defaultMethodsInheritedBySubclasses = defaultMethodsInheritedBySubclasses;
+    }
+
+    Map<DexMethodSignature, Set<DexMethod>> getInheritedDefaultMethods() {
+      return inheritedDefaultMethods;
+    }
+
+    Map<DexMethodSignature, Set<DexMethod>> getDefaultMethodsInheritedBySubclasses() {
+      return defaultMethodsInheritedBySubclasses;
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodMerging.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodMerging.java
new file mode 100644
index 0000000..74ae0ab
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoDefaultInterfaceMethodMerging.java
@@ -0,0 +1,81 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging.policies;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.horizontalclassmerging.MergeGroup;
+import com.android.tools.r8.horizontalclassmerging.MultiClassPolicy;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.ListUtils;
+import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
+import com.google.common.collect.Lists;
+import java.util.Collection;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+
+/**
+ * For interfaces, we cannot introduce an instance field `int $r8$classId`. Therefore, we can't
+ * merge two interfaces that declare the same default interface method.
+ *
+ * <p>This policy attempts to split a merge group consisting of interfaces into smaller merge groups
+ * such that each pairs of interfaces in each merge group does not have conflicting default
+ * interface methods.
+ */
+public class NoDefaultInterfaceMethodMerging extends MultiClassPolicy {
+
+  private final InternalOptions options;
+
+  public NoDefaultInterfaceMethodMerging(AppView<?> appView) {
+    this.options = appView.options();
+  }
+
+  @Override
+  public Collection<MergeGroup> apply(MergeGroup group) {
+    if (!group.isInterfaceGroup()) {
+      return ListUtils.newLinkedList(group);
+    }
+
+    // Split the group into smaller groups such that no default methods collide.
+    Map<MergeGroup, DexMethodSignatureSet> newGroups = new LinkedHashMap<>();
+    for (DexProgramClass clazz : group) {
+      addClassToGroup(clazz, newGroups);
+    }
+
+    return removeTrivialGroups(Lists.newLinkedList(newGroups.keySet()));
+  }
+
+  private void addClassToGroup(
+      DexProgramClass clazz, Map<MergeGroup, DexMethodSignatureSet> newGroups) {
+    DexMethodSignatureSet classSignatures = DexMethodSignatureSet.create();
+    classSignatures.addAllMethods(clazz.virtualMethods(DexEncodedMethod::isDefaultMethod));
+
+    // Find a group that does not have any collisions with `clazz`.
+    for (Entry<MergeGroup, DexMethodSignatureSet> entry : newGroups.entrySet()) {
+      MergeGroup group = entry.getKey();
+      DexMethodSignatureSet groupSignatures = entry.getValue();
+      if (!groupSignatures.containsAnyOf(classSignatures)) {
+        groupSignatures.addAll(classSignatures);
+        group.add(clazz);
+        return;
+      }
+    }
+
+    // Else create a new group.
+    newGroups.put(new MergeGroup(clazz), classSignatures);
+  }
+
+  @Override
+  public String getName() {
+    return "NoDefaultInterfaceMethodMerging";
+  }
+
+  @Override
+  public boolean shouldSkipPolicy() {
+    return !options.horizontalClassMergerOptions().isInterfaceMergingEnabled();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/OnlyDirectlyConnectedOrUnrelatedInterfaces.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/OnlyDirectlyConnectedOrUnrelatedInterfaces.java
new file mode 100644
index 0000000..fecf958
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/OnlyDirectlyConnectedOrUnrelatedInterfaces.java
@@ -0,0 +1,186 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging.policies;
+
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.horizontalclassmerging.MergeGroup;
+import com.android.tools.r8.horizontalclassmerging.MultiClassPolicy;
+import com.android.tools.r8.utils.WorkList;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Sets;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.IdentityHashMap;
+import java.util.Iterator;
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+
+/**
+ * This policy ensures that we do not create cycles in the class hierarchy as a result of interface
+ * merging.
+ *
+ * <p>Example: Consider that we have the following three interfaces:
+ *
+ * <pre>
+ *   interface I extends ... {}
+ *   interface J extends I, ... {}
+ *   interface K extends J, ... {}
+ * </pre>
+ *
+ * <p>In this case, it would be possible to merge the groups {I, J}, {J, K}, and {I, J, K}. Common
+ * to these merge groups is that each interface in the merge group can reach all other interfaces in
+ * the same merge group in the class hierarchy, without visiting any interfaces outside the merge
+ * group.
+ *
+ * <p>The group {I, K} cannot safely be merged, as this would lead to a cycle in the class
+ * hierarchy:
+ *
+ * <pre>
+ *   interface IK extends J, ... {}
+ *   interface J extends IK, ... {}
+ * </pre>
+ */
+public class OnlyDirectlyConnectedOrUnrelatedInterfaces extends MultiClassPolicy {
+
+  private final AppView<? extends AppInfoWithClassHierarchy> appView;
+
+  public OnlyDirectlyConnectedOrUnrelatedInterfaces(
+      AppView<? extends AppInfoWithClassHierarchy> appView) {
+    this.appView = appView;
+  }
+
+  @Override
+  public Collection<MergeGroup> apply(MergeGroup group) {
+    if (!group.isInterfaceGroup()) {
+      return ImmutableList.of(group);
+    }
+
+    Set<DexProgramClass> classes = new LinkedHashSet<>(group.getClasses());
+    Map<DexProgramClass, Set<DexProgramClass>> ineligibleForMerging =
+        computeIneligibleForMergingGraph(classes);
+    if (ineligibleForMerging.isEmpty()) {
+      return ImmutableList.of(group);
+    }
+
+    // Extract sub-merge groups from the graph in such a way that all pairs of interfaces in each
+    // merge group are not connected by an edge in the graph.
+    List<MergeGroup> newGroups = new LinkedList<>();
+    while (!classes.isEmpty()) {
+      Iterator<DexProgramClass> iterator = classes.iterator();
+      MergeGroup newGroup = new MergeGroup(iterator.next());
+      Iterators.addAll(
+          newGroup,
+          Iterators.filter(
+              iterator,
+              candidate -> !isConnectedToGroup(candidate, newGroup, ineligibleForMerging)));
+      if (!newGroup.isTrivial()) {
+        newGroups.add(newGroup);
+      }
+      classes.removeAll(newGroup.getClasses());
+    }
+    return newGroups;
+  }
+
+  /**
+   * Computes an undirected graph, where the nodes are the interfaces from the merge group, and an
+   * edge I <-> J represents that I and J are not eligible for merging.
+   *
+   * <p>We will insert an edge I <-> J, if interface I inherits from interface J, and the path from
+   * I to J in the class hierarchy includes an interface K that is outside the merge group. Note
+   * that if I extends J directly we will not insert an edge I <-> J (unless there are multiple
+   * paths in the class hierarchy from I to J, and one of the paths goes through an interface
+   * outside the merge group).
+   */
+  private Map<DexProgramClass, Set<DexProgramClass>> computeIneligibleForMergingGraph(
+      Set<DexProgramClass> classes) {
+    Map<DexProgramClass, Set<DexProgramClass>> ineligibleForMerging = new IdentityHashMap<>();
+    for (DexProgramClass clazz : classes) {
+      forEachIndirectlyReachableInterfaceInMergeGroup(
+          clazz,
+          classes,
+          other ->
+              ineligibleForMerging
+                  .computeIfAbsent(clazz, ignore -> Sets.newIdentityHashSet())
+                  .add(other));
+    }
+    return ineligibleForMerging;
+  }
+
+  private void forEachIndirectlyReachableInterfaceInMergeGroup(
+      DexProgramClass clazz, Set<DexProgramClass> classes, Consumer<DexProgramClass> consumer) {
+    // First find the set of interfaces that can be reached via paths in the class hierarchy from
+    // the given interface, without visiting any interfaces outside the merge group.
+    WorkList<DexType> workList = WorkList.newIdentityWorkList(clazz.getInterfaces());
+    while (workList.hasNext()) {
+      DexProgramClass directlyReachableInterface =
+          asProgramClassOrNull(appView.definitionFor(workList.next()));
+      if (directlyReachableInterface == null) {
+        continue;
+      }
+      // If the implemented interface is a member of the merge group, then include it's interfaces.
+      if (classes.contains(directlyReachableInterface)) {
+        workList.addIfNotSeen(directlyReachableInterface.getInterfaces());
+      }
+    }
+
+    // Initialize a new worklist with the first layer of indirectly reachable interface types.
+    Set<DexType> directlyReachableInterfaceTypes = workList.getSeenSet();
+    workList = WorkList.newIdentityWorkList();
+    for (DexType directlyReachableInterfaceType : directlyReachableInterfaceTypes) {
+      DexProgramClass directlyReachableInterface =
+          asProgramClassOrNull(appView.definitionFor(directlyReachableInterfaceType));
+      if (directlyReachableInterface != null) {
+        workList.addIfNotSeen(directlyReachableInterface.getInterfaces());
+      }
+    }
+
+    // Report all interfaces from the merge group that are reachable in the class hierarchy from the
+    // worklist.
+    while (workList.hasNext()) {
+      DexProgramClass indirectlyReachableInterface =
+          asProgramClassOrNull(appView.definitionFor(workList.next()));
+      if (indirectlyReachableInterface == null) {
+        continue;
+      }
+      if (classes.contains(indirectlyReachableInterface)) {
+        consumer.accept(indirectlyReachableInterface);
+      }
+      workList.addIfNotSeen(indirectlyReachableInterface.getInterfaces());
+    }
+  }
+
+  private boolean isConnectedToGroup(
+      DexProgramClass clazz,
+      MergeGroup group,
+      Map<DexProgramClass, Set<DexProgramClass>> ineligibleForMerging) {
+    for (DexProgramClass member : group) {
+      if (ineligibleForMerging.getOrDefault(clazz, Collections.emptySet()).contains(member)
+          || ineligibleForMerging.getOrDefault(member, Collections.emptySet()).contains(clazz)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  @Override
+  public String getName() {
+    return "OnlyDirectlyConnectedOrUnrelatedInterfaces";
+  }
+
+  @Override
+  public boolean shouldSkipPolicy() {
+    return !appView.options().horizontalClassMergerOptions().isInterfaceMergingEnabled();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 94b8404..40ce5b8 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -1195,10 +1195,11 @@
   public static class HorizontalClassMergerOptions {
 
     // TODO(b/138781768): Set enable to true when this bug is resolved.
-    public boolean enable =
+    private boolean enable =
         !Version.isDevelopmentVersion()
             || System.getProperty("com.android.tools.r8.disableHorizontalClassMerging") == null;
-    public boolean enableConstructorMerging = true;
+    private boolean enableConstructorMerging = true;
+    private boolean enableInterfaceMerging = false;
 
     public int maxGroupSize = 30;
 
@@ -1229,6 +1230,10 @@
     public boolean isEnabled() {
       return enable;
     }
+
+    public boolean isInterfaceMergingEnabled() {
+      return enableInterfaceMerging;
+    }
   }
 
   public static class ProtoShrinkingOptions {
diff --git a/src/main/java/com/android/tools/r8/utils/IterableUtils.java b/src/main/java/com/android/tools/r8/utils/IterableUtils.java
index 1a61b44..89063c4 100644
--- a/src/main/java/com/android/tools/r8/utils/IterableUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/IterableUtils.java
@@ -19,6 +19,25 @@
 
 public class IterableUtils {
 
+  public static <S, T> boolean allIdentical(Iterable<S> iterable) {
+    return allIdentical(iterable, Function.identity());
+  }
+
+  public static <S, T> boolean allIdentical(Iterable<S> iterable, Function<S, T> fn) {
+    Iterator<S> iterator = iterable.iterator();
+    if (!iterator.hasNext()) {
+      return true;
+    }
+    T first = fn.apply(iterator.next());
+    while (iterator.hasNext()) {
+      T other = fn.apply(iterator.next());
+      if (other != first) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   public static <S, T> boolean any(
       Iterable<S> iterable, Function<S, T> transform, Predicate<T> predicate) {
     for (S element : iterable) {
diff --git a/src/main/java/com/android/tools/r8/utils/ListUtils.java b/src/main/java/com/android/tools/r8/utils/ListUtils.java
index 7b4976c..18ce187 100644
--- a/src/main/java/com/android/tools/r8/utils/ListUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/ListUtils.java
@@ -7,6 +7,7 @@
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Optional;
 import java.util.function.BiFunction;
@@ -153,6 +154,18 @@
     return builder.build();
   }
 
+  public static <T> LinkedList<T> newLinkedList(T element) {
+    LinkedList<T> list = new LinkedList<>();
+    list.add(element);
+    return list;
+  }
+
+  public static <T> LinkedList<T> newLinkedList(ForEachable<T> forEachable) {
+    LinkedList<T> list = new LinkedList<>();
+    forEachable.forEach(list::add);
+    return list;
+  }
+
   public static <T> Optional<T> removeFirstMatch(List<T> list, Predicate<T> element) {
     int index = firstIndexMatching(list, element);
     if (index >= 0) {
diff --git a/src/main/java/com/android/tools/r8/utils/MapUtils.java b/src/main/java/com/android/tools/r8/utils/MapUtils.java
index abae424..c2cf607 100644
--- a/src/main/java/com/android/tools/r8/utils/MapUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/MapUtils.java
@@ -12,6 +12,12 @@
 
 public class MapUtils {
 
+  public static <K, V> Map<K, V> clone(
+      Map<K, V> mapToClone, Map<K, V> newMap, Function<V, V> valueCloner) {
+    mapToClone.forEach((key, value) -> newMap.put(key, valueCloner.apply(value)));
+    return newMap;
+  }
+
   public static <K, V> K firstKey(Map<K, V> map) {
     return map.keySet().iterator().next();
   }
diff --git a/src/main/java/com/android/tools/r8/utils/StreamUtils.java b/src/main/java/com/android/tools/r8/utils/StreamUtils.java
index 8d4ebc3..53a3010 100644
--- a/src/main/java/com/android/tools/r8/utils/StreamUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/StreamUtils.java
@@ -8,6 +8,7 @@
 import java.io.InputStream;
 
 public class StreamUtils {
+
   /**
    * Read all data from the stream into a byte[], close the stream and return the bytes.
    * @return The bytes of the stream
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureSet.java b/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureSet.java
index 5a4bc89..a71c116 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureSet.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexMethodSignatureSet.java
@@ -8,13 +8,19 @@
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexMethodSignature;
+import com.google.common.collect.Iterables;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
 import java.util.Set;
 import java.util.function.Function;
 
-public class DexMethodSignatureSet implements Iterable<DexMethodSignature> {
+public class DexMethodSignatureSet implements Collection<DexMethodSignature> {
+
+  private static final DexMethodSignatureSet EMPTY =
+      new DexMethodSignatureSet(Collections.emptySet());
 
   private final Set<DexMethodSignature> backing;
 
@@ -34,6 +40,10 @@
     return new DexMethodSignatureSet(new LinkedHashSet<>());
   }
 
+  public static DexMethodSignatureSet empty() {
+    return EMPTY;
+  }
+
   public boolean add(DexMethodSignature signature) {
     return backing.add(signature);
   }
@@ -50,8 +60,9 @@
     return add(method.getReference());
   }
 
-  public void addAll(Iterable<DexMethodSignature> signatures) {
-    signatures.forEach(this::add);
+  @Override
+  public boolean addAll(Collection<? extends DexMethodSignature> collection) {
+    return backing.addAll(collection);
   }
 
   public void addAllMethods(Iterable<DexEncodedMethod> methods) {
@@ -64,19 +75,53 @@
 
   public <T> void addAll(Iterable<T> elements, Function<T, Iterable<DexMethodSignature>> fn) {
     for (T element : elements) {
-      addAll(fn.apply(element));
+      Iterables.addAll(this, fn.apply(element));
     }
   }
 
+  @Override
+  public void clear() {
+    backing.clear();
+  }
+
+  @Override
+  public boolean contains(Object o) {
+    return backing.contains(o);
+  }
+
   public boolean contains(DexMethodSignature signature) {
     return backing.contains(signature);
   }
 
   @Override
+  public boolean containsAll(Collection<?> collection) {
+    return backing.containsAll(collection);
+  }
+
+  public boolean containsAnyOf(Iterable<DexMethodSignature> signatures) {
+    for (DexMethodSignature signature : signatures) {
+      if (contains(signature)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  @Override
+  public boolean isEmpty() {
+    return backing.isEmpty();
+  }
+
+  @Override
   public Iterator<DexMethodSignature> iterator() {
     return backing.iterator();
   }
 
+  @Override
+  public boolean remove(Object o) {
+    return backing.remove(o);
+  }
+
   public boolean remove(DexMethodSignature signature) {
     return backing.remove(signature);
   }
@@ -85,6 +130,11 @@
     return remove(method.getSignature());
   }
 
+  @Override
+  public boolean removeAll(Collection<?> collection) {
+    return backing.removeAll(collection);
+  }
+
   public void removeAll(Iterable<DexMethodSignature> signatures) {
     signatures.forEach(this::remove);
   }
@@ -92,4 +142,24 @@
   public void removeAllMethods(Iterable<DexEncodedMethod> methods) {
     methods.forEach(this::remove);
   }
+
+  @Override
+  public boolean retainAll(Collection<?> collection) {
+    return backing.retainAll(collection);
+  }
+
+  @Override
+  public int size() {
+    return backing.size();
+  }
+
+  @Override
+  public Object[] toArray() {
+    return backing.toArray();
+  }
+
+  @Override
+  public <T> T[] toArray(T[] ts) {
+    return backing.toArray(ts);
+  }
 }