Fix unsafe state pruning in access modifier traversal

Fixes: b/314984596
Change-Id: Ida1dc01c4355cc50ee561100e29ed639bf4de01f
diff --git a/src/main/java/com/android/tools/r8/graph/ImmediateProgramSubtypingInfo.java b/src/main/java/com/android/tools/r8/graph/ImmediateProgramSubtypingInfo.java
index 8b9a38b..f7be3cf 100644
--- a/src/main/java/com/android/tools/r8/graph/ImmediateProgramSubtypingInfo.java
+++ b/src/main/java/com/android/tools/r8/graph/ImmediateProgramSubtypingInfo.java
@@ -126,4 +126,8 @@
   public List<DexProgramClass> getSubclasses(DexProgramClass clazz) {
     return immediateSubtypes.getOrDefault(clazz, Collections.emptyList());
   }
+
+  public boolean hasSubclasses(DexProgramClass clazz) {
+    return !getSubclasses(clazz).isEmpty();
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/accessmodification/AccessModifierTraversal.java b/src/main/java/com/android/tools/r8/optimize/accessmodification/AccessModifierTraversal.java
index e07407b..50c7bf7 100644
--- a/src/main/java/com/android/tools/r8/optimize/accessmodification/AccessModifierTraversal.java
+++ b/src/main/java/com/android/tools/r8/optimize/accessmodification/AccessModifierTraversal.java
@@ -9,14 +9,12 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.optimize.argumentpropagation.utils.DepthFirstTopDownClassHierarchyTraversal;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.KeepClassInfo;
 import com.android.tools.r8.utils.InternalOptions;
-import com.android.tools.r8.utils.MapUtils;
 import com.android.tools.r8.utils.collections.DexMethodSignatureMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
@@ -31,7 +29,7 @@
   private final AccessModifier accessModifier;
   private final AccessModifierNamingState namingState;
 
-  private final Map<DexType, TraversalState> states = new IdentityHashMap<>();
+  private final Map<DexProgramClass, TraversalState> states = new IdentityHashMap<>();
 
   AccessModifierTraversal(
       AppView<AppInfoWithLiveness> appView,
@@ -57,29 +55,13 @@
     // TODO(b/279126633): Store a top down traversal state for the current class, which contains the
     //  protected and public method signatures when traversing downwards to enable publicizing of
     //  package private methods with illegal overrides.
-    states.put(clazz.getType(), TopDownTraversalState.empty());
+    states.put(clazz, TopDownTraversalState.empty());
   }
 
   /** Called during backtracking when all subclasses of {@param clazz} have been processed. */
   @Override
   public void prune(DexProgramClass clazz) {
-    // Remove the traversal state since all subclasses have now been processed.
-    states.remove(clazz.getType());
-
-    // Remove and join the bottom up traversal states of the subclasses.
-    KeepClassInfo keepInfo = appView.getKeepInfo(clazz);
-    InternalOptions options = appView.options();
-    BottomUpTraversalState state =
-        new BottomUpTraversalState(
-            !keepInfo.isMinificationAllowed(options) && !keepInfo.isShrinkingAllowed(options));
-    forEachSubClass(
-        clazz,
-        subclass -> {
-          BottomUpTraversalState subState =
-              MapUtils.removeOrDefault(states, subclass.getType(), BottomUpTraversalState.empty())
-                  .asBottomUpTraversalState();
-          state.add(subState);
-        });
+    BottomUpTraversalState state = getOrCreateBottomUpTraversalState(clazz);
 
     // Apply access modification to the class and its members.
     accessModifier.processClass(clazz, namingState, state);
@@ -88,19 +70,48 @@
     clazz.forEachProgramVirtualMethod(state::addMethod);
 
     // Store the bottom up traversal state for the current class.
-    if (state.isEmpty()) {
-      states.remove(clazz.getType());
-    } else {
-      states.put(clazz.getType(), state);
+    if (!state.isEmpty()) {
+      immediateSubtypingInfo.forEachImmediateProgramSuperClass(
+          clazz,
+          superClass -> {
+            BottomUpTraversalState superState = getOrCreateBottomUpTraversalState(superClass);
+            superState.add(state);
+          });
     }
+
+    // Done processing the current class and all subclasses.
+    states.remove(clazz);
   }
 
-  abstract static class TraversalState {
+  private BottomUpTraversalState getOrCreateBottomUpTraversalState(DexProgramClass clazz) {
+    TraversalState traversalState = states.get(clazz);
+    if (traversalState == null || traversalState.isTopDownTraversalState()) {
+      KeepClassInfo keepInfo = appView.getKeepInfo(clazz);
+      InternalOptions options = appView.options();
+      BottomUpTraversalState newState =
+          new BottomUpTraversalState(
+              !keepInfo.isMinificationAllowed(options) && !keepInfo.isShrinkingAllowed(options));
+      states.put(clazz, newState);
+      return newState;
+    }
+    assert traversalState.isBottomUpTraversalState();
+    return traversalState.asBottomUpTraversalState();
+  }
+
+  private abstract static class TraversalState {
+
+    boolean isBottomUpTraversalState() {
+      return false;
+    }
 
     BottomUpTraversalState asBottomUpTraversalState() {
       return null;
     }
 
+    boolean isTopDownTraversalState() {
+      return false;
+    }
+
     TopDownTraversalState asTopDownTraversalState() {
       return null;
     }
@@ -108,7 +119,7 @@
 
   // TODO(b/279126633): Collect the protected and public method signatures when traversing downwards
   //  to enable publicizing of package private methods with illegal overrides.
-  static class TopDownTraversalState extends TraversalState {
+  private static class TopDownTraversalState extends TraversalState {
 
     private static final TopDownTraversalState EMPTY = new TopDownTraversalState();
 
@@ -117,12 +128,13 @@
     }
 
     @Override
-    TopDownTraversalState asTopDownTraversalState() {
-      return this;
+    boolean isTopDownTraversalState() {
+      return true;
     }
 
-    boolean isEmpty() {
-      return true;
+    @Override
+    TopDownTraversalState asTopDownTraversalState() {
+      return this;
     }
   }
 
@@ -136,20 +148,29 @@
     // The set of non-private virtual methods below the current class.
     DexMethodSignatureMap<Set<String>> nonPrivateVirtualMethods;
 
-    BottomUpTraversalState(boolean isKept) {
+    private BottomUpTraversalState(boolean isKept) {
       this(DexMethodSignatureMap.create());
       this.isKeptOrHasKeptSubclass = isKept;
     }
 
-    BottomUpTraversalState(DexMethodSignatureMap<Set<String>> packagePrivateMethods) {
+    private BottomUpTraversalState(DexMethodSignatureMap<Set<String>> packagePrivateMethods) {
       this.nonPrivateVirtualMethods = packagePrivateMethods;
     }
 
+    static BottomUpTraversalState asBottomUpTraversalStateOrNull(TraversalState traversalState) {
+      return (BottomUpTraversalState) traversalState;
+    }
+
     static BottomUpTraversalState empty() {
       return EMPTY;
     }
 
     @Override
+    boolean isBottomUpTraversalState() {
+      return true;
+    }
+
+    @Override
     BottomUpTraversalState asBottomUpTraversalState() {
       return this;
     }
diff --git a/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverrideWithInterfacePublicizerTest.java b/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverrideWithInterfacePublicizerTest.java
index 45778a9..8e3c527 100644
--- a/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverrideWithInterfacePublicizerTest.java
+++ b/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverrideWithInterfacePublicizerTest.java
@@ -7,7 +7,6 @@
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
-import com.android.tools.r8.R8TestRunResult;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
@@ -65,15 +64,10 @@
   }
 
   private void assertSuccessOutput(TestRunResult<?> result) {
-    if (result instanceof R8TestRunResult) {
-      // TODO(b/314984596): Should pass with expected output.
-      result.assertSuccessWithOutputLines("SubViewModel.clear()", "SubViewModel.clear()");
+    if (parameters.isDexRuntime() && parameters.getDexRuntimeVersion().isDalvik()) {
+      result.assertFailureWithErrorThatMatches(containsString("overrides final"));
     } else {
-      if (parameters.isDexRuntime() && parameters.getDexRuntimeVersion().isDalvik()) {
-        result.assertFailureWithErrorThatMatches(containsString("overrides final"));
-      } else {
-        result.assertSuccessWithOutputLines("SubViewModel.clear()", "ViewModel.clear()");
-      }
+      result.assertSuccessWithOutputLines("SubViewModel.clear()", "ViewModel.clear()");
     }
   }