Nest attributes should be updated

    In the Cf backend, nest attributes should be updated
    when classes are removed (tree shaking, class merge, etc.)
    - Nest host stop referencing dead nest members.
    - Nest member claims nest ownership if nest host is dead.

Bug: 130716228
Change-Id: I316a31cbd71c8499a07777791aa294b8d46955a8
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index 331699a..fb96f76 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -61,7 +61,7 @@
   /** InnerClasses table. If this class is an inner class, it will have an entry here. */
   private final List<InnerClassAttribute> innerClasses;
 
-  private final NestHostClassAttribute nestHost;
+  private NestHostClassAttribute nestHost;
   private final List<NestMemberClassAttribute> nestMembers;
 
   public DexAnnotationSet annotations;
@@ -824,19 +824,31 @@
   }
 
   public boolean isInANest() {
-    assert nestMembers != null;
-    return !(nestMembers.isEmpty()) || (nestHost != null);
+    return isNestHost() || isNestMember();
+  }
+
+  public void clearNestHost() {
+    nestHost = null;
+  }
+
+  public void setNestHost(DexType type) {
+    assert type != null;
+    this.nestHost = new NestHostClassAttribute(type);
   }
 
   public boolean isNestHost() {
     return !nestMembers.isEmpty();
   }
 
+  public boolean isNestMember() {
+    return nestHost != null;
+  }
+
   public DexType getNestHost() {
-    assert isInANest();
-    if (nestHost != null) {
+    if (isNestMember()) {
       return nestHost.getNestHost();
     }
+    assert isNestHost();
     return type;
   }
 
diff --git a/src/main/java/com/android/tools/r8/graph/JarClassFileReader.java b/src/main/java/com/android/tools/r8/graph/JarClassFileReader.java
index df420b0..b6e7513 100644
--- a/src/main/java/com/android/tools/r8/graph/JarClassFileReader.java
+++ b/src/main/java/com/android/tools/r8/graph/JarClassFileReader.java
@@ -263,7 +263,6 @@
     public void visitNestHost(String nestHost) {
       assert this.nestHost == null && nestMembers.isEmpty();
       DexType nestHostType = application.getTypeFromName(nestHost);
-      // TODO anonymous classes b/130716158
       this.nestHost = new NestHostClassAttribute(nestHostType);
     }
 
@@ -271,7 +270,6 @@
     public void visitNestMember(String nestMember) {
       assert nestHost == null;
       DexType nestMemberType = application.getTypeFromName(nestMember);
-      // TODO anonymous classes b/130716158
       nestMembers.add(new NestMemberClassAttribute(nestMemberType));
     }
 
diff --git a/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java b/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
index fc81ded..f08f81a 100644
--- a/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
+++ b/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
@@ -168,6 +168,8 @@
 
     for (NestMemberClassAttribute entry : clazz.getNestMembersClassAttributes()) {
       entry.write(writer, namingLens);
+      assert clazz.getNestHostClassAttribute() == null
+          : "A nest host cannot also be a nest member.";
     }
 
     for (InnerClassAttribute entry : clazz.getInnerClasses()) {
diff --git a/src/main/java/com/android/tools/r8/shaking/TreePruner.java b/src/main/java/com/android/tools/r8/shaking/TreePruner.java
index 1671197..24dcae4 100644
--- a/src/main/java/com/android/tools/r8/shaking/TreePruner.java
+++ b/src/main/java/com/android/tools/r8/shaking/TreePruner.java
@@ -14,6 +14,7 @@
 import com.android.tools.r8.graph.EnclosingMethodAttribute;
 import com.android.tools.r8.graph.InnerClassAttribute;
 import com.android.tools.r8.graph.KeyedDexItem;
+import com.android.tools.r8.graph.NestMemberClassAttribute;
 import com.android.tools.r8.graph.PresortedComparable;
 import com.android.tools.r8.logging.Log;
 import com.android.tools.r8.utils.InternalOptions;
@@ -132,9 +133,48 @@
     }
     clazz.removeInnerClasses(this::isAttributeReferencingPrunedType);
     clazz.removeEnclosingMethod(this::isAttributeReferencingPrunedItem);
+    rewriteNestAttributes(clazz);
     usagePrinter.visited();
   }
 
