Ensure fixup of attributes in class merger

Change-Id: I64b7aedc995803642971816029f23f3353d45a5d
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index ce36598..182c91c 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -76,7 +76,7 @@
    */
   private NestHostClassAttribute nestHost;
 
-  private final List<NestMemberClassAttribute> nestMembers;
+  private List<NestMemberClassAttribute> nestMembers;
 
   /** Generic signature information if the attribute is present in the input */
   protected ClassSignature classSignature;
@@ -1042,6 +1042,10 @@
     this.nestHost = new NestHostClassAttribute(type);
   }
 
+  public void setNestHostAttribute(NestHostClassAttribute nestHostAttribute) {
+    this.nestHost = nestHostAttribute;
+  }
+
   public boolean isNestHost() {
     return !nestMembers.isEmpty();
   }
@@ -1069,10 +1073,22 @@
     return nestHost;
   }
 
+  public boolean hasNestMemberAttributes() {
+    return nestMembers != null && !nestMembers.isEmpty();
+  }
+
   public List<NestMemberClassAttribute> getNestMembersClassAttributes() {
     return nestMembers;
   }
 
+  public void setNestMemberAttributes(List<NestMemberClassAttribute> nestMemberAttributes) {
+    this.nestMembers = nestMemberAttributes;
+  }
+
+  public void removeNestMemberAttributes(Predicate<NestMemberClassAttribute> predicate) {
+    nestMembers.removeIf(predicate);
+  }
+
   /** Returns kotlin class info if the class is synthesized by kotlin compiler. */
   public abstract KotlinClassLevelInfo getKotlinInfo();
 
diff --git a/src/main/java/com/android/tools/r8/graph/EnclosingMethodAttribute.java b/src/main/java/com/android/tools/r8/graph/EnclosingMethodAttribute.java
index 1a3984c..fde5f6c 100644
--- a/src/main/java/com/android/tools/r8/graph/EnclosingMethodAttribute.java
+++ b/src/main/java/com/android/tools/r8/graph/EnclosingMethodAttribute.java
@@ -48,6 +48,10 @@
     }
   }
 
+  public boolean hasEnclosingMethod() {
+    return enclosingMethod != null;
+  }
+
   public DexMethod getEnclosingMethod() {
     return enclosingMethod;
   }
diff --git a/src/main/java/com/android/tools/r8/graph/TreeFixerBase.java b/src/main/java/com/android/tools/r8/graph/TreeFixerBase.java
index 5bbc1ee..969ac7d 100644
--- a/src/main/java/com/android/tools/r8/graph/TreeFixerBase.java
+++ b/src/main/java/com/android/tools/r8/graph/TreeFixerBase.java
@@ -141,7 +141,7 @@
     return newClass;
   }
 
-  private EnclosingMethodAttribute fixupEnclosingMethodAttribute(
+  protected EnclosingMethodAttribute fixupEnclosingMethodAttribute(
       EnclosingMethodAttribute enclosingMethodAttribute) {
     if (enclosingMethodAttribute == null) {
       return null;
@@ -190,7 +190,7 @@
     return dexItemFactory.createField(newHolder, newType, field.name);
   }
 
-  private List<InnerClassAttribute> fixupInnerClassAttributes(
+  protected List<InnerClassAttribute> fixupInnerClassAttributes(
       List<InnerClassAttribute> innerClassAttributes) {
     if (innerClassAttributes.isEmpty()) {
       return innerClassAttributes;
@@ -240,13 +240,13 @@
         fixupType(method.holder), fixupProto(method.proto), method.name);
   }
 
-  private NestHostClassAttribute fixupNestHost(NestHostClassAttribute nestHostClassAttribute) {
+  protected NestHostClassAttribute fixupNestHost(NestHostClassAttribute nestHostClassAttribute) {
     return nestHostClassAttribute != null
         ? new NestHostClassAttribute(fixupType(nestHostClassAttribute.getNestHost()))
         : null;
   }
 
-  private List<NestMemberClassAttribute> fixupNestMemberAttributes(
+  protected List<NestMemberClassAttribute> fixupNestMemberAttributes(
       List<NestMemberClassAttribute> nestMemberAttributes) {
     if (nestMemberAttributes.isEmpty()) {
       return nestMemberAttributes;
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
index de2c86f..5eea854 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/ClassMerger.java
@@ -240,6 +240,24 @@
     }
   }
 
+  void fixNestMemberAttributes() {
+    if (group.getTarget().isInANest() && !group.getTarget().hasNestMemberAttributes()) {
+      for (DexProgramClass clazz : group.getSources()) {
+        if (clazz.hasNestMemberAttributes()) {
+          // The nest host has been merged into a nest member.
+          group.getTarget().clearNestHost();
+          group.getTarget().setNestMemberAttributes(clazz.getNestMembersClassAttributes());
+          group
+              .getTarget()
+              .removeNestMemberAttributes(
+                  nestMemberAttribute ->
+                      nestMemberAttribute.getNestMember() == group.getTarget().getType());
+          break;
+        }
+      }
+    }
+  }
+
   private void mergeAnnotations() {
     assert group.getClasses().stream().filter(DexDefinition::hasAnnotations).count() <= 1;
     for (DexProgramClass clazz : group.getSources()) {
@@ -285,6 +303,7 @@
 
   public void mergeGroup(SyntheticArgumentClass syntheticArgumentClass) {
     fixAccessFlags();
+    fixNestMemberAttributes();
 
     if (group.hasClassIdField()) {
       appendClassIdField();
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
index 34c7362..442489d 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
@@ -18,6 +18,7 @@
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexTypeList;
+import com.android.tools.r8.graph.EnclosingMethodAttribute;
 import com.android.tools.r8.graph.TreeFixerBase;
 import com.android.tools.r8.ir.conversion.ExtraUnusedNullParameter;
 import com.android.tools.r8.shaking.AnnotationFixer;
@@ -116,6 +117,7 @@
   public HorizontalClassMergerGraphLens fixupTypeReferences() {
     List<DexProgramClass> classes = appView.appInfo().classesWithDeterministicOrder();
     Iterables.filter(classes, DexProgramClass::isInterface).forEach(this::fixupInterfaceClass);
+    classes.forEach(this::fixupAttributes);
     classes.forEach(this::fixupProgramClassSuperTypes);
     SubtypingForrestForClasses subtypingForrest = new SubtypingForrestForClasses(appView);
     // TODO(b/170078037): parallelize this code segment.
@@ -127,6 +129,21 @@
     return lens;
   }
 
+  private void fixupAttributes(DexProgramClass clazz) {
+    if (clazz.hasEnclosingMethodAttribute()) {
+      EnclosingMethodAttribute enclosingMethodAttribute = clazz.getEnclosingMethodAttribute();
+      if (mergedClasses.hasBeenMergedIntoDifferentType(
+          enclosingMethodAttribute.getEnclosingType())) {
+        clazz.clearEnclosingMethodAttribute();
+      } else {
+        clazz.setEnclosingMethodAttribute(fixupEnclosingMethodAttribute(enclosingMethodAttribute));
+      }
+    }
+    clazz.setInnerClasses(fixupInnerClassAttributes(clazz.getInnerClasses()));
+    clazz.setNestHostAttribute(fixupNestHost(clazz.getNestHostClassAttribute()));
+    clazz.setNestMemberAttributes(fixupNestMemberAttributes(clazz.getNestMembersClassAttributes()));
+  }
+
   private void fixupProgramClassSuperTypes(DexProgramClass clazz) {
     clazz.superType = fixupType(clazz.superType);
     clazz.setInterfaces(fixupInterfaces(clazz, clazz.getInterfaces()));
diff --git a/src/test/examplesJava11/horizontalclassmerging/BasicNestHostHorizontalClassMerging.java b/src/test/examplesJava11/horizontalclassmerging/BasicNestHostHorizontalClassMerging.java
deleted file mode 100644
index 2362165..0000000
--- a/src/test/examplesJava11/horizontalclassmerging/BasicNestHostHorizontalClassMerging.java
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package horizontalclassmerging;
-
-import com.android.tools.r8.NeverClassInline;
-import com.android.tools.r8.NeverInline;
-
-public class BasicNestHostHorizontalClassMerging {
-  // Prevent merging with BasicNestHostHorizontalClassMerging2.
-  private String name;
-
-  private BasicNestHostHorizontalClassMerging(String name) {
-    this.name = name;
-  }
-
-  @NeverInline
-  private void print(String v) {
-    System.out.println(name + ": " + v);
-  }
-
-  public static void main(String[] args) {
-    BasicNestHostHorizontalClassMerging host = new BasicNestHostHorizontalClassMerging("1");
-    new A(host);
-    new B(host);
-    BasicNestHostHorizontalClassMerging2.main(args);
-  }
-
-  @NeverClassInline
-  public static class A {
-    public A(BasicNestHostHorizontalClassMerging parent) {
-      parent.print("a");
-    }
-  }
-
-  @NeverClassInline
-  public static class B {
-    public B(BasicNestHostHorizontalClassMerging parent) {
-      parent.print("b");
-    }
-  }
-}
diff --git a/src/test/examplesJava11/horizontalclassmerging/BasicNestHostHorizontalClassMerging2.java b/src/test/examplesJava11/horizontalclassmerging/BasicNestHostHorizontalClassMerging2.java
deleted file mode 100644
index 78718e1..0000000
--- a/src/test/examplesJava11/horizontalclassmerging/BasicNestHostHorizontalClassMerging2.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package horizontalclassmerging;
-
-import com.android.tools.r8.NeverClassInline;
-import com.android.tools.r8.NeverInline;
-
-public class BasicNestHostHorizontalClassMerging2 {
-  @NeverInline
-  public static void main(String[] args) {
-    new A();
-    new B();
-  }
-
-  @NeverInline
-  private static void print(String v) {
-    System.out.println("2: " + v);
-  }
-
-  @NeverClassInline
-  public static class A {
-    public A() {
-      print("a");
-    }
-  }
-
-  @NeverClassInline
-  public static class B {
-    public B() {
-      print("b");
-    }
-  }
-}
diff --git a/src/test/examplesJava11/horizontalclassmerging/NestClassMergingTest.java b/src/test/examplesJava11/horizontalclassmerging/NestClassMergingTest.java
new file mode 100644
index 0000000..b8ef472
--- /dev/null
+++ b/src/test/examplesJava11/horizontalclassmerging/NestClassMergingTest.java
@@ -0,0 +1,17 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package horizontalclassmerging;
+
+public class NestClassMergingTest {
+
+  public static void main(String[] args) {
+    NestHostA hostA = new NestHostA();
+    new NestHostA.NestMemberA();
+    new NestHostA.NestMemberB(hostA);
+    NestHostB hostB = new NestHostB();
+    new NestHostB.NestMemberA();
+    new NestHostB.NestMemberB(hostB);
+  }
+}
diff --git a/src/test/examplesJava11/horizontalclassmerging/NestHostA.java b/src/test/examplesJava11/horizontalclassmerging/NestHostA.java
new file mode 100644
index 0000000..13973b6
--- /dev/null
+++ b/src/test/examplesJava11/horizontalclassmerging/NestHostA.java
@@ -0,0 +1,36 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package horizontalclassmerging;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+
+@NeverClassInline
+public class NestHostA {
+
+  @NeverInline
+  private void privatePrint(String v) {
+    System.out.println(v);
+  }
+
+  @NeverInline
+  private static void privateStaticPrint(String v) {
+    System.out.println(v);
+  }
+
+  @NeverClassInline
+  public static class NestMemberA {
+    public NestMemberA() {
+      NestHostA.privateStaticPrint("NestHostA$NestMemberA");
+    }
+  }
+
+  @NeverClassInline
+  public static class NestMemberB {
+    public NestMemberB(NestHostA host) {
+      host.privatePrint("NestHostA$NestMemberB");
+    }
+  }
+}
diff --git a/src/test/examplesJava11/horizontalclassmerging/NestHostB.java b/src/test/examplesJava11/horizontalclassmerging/NestHostB.java
new file mode 100644
index 0000000..99ac10b
--- /dev/null
+++ b/src/test/examplesJava11/horizontalclassmerging/NestHostB.java
@@ -0,0 +1,36 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package horizontalclassmerging;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+
+@NeverClassInline
+public class NestHostB {
+
+  @NeverInline
+  private void privatePrint(String v) {
+    System.out.println(v);
+  }
+
+  @NeverInline
+  private static void privateStaticPrint(String v) {
+    System.out.println(v);
+  }
+
+  @NeverClassInline
+  public static class NestMemberA {
+    public NestMemberA() {
+      NestHostB.privateStaticPrint("NestHostB$NestMemberA");
+    }
+  }
+
+  @NeverClassInline
+  public static class NestMemberB {
+    public NestMemberB(NestHostB host) {
+      host.privatePrint("NestHostB$NestMemberB");
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/R8TestBuilder.java b/src/test/java/com/android/tools/r8/R8TestBuilder.java
index 5594bee..97cf033 100644
--- a/src/test/java/com/android/tools/r8/R8TestBuilder.java
+++ b/src/test/java/com/android/tools/r8/R8TestBuilder.java
@@ -493,6 +493,13 @@
     return addInternalKeepRules("-nohorizontalclassmerging class " + clazz);
   }
 
+  public T addNoHorizontalClassMergingRule(String... classes) {
+    for (String clazz : classes) {
+      addNoHorizontalClassMergingRule(clazz);
+    }
+    return self();
+  }
+
   public T enableMemberValuePropagationAnnotations() {
     return enableMemberValuePropagationAnnotations(true);
   }
diff --git a/src/test/java/com/android/tools/r8/TestBase.java b/src/test/java/com/android/tools/r8/TestBase.java
index dcc3cee..b24d16a 100644
--- a/src/test/java/com/android/tools/r8/TestBase.java
+++ b/src/test/java/com/android/tools/r8/TestBase.java
@@ -1715,8 +1715,16 @@
     return clazz.getTypeName();
   }
 
-  public static String examplesTypeName(Class<? extends ExamplesClass> clazz) throws Exception {
-    return ReflectiveBuildPathUtils.resolveClassName(clazz);
+  public static ClassReference examplesClassReference(Class<? extends ExamplesClass> clazz) {
+    return Reference.classFromTypeName(examplesTypeName(clazz));
+  }
+
+  public static String examplesTypeName(Class<? extends ExamplesClass> clazz) {
+    try {
+      return ReflectiveBuildPathUtils.resolveClassName(clazz);
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
   }
 
   public static AndroidApiLevel apiLevelWithDefaultInterfaceMethodsSupport() {
diff --git a/src/test/java/com/android/tools/r8/TestParameters.java b/src/test/java/com/android/tools/r8/TestParameters.java
index 5a7423a..1df89b1 100644
--- a/src/test/java/com/android/tools/r8/TestParameters.java
+++ b/src/test/java/com/android/tools/r8/TestParameters.java
@@ -43,6 +43,11 @@
         .isGreaterThanOrEqualTo(TestBase.apiLevelWithDefaultInterfaceMethodsSupport());
   }
 
+  public boolean canUseNestBasedAccesses() {
+    assert isCfRuntime() || isDexRuntime();
+    return isCfRuntime() && getRuntime().asCf().isNewerThanOrEqual(CfVm.JDK11);
+  }
+
   // Convenience predicates.
   public boolean isDexRuntime() {
     return runtime.isDex();
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/NestClassMergingTestRunner.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/NestClassMergingTestRunner.java
new file mode 100644
index 0000000..bf50ab1
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/NestClassMergingTestRunner.java
@@ -0,0 +1,269 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.R8FullTestBuilder;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.TestRuntime.CfVm;
+import com.android.tools.r8.ThrowableConsumer;
+import com.android.tools.r8.classmerging.horizontal.NestClassMergingTestRunner.R.horizontalclassmerging.NestClassMergingTest;
+import com.android.tools.r8.classmerging.horizontal.NestClassMergingTestRunner.R.horizontalclassmerging.NestHostA;
+import com.android.tools.r8.classmerging.horizontal.NestClassMergingTestRunner.R.horizontalclassmerging.NestHostB;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.utils.ReflectiveBuildPathUtils.ExamplesClass;
+import com.android.tools.r8.utils.ReflectiveBuildPathUtils.ExamplesJava11RootPackage;
+import com.android.tools.r8.utils.ReflectiveBuildPathUtils.ExamplesPackage;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Streams;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.junit.Test;
+import org.junit.runners.Parameterized;
+
+public class NestClassMergingTestRunner extends HorizontalClassMergingTestBase {
+
+  public static class R extends ExamplesJava11RootPackage {
+    public static class horizontalclassmerging extends ExamplesPackage {
+      public static class NestClassMergingTest extends ExamplesClass {}
+
+      public static class NestHostA extends ExamplesClass {
+        public static class NestMemberA extends ExamplesClass {}
+
+        public static class NestMemberB extends ExamplesClass {}
+      }
+
+      public static class NestHostB extends ExamplesClass {
+        public static class NestMemberA extends ExamplesClass {}
+
+        public static class NestMemberB extends ExamplesClass {}
+      }
+    }
+  }
+
+  public NestClassMergingTestRunner(TestParameters parameters) {
+    super(parameters);
+  }
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters()
+        .withCfRuntimesStartingFromIncluding(CfVm.JDK11)
+        .withDexRuntimes()
+        .withAllApiLevels()
+        .build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    runTest(
+        builder ->
+            builder.addHorizontallyMergedClassesInspector(
+                inspector -> {
+                  if (parameters.canUseNestBasedAccesses()) {
+                    inspector
+                        .assertIsCompleteMergeGroup(
+                            classRef(NestHostA.class),
+                            classRef(NestHostA.NestMemberA.class),
+                            classRef(NestHostA.NestMemberB.class))
+                        .assertIsCompleteMergeGroup(
+                            classRef(NestHostB.class),
+                            classRef(NestHostB.NestMemberA.class),
+                            classRef(NestHostB.NestMemberB.class));
+                  } else {
+                    inspector.assertIsCompleteMergeGroup(
+                        classRef(NestHostA.class),
+                        classRef(NestHostA.NestMemberA.class),
+                        classRef(NestHostA.NestMemberB.class),
+                        classRef(NestHostB.class),
+                        classRef(NestHostB.NestMemberA.class),
+                        classRef(NestHostB.NestMemberB.class));
+                  }
+                }));
+  }
+
+  @Test
+  public void testMergeHostIntoNestMemberA() throws Exception {
+    assumeTrue(parameters.isCfRuntime());
+    runTest(
+        builder ->
+            builder
+                .addHorizontallyMergedClassesInspector(
+                    inspector ->
+                        inspector
+                            .assertIsCompleteMergeGroup(
+                                classRef(NestHostA.class), classRef(NestHostA.NestMemberA.class))
+                            .assertIsCompleteMergeGroup(
+                                classRef(NestHostB.class), classRef(NestHostB.NestMemberA.class))
+                            .assertClassReferencesNotMerged(
+                                classRef(NestHostA.NestMemberB.class),
+                                classRef(NestHostB.NestMemberB.class)))
+                .addNoHorizontalClassMergingRule(
+                    examplesTypeName(NestHostA.NestMemberB.class),
+                    examplesTypeName(NestHostB.NestMemberB.class))
+                .addOptionsModification(
+                    options -> {
+                      options.testing.horizontalClassMergingTarget =
+                          (canditates, target) -> {
+                            Set<ClassReference> candidateClassReferences =
+                                Streams.stream(canditates)
+                                    .map(DexClass::getClassReference)
+                                    .collect(Collectors.toSet());
+                            if (candidateClassReferences.contains(classRef(NestHostA.class))) {
+                              assertEquals(
+                                  ImmutableSet.of(
+                                      classRef(NestHostA.class),
+                                      classRef(NestHostA.NestMemberA.class)),
+                                  candidateClassReferences);
+                            } else {
+                              assertEquals(
+                                  ImmutableSet.of(
+                                      classRef(NestHostB.class),
+                                      classRef(NestHostB.NestMemberA.class)),
+                                  candidateClassReferences);
+                            }
+                            return Iterables.find(
+                                canditates,
+                                candidate -> {
+                                  ClassReference classReference = candidate.getClassReference();
+                                  return classReference.equals(
+                                          classRef(NestHostA.NestMemberA.class))
+                                      || classReference.equals(
+                                          classRef(NestHostB.NestMemberA.class));
+                                });
+                          };
+                    }));
+  }
+
+  @Test
+  public void testMergeHostIntoNestMemberB() throws Exception {
+    assumeTrue(parameters.isCfRuntime());
+    runTest(
+        builder ->
+            builder
+                .addHorizontallyMergedClassesInspector(
+                    inspector ->
+                        inspector
+                            .assertIsCompleteMergeGroup(
+                                classRef(NestHostA.class), classRef(NestHostA.NestMemberB.class))
+                            .assertIsCompleteMergeGroup(
+                                classRef(NestHostB.class), classRef(NestHostB.NestMemberB.class))
+                            .assertClassReferencesNotMerged(
+                                classRef(NestHostA.NestMemberA.class),
+                                classRef(NestHostB.NestMemberA.class)))
+                .addNoHorizontalClassMergingRule(
+                    examplesTypeName(NestHostA.NestMemberA.class),
+                    examplesTypeName(NestHostB.NestMemberA.class))
+                .addOptionsModification(
+                    options -> {
+                      options.testing.horizontalClassMergingTarget =
+                          (canditates, target) -> {
+                            Set<ClassReference> candidateClassReferences =
+                                Streams.stream(canditates)
+                                    .map(DexClass::getClassReference)
+                                    .collect(Collectors.toSet());
+                            if (candidateClassReferences.contains(classRef(NestHostA.class))) {
+                              assertEquals(
+                                  ImmutableSet.of(
+                                      classRef(NestHostA.class),
+                                      classRef(NestHostA.NestMemberB.class)),
+                                  candidateClassReferences);
+                            } else {
+                              assertEquals(
+                                  ImmutableSet.of(
+                                      classRef(NestHostB.class),
+                                      classRef(NestHostB.NestMemberB.class)),
+                                  candidateClassReferences);
+                            }
+                            return Iterables.find(
+                                canditates,
+                                candidate -> {
+                                  ClassReference classReference = candidate.getClassReference();
+                                  return classReference.equals(
+                                          classRef(NestHostA.NestMemberB.class))
+                                      || classReference.equals(
+                                          classRef(NestHostB.NestMemberB.class));
+                                });
+                          };
+                    }));
+  }
+
+  @Test
+  public void testMergeMemberAIntoNestHost() throws Exception {
+    assumeTrue(parameters.isCfRuntime());
+    runTest(
+        builder ->
+            builder
+                .addHorizontallyMergedClassesInspector(
+                    inspector ->
+                        inspector
+                            .assertIsCompleteMergeGroup(
+                                classRef(NestHostA.class), classRef(NestHostA.NestMemberA.class))
+                            .assertIsCompleteMergeGroup(
+                                classRef(NestHostB.class), classRef(NestHostB.NestMemberA.class))
+                            .assertClassReferencesNotMerged(
+                                classRef(NestHostA.NestMemberB.class),
+                                classRef(NestHostB.NestMemberB.class)))
+                .addNoHorizontalClassMergingRule(
+                    examplesTypeName(NestHostA.NestMemberB.class),
+                    examplesTypeName(NestHostB.NestMemberB.class))
+                .addOptionsModification(
+                    options -> {
+                      options.testing.horizontalClassMergingTarget =
+                          (canditates, target) -> {
+                            Set<ClassReference> candidateClassReferences =
+                                Streams.stream(canditates)
+                                    .map(DexClass::getClassReference)
+                                    .collect(Collectors.toSet());
+                            if (candidateClassReferences.contains(classRef(NestHostA.class))) {
+                              assertEquals(
+                                  ImmutableSet.of(
+                                      classRef(NestHostA.class),
+                                      classRef(NestHostA.NestMemberA.class)),
+                                  candidateClassReferences);
+                            } else {
+                              assertEquals(
+                                  ImmutableSet.of(
+                                      classRef(NestHostB.class),
+                                      classRef(NestHostB.NestMemberA.class)),
+                                  candidateClassReferences);
+                            }
+                            return Iterables.find(
+                                canditates,
+                                candidate -> {
+                                  ClassReference classReference = candidate.getClassReference();
+                                  return classReference.equals(classRef(NestHostA.class))
+                                      || classReference.equals(classRef(NestHostB.class));
+                                });
+                          };
+                    }));
+  }
+
+  private void runTest(ThrowableConsumer<R8FullTestBuilder> configuration) throws Exception {
+    testForR8(parameters.getBackend())
+        .addKeepMainRule(examplesTypeName(NestClassMergingTest.class))
+        .addExamplesProgramFiles(R.class)
+        .apply(configuration)
+        .enableInliningAnnotations()
+        .enableNeverClassInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .run(parameters.getRuntime(), examplesTypeName(NestClassMergingTest.class))
+        .assertSuccessWithOutputLines(
+            "NestHostA$NestMemberA",
+            "NestHostA$NestMemberB",
+            "NestHostB$NestMemberA",
+            "NestHostB$NestMemberB");
+  }
+
+  private static ClassReference classRef(Class<? extends ExamplesClass> clazz) {
+    return examplesClassReference(clazz);
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/NestClassTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/NestClassTest.java
deleted file mode 100644
index 215ce7a..0000000
--- a/src/test/java/com/android/tools/r8/classmerging/horizontal/NestClassTest.java
+++ /dev/null
@@ -1,104 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package com.android.tools.r8.classmerging.horizontal;
-
-import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
-import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
-import static com.android.tools.r8.utils.codeinspector.Matchers.isPrivate;
-import static com.android.tools.r8.utils.codeinspector.Matchers.isStatic;
-import static org.hamcrest.MatcherAssert.assertThat;
-
-import com.android.tools.r8.Jdk9TestUtils;
-import com.android.tools.r8.TestParameters;
-import com.android.tools.r8.TestParametersCollection;
-import com.android.tools.r8.TestRuntime.CfVm;
-import com.android.tools.r8.classmerging.horizontal.NestClassTest.R.horizontalclassmerging.BasicNestHostHorizontalClassMerging;
-import com.android.tools.r8.classmerging.horizontal.NestClassTest.R.horizontalclassmerging.BasicNestHostHorizontalClassMerging2;
-import com.android.tools.r8.utils.ReflectiveBuildPathUtils.ExamplesClass;
-import com.android.tools.r8.utils.ReflectiveBuildPathUtils.ExamplesJava11RootPackage;
-import com.android.tools.r8.utils.ReflectiveBuildPathUtils.ExamplesPackage;
-import com.android.tools.r8.utils.codeinspector.ClassSubject;
-import com.android.tools.r8.utils.codeinspector.MethodSubject;
-import org.junit.Test;
-import org.junit.runners.Parameterized;
-
-public class NestClassTest extends HorizontalClassMergingTestBase {
-  public static class R extends ExamplesJava11RootPackage {
-    public static class horizontalclassmerging extends ExamplesPackage {
-      public static class BasicNestHostHorizontalClassMerging extends ExamplesClass {
-        public static class A extends ExamplesClass {}
-
-        public static class B extends ExamplesClass {}
-      }
-
-      public static class BasicNestHostHorizontalClassMerging2 extends ExamplesClass {
-        public static class A extends ExamplesClass {}
-
-        public static class B extends ExamplesClass {}
-      }
-    }
-  }
-
-  public NestClassTest(TestParameters parameters) {
-    super(parameters);
-  }
-
-  @Parameterized.Parameters(name = "{0}")
-  public static TestParametersCollection data() {
-    return getTestParameters().withCfRuntimesStartingFromIncluding(CfVm.JDK11).build();
-  }
-
-  @Test
-  public void testR8() throws Exception {
-    testForR8(parameters.getBackend())
-        .addKeepMainRule(examplesTypeName(BasicNestHostHorizontalClassMerging.class))
-        .addExamplesProgramFiles(R.class)
-        .applyIf(parameters.isCfRuntime(), Jdk9TestUtils.addJdk9LibraryFiles(temp))
-        .enableInliningAnnotations()
-        .enableNeverClassInliningAnnotations()
-        .compile()
-        .run(parameters.getRuntime(), examplesTypeName(BasicNestHostHorizontalClassMerging.class))
-        .assertSuccessWithOutputLines("1: a", "1: b", "2: a", "2: b")
-        .inspect(
-            codeInspector -> {
-              ClassSubject class1A =
-                  codeInspector.clazz(
-                      examplesTypeName(BasicNestHostHorizontalClassMerging.A.class));
-              ClassSubject class2A =
-                  codeInspector.clazz(
-                      examplesTypeName(BasicNestHostHorizontalClassMerging2.A.class));
-              ClassSubject class1 =
-                  codeInspector.clazz(examplesTypeName(BasicNestHostHorizontalClassMerging.class));
-              ClassSubject class2 =
-                  codeInspector.clazz(examplesTypeName(BasicNestHostHorizontalClassMerging2.class));
-              assertThat(class1, isPresent());
-              assertThat(class2, isPresent());
-              assertThat(class1A, isPresent());
-              assertThat(class2A, isPresent());
-
-              MethodSubject printClass1MethodSubject =
-                  class1.method("void", "print", String.class.getTypeName());
-              assertThat(printClass1MethodSubject, isPresent());
-              assertThat(printClass1MethodSubject, isPrivate());
-
-              MethodSubject printClass2MethodSubject =
-                  class2.method("void", "print", String.class.getTypeName());
-              assertThat(printClass2MethodSubject, isPresent());
-              assertThat(printClass2MethodSubject, isPrivate());
-              assertThat(printClass2MethodSubject, isStatic());
-
-              assertThat(
-                  codeInspector.clazz(
-                      examplesTypeName(BasicNestHostHorizontalClassMerging.B.class)),
-                  isAbsent());
-              assertThat(
-                  codeInspector.clazz(
-                      examplesTypeName(BasicNestHostHorizontalClassMerging2.B.class)),
-                  isAbsent());
-
-              // TODO(b/165517236): Explicitly check 1.B is merged into 1.A, and 2.B into 2.A.
-            });
-  }
-}
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/HorizontallyMergedClassesInspector.java b/src/test/java/com/android/tools/r8/utils/codeinspector/HorizontallyMergedClassesInspector.java
index e6fa548..ae22c5e 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/HorizontallyMergedClassesInspector.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/HorizontallyMergedClassesInspector.java
@@ -146,6 +146,11 @@
         Stream.of(classes).map(Reference::classFromClass).collect(Collectors.toList()));
   }
 
+  public HorizontallyMergedClassesInspector assertIsCompleteMergeGroup(
+      ClassReference... classReferences) {
+    return assertIsCompleteMergeGroup(Arrays.asList(classReferences));
+  }
+
   public HorizontallyMergedClassesInspector assertIsCompleteMergeGroup(String... typeNames) {
     return assertIsCompleteMergeGroup(
         Stream.of(typeNames).map(Reference::classFromTypeName).collect(Collectors.toList()));