+  private void rewriteNestAttributes(DexProgramClass clazz) {
+    if (!clazz.isInANest() || !isTypeLive(clazz.type)) {
+      return;
+    }
+    if (clazz.isNestHost()) {
+      clearDeadNestMembers(clazz);
+    } else {
+      assert clazz.isNestMember();
+      if (!isTypeLive(clazz.getNestHost())) {
+        claimNestOwnership(clazz);
+      }
+    }
+  }
+
+  private boolean isTypeLive(DexType type) {
+    return appView.appInfo().liveTypes.contains(type);
+  }
+
+  private void clearDeadNestMembers(DexClass nestHost) {
+    nestHost
+        .getNestMembersClassAttributes()
+        .removeIf(nestMemberAttr -> !isTypeLive(nestMemberAttr.getNestMember()));
+  }
+
+  private void claimNestOwnership(DexClass newHost) {
+    DexClass previousHost = appView.definitionFor(newHost.getNestHost());
+    assert previousHost != null;
+    newHost.clearNestHost();
+    for (NestMemberClassAttribute attr : previousHost.getNestMembersClassAttributes()) {
+      if (attr.getNestMember() != newHost.type && isTypeLive(attr.getNestMember())) {
+        DexClass nestMember = appView.definitionFor(attr.getNestMember());
+        assert nestMember != null;
+        nestMember.setNestHost(newHost.type);
+        newHost.getNestMembersClassAttributes().add(new NestMemberClassAttribute(nestMember.type));
+      }
+    }
+  }
+
   private boolean isAttributeReferencingPrunedItem(EnclosingMethodAttribute attr) {
     AppInfoWithLiveness appInfo = appView.appInfo();
     return
diff --git a/src/test/examplesJava11/nestHostExample/BasicNestHostClassMerging.java b/src/test/examplesJava11/nestHostExample/BasicNestHostClassMerging.java
new file mode 100644
index 0000000..0f01e9a
--- /dev/null
+++ b/src/test/examplesJava11/nestHostExample/BasicNestHostClassMerging.java
@@ -0,0 +1,32 @@
+package nestHostExample;
+
+public class BasicNestHostClassMerging {
+
+  private String field = "Outer";
+
+  public static class MiddleOuter extends BasicNestHostClassMerging {
+
+    private String field = "Middle";
+
+    public static void main(String[] args) {
+      System.out.println(new InnerMost().getFields());
+    }
+  }
+
+  public static class MiddleInner extends MiddleOuter {
+    private String field = "Inner";
+  }
+
+  public static class InnerMost extends MiddleInner {
+
+    public String getFields() {
+      return ((BasicNestHostClassMerging) this).field
+          + ((MiddleOuter) this).field
+          + ((MiddleInner) this).field;
+    }
+  }
+
+  public static void main(String[] args) {
+    System.out.println(new InnerMost().getFields());
+  }
+}
diff --git a/src/test/examplesJava11/nestHostExample/BasicNestHostTreePruning.java b/src/test/examplesJava11/nestHostExample/BasicNestHostTreePruning.java
new file mode 100644
index 0000000..7c0ca32
--- /dev/null
+++ b/src/test/examplesJava11/nestHostExample/BasicNestHostTreePruning.java
@@ -0,0 +1,24 @@
+package nestHostExample;
+
+public class BasicNestHostTreePruning {
+
+  private String field = "NotPruned";
+
+  public static class NotPruned extends BasicNestHostTreePruning {
+
+    public String getFields() {
+      return ((BasicNestHostTreePruning) this).field;
+    }
+  }
+
+  public static class Pruned {
+
+    public static void main(String[] args) {
+      System.out.println("NotPruned");
+    }
+  }
+
+  public static void main(String[] args) {
+    System.out.println(new NotPruned().getFields());
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/desugar/nestaccesscontrol/NestAccessControlTestUtils.java b/src/test/java/com/android/tools/r8/desugar/nestaccesscontrol/NestAccessControlTestUtils.java
index 216580e..a6acd2e 100644
--- a/src/test/java/com/android/tools/r8/desugar/nestaccesscontrol/NestAccessControlTestUtils.java
+++ b/src/test/java/com/android/tools/r8/desugar/nestaccesscontrol/NestAccessControlTestUtils.java
@@ -36,6 +36,13 @@
           "BasicNestHostWithAnonymousInnerClass",
           "BasicNestHostWithAnonymousInnerClass$1",
           "BasicNestHostWithAnonymousInnerClass$InterfaceForAnonymousClass",
+          "BasicNestHostClassMerging",
+          "BasicNestHostClassMerging$MiddleInner",
+          "BasicNestHostClassMerging$MiddleOuter",
+          "BasicNestHostClassMerging$InnerMost",
+          "BasicNestHostTreePruning",
+          "BasicNestHostTreePruning$Pruned",
+          "BasicNestHostTreePruning$NotPruned",
           "NestHostExample",
           "NestHostExample$NestMemberInner",
           "NestHostExample$NestMemberInner$NestMemberInnerInner",
@@ -50,12 +57,13 @@
   public static final ImmutableList<String> NEST_IDS =
       ImmutableList.of("fields", "methods", "constructors", "anonymous", "all");
   public static final ImmutableMap<String, String> MAIN_CLASSES =
-      ImmutableMap.of(
-          "fields", "BasicNestHostWithInnerClassFields",
-          "methods", "BasicNestHostWithInnerClassMethods",
-          "constructors", "BasicNestHostWithInnerClassConstructors",
-          "anonymous", "BasicNestHostWithAnonymousInnerClass",
-          "all", "NestHostExample");
+      ImmutableMap.<String, String>builder()
+          .put("fields", "BasicNestHostWithInnerClassFields")
+          .put("methods", "BasicNestHostWithInnerClassMethods")
+          .put("constructors", "BasicNestHostWithInnerClassConstructors")
+          .put("anonymous", "BasicNestHostWithAnonymousInnerClass")
+          .put("all", "NestHostExample")
+          .build();
   public static final String ALL_RESULT_LINE =
       String.join(
           ", ",
@@ -84,19 +92,23 @@
             "nest2Method"
           });
   public static final ImmutableMap<String, String> EXPECTED_RESULTS =
-      ImmutableMap.of(
-          "fields",
+      ImmutableMap.<String, String>builder()
+          .put(
+              "fields",
               StringUtils.lines(
-                  "RWnestFieldRWRWnestFieldRWRWnestFieldnoBridge", "RWfieldRWRWfieldRWRWnestField"),
-          "methods",
+                  "RWnestFieldRWRWnestFieldRWRWnestFieldnoBridge", "RWfieldRWRWfieldRWRWnestField"))
+          .put(
+              "methods",
               StringUtils.lines(
                   "nestMethodstaticNestMethodstaticNestMethodnoBridge",
-                  "hostMethodstaticHostMethodstaticNestMethod"),
-          "constructors", StringUtils.lines("field", "nest1SField", "1"),
-          "anonymous",
+                  "hostMethodstaticHostMethodstaticNestMethod"))
+          .put("constructors", StringUtils.lines("field", "nest1SField", "1"))
+          .put(
+              "anonymous",
               StringUtils.lines(
-                  "fieldstaticFieldstaticFieldhostMethodstaticHostMethodstaticHostMethod"),
-          "all",
+                  "fieldstaticFieldstaticFieldhostMethodstaticHostMethodstaticHostMethod"))
+          .put(
+              "all",
               StringUtils.lines(
                   ALL_RESULT_LINE,
                   ALL_RESULT_LINE,
@@ -105,7 +117,8 @@
                   "staticInterfaceMethodstaticStaticInterfaceMethod",
                   "staticInterfaceMethodstaticStaticInterfaceMethod",
                   "staticInterfaceMethodstaticStaticInterfaceMethod",
-                  "staticInterfaceMethodstaticStaticInterfaceMethod"));
+                  "staticInterfaceMethodstaticStaticInterfaceMethod"))
+          .build();
 
   public static String getMainClass(String id) {
     return PACKAGE_NAME + MAIN_CLASSES.get(id);
@@ -116,8 +129,12 @@
   }
 
   public static List<Path> classesOfNest(String nestID) {
+    return classesMatching(MAIN_CLASSES.get(nestID));
+  }
+
+  public static List<Path> classesMatching(String matcher) {
     return CLASS_NAMES.stream()
-        .filter(name -> containsString(MAIN_CLASSES.get(nestID)).matches(name))
+        .filter(name -> containsString(matcher).matches(name))
         .map(name -> CLASSES_PATH.resolve(name + CLASS_EXTENSION))
         .collect(toList());
   }
diff --git a/src/test/java/com/android/tools/r8/desugar/nestaccesscontrol/NestAttributesUpdateTest.java b/src/test/java/com/android/tools/r8/desugar/nestaccesscontrol/NestAttributesUpdateTest.java
new file mode 100644
index 0000000..c250934
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/nestaccesscontrol/NestAttributesUpdateTest.java
@@ -0,0 +1,115 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.desugar.nestaccesscontrol;
+
+import static com.android.tools.r8.desugar.nestaccesscontrol.NestAccessControlTestUtils.PACKAGE_NAME;
+import static com.android.tools.r8.desugar.nestaccesscontrol.NestAccessControlTestUtils.classesMatching;
+import static junit.framework.TestCase.assertNotNull;
+import static junit.framework.TestCase.assertSame;
+import static junit.framework.TestCase.assertTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.TestRuntime.CfVm;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.NestMemberClassAttribute;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FoundClassSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class NestAttributesUpdateTest extends TestBase {
+
+  public NestAttributesUpdateTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  private final TestParameters parameters;
+
+  private final String MERGING_OUTER_CLASS = "BasicNestHostClassMerging";
+  private final String PRUNING_OUTER_CLASS = "BasicNestHostTreePruning";
+  private final String MERGING_EXPECTED_RESULT = StringUtils.lines("OuterMiddleInner");
+  private final String PRUNING_EXPECTED_RESULT = StringUtils.lines("NotPruned");
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters()
+        .withCfRuntimesStartingFromIncluding(CfVm.JDK11)
+        .withAllApiLevels()
+        .build();
+  }
+
+  @Test
+  public void testClassMergingNestMemberRemoval() throws Exception {
+    testNestAttributesCorrect(MERGING_OUTER_CLASS, MERGING_OUTER_CLASS, MERGING_EXPECTED_RESULT);
+  }
+
+  @Test
+  public void testClassMergingNestHostRemoval() throws Exception {
+    testNestAttributesCorrect(
+        MERGING_OUTER_CLASS + "$MiddleOuter", MERGING_OUTER_CLASS, MERGING_EXPECTED_RESULT);
+  }
+
+  @Test
+  public void testTreePruningNestMemberRemoval() throws Exception {
+    testNestAttributesCorrect(PRUNING_OUTER_CLASS, PRUNING_OUTER_CLASS, PRUNING_EXPECTED_RESULT);
+  }
+
+  @Test
+  public void testTreePruningNestHostRemoval() throws Exception {
+    testNestAttributesCorrect(
+        PRUNING_OUTER_CLASS + "$Pruned", PRUNING_OUTER_CLASS, PRUNING_EXPECTED_RESULT);
+  }
+
+  public void testNestAttributesCorrect(
+      String mainClassName, String outerNestName, String expectedResult) throws Exception {
+    String actualMainClassName = PACKAGE_NAME + mainClassName;
+    testForR8(parameters.getBackend())
+        .noMinification()
+        .addKeepMainRule(actualMainClassName)
+        .setMinApi(parameters.getApiLevel())
+        .addProgramFiles(classesMatching(outerNestName))
+        .addOptionsModification(options -> options.enableNestBasedAccessDesugaring = true)
+        .compile()
+        .inspect(this::assertNestAttributesCorrect)
+        .run(parameters.getRuntime(), actualMainClassName)
+        .assertSuccessWithOutput(expectedResult);
+  }
+
+  private void assertNestAttributesCorrect(CodeInspector inspector) {
+    for (FoundClassSubject classSubject : inspector.allClasses()) {
+      DexClass clazz = classSubject.getDexClass();
+      if (clazz.isInANest()) {
+        if (clazz.isNestHost()) {
+          // All members are present with the clazz as host
+          for (NestMemberClassAttribute attr : clazz.getNestMembersClassAttributes()) {
+            String memberName = attr.getNestMember().getName();
+            ClassSubject inner = inspector.clazz(PACKAGE_NAME + memberName);
+            assertNotNull(
+                "The nest member " + memberName + " of " + clazz.type.getName() + " is missing",
+                inner.getDexClass());
+            assertSame(inner.getDexClass().getNestHost(), clazz.type);
+          }
+        } else {
+          // Nest host is present and with the clazz as member
+          String hostName = clazz.getNestHost().getName();
+          ClassSubject host = inspector.clazz(PACKAGE_NAME + hostName);
+          assertNotNull(
+              "The nest host " + hostName + " of " + clazz.type.getName() + " is missing",
+              host.getDexClass());
+          assertTrue(
+              host.getDexClass().getNestMembersClassAttributes().stream()
+                  .anyMatch(attr -> attr.getNestMember() == clazz.type));
+        }
+      }
+    }
+  }
+}