Prevent publicizing package private methods that is overridden

This CL implements a check when iterating over the method pool
collection to find if an overriden package private method will have a
different package name. Only if all package private methods in the
hierarchy have the same package we allow to publicize them.

Bug: 181328496
Bug: 172496438
Bug: 150589374
Bug: 172254047
Change-Id: I08112e83ba45912114f4eed445e5fc9a1cd65d3c
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/MemberPoolCollection.java b/src/main/java/com/android/tools/r8/ir/optimize/MemberPoolCollection.java
index 4f699e6..bd0376b 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/MemberPoolCollection.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/MemberPoolCollection.java
@@ -12,6 +12,7 @@
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.WorkList;
 import com.google.common.base.Equivalence;
 import com.google.common.base.Equivalence.Wrapper;
 import java.util.ArrayDeque;
@@ -26,6 +27,7 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.function.BiFunction;
 import java.util.function.Predicate;
 
 // Per-class collection of member signatures.
@@ -158,14 +160,16 @@
 
   public static class MemberPool<T> {
 
-    private Equivalence<T> equivalence;
+    private final DexClass clazz;
+    private final Equivalence<T> equivalence;
     private MemberPool<T> superType;
     private final Set<MemberPool<T>> interfaces = new HashSet<>();
     private final Set<MemberPool<T>> subTypes = new HashSet<>();
     private final Set<Wrapper<T>> memberPool = new HashSet<>();
 
-    MemberPool(Equivalence<T> equivalence) {
+    MemberPool(Equivalence<T> equivalence, DexClass clazz) {
       this.equivalence = equivalence;
+      this.clazz = clazz;
     }
 
     synchronized void linkSupertype(MemberPool<T> superType) {
@@ -193,36 +197,74 @@
     }
 
     public boolean hasSeen(Wrapper<T> member) {
-      return hasSeenAbove(member, true) || hasSeenStrictlyBelow(member);
+      return fold(member, false, true, (t, ignored) -> true);
     }
 
     public boolean hasSeenDirectly(Wrapper<T> member) {
-      return memberPool.contains(member);
+      return here(member, false, (t, ignored) -> true);
     }
 
     public boolean hasSeenStrictlyAbove(Wrapper<T> member) {
-      return hasSeenAbove(member, false);
-    }
-
-    private boolean hasSeenAbove(Wrapper<T> member, boolean inclusive) {
-      if (inclusive && hasSeenDirectly(member)) {
-        return true;
-      }
-      return (superType != null && superType.hasSeenAbove(member, true))
-          || interfaces.stream().anyMatch(itf -> itf.hasSeenAbove(member, true));
+      return above(member, false, false, true, (t, ignored) -> true);
     }
 
     public boolean hasSeenStrictlyBelow(Wrapper<T> member) {
-      return hasSeenBelow(member, false);
+      return below(member, false, true, (t, ignored) -> true);
     }
 
-    private boolean hasSeenBelow(Wrapper<T> member, boolean inclusive) {
-      if (inclusive
-          && (hasSeenDirectly(member)
-              || interfaces.stream().anyMatch(itf -> itf.hasSeenAbove(member, true)))) {
-        return true;
+    private <S> S above(
+        Wrapper<T> member,
+        boolean inclusive,
+        S value,
+        S terminator,
+        BiFunction<DexClass, S, S> accumulator) {
+      WorkList<MemberPool<T>> workList = WorkList.newIdentityWorkList(this);
+      while (workList.hasNext()) {
+        MemberPool<T> next = workList.next();
+        if (inclusive) {
+          value = next.here(member, value, accumulator);
+          if (value == terminator) {
+            return value;
+          }
+        }
+        inclusive = true;
+        if (next.superType != null) {
+          workList.addIfNotSeen(next.superType);
+        }
+        workList.addIfNotSeen(next.interfaces);
       }
-      return subTypes.stream().anyMatch(subType -> subType.hasSeenBelow(member, true));
+      return value;
+    }
+
+    private <S> S here(Wrapper<T> member, S value, BiFunction<DexClass, S, S> accumulator) {
+      if (memberPool.contains(member)) {
+        return accumulator.apply(clazz, value);
+      }
+      return value;
+    }
+
+    public <S> S below(
+        Wrapper<T> member, S value, S terminator, BiFunction<DexClass, S, S> accumulator) {
+      WorkList<MemberPool<T>> workList = WorkList.newIdentityWorkList(this.subTypes);
+      while (workList.hasNext()) {
+        MemberPool<T> next = workList.next();
+        value = next.here(member, value, accumulator);
+        if (value == terminator) {
+          return value;
+        }
+        workList.addIfNotSeen(next.interfaces);
+        workList.addIfNotSeen(next.subTypes);
+      }
+      return value;
+    }
+
+    public <S> S fold(
+        Wrapper<T> member, S initialValue, S terminator, BiFunction<DexClass, S, S> accumulator) {
+      S value = above(member, true, initialValue, terminator, accumulator);
+      if (value == terminator) {
+        return value;
+      }
+      return below(member, initialValue, terminator, accumulator);
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/MethodPoolCollection.java b/src/main/java/com/android/tools/r8/ir/optimize/MethodPoolCollection.java
index 593e827..74a4324 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/MethodPoolCollection.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/MethodPoolCollection.java
@@ -54,7 +54,7 @@
   Runnable computeMemberPoolForClass(DexClass clazz) {
     return () -> {
       MemberPool<DexMethod> methodPool =
-          memberPools.computeIfAbsent(clazz, k -> new MemberPool<>(equivalence));
+          memberPools.computeIfAbsent(clazz, k -> new MemberPool<>(equivalence, k));
       clazz.forEachMethod(
           encodedMethod -> {
             if (methodTester.test(encodedMethod)) {
@@ -65,7 +65,8 @@
         DexClass superClazz = appView.definitionFor(clazz.superType);
         if (superClazz != null) {
           MemberPool<DexMethod> superPool =
-              memberPools.computeIfAbsent(superClazz, k -> new MemberPool<>(equivalence));
+              memberPools.computeIfAbsent(
+                  superClazz, k -> new MemberPool<>(equivalence, superClazz));
           superPool.linkSubtype(methodPool);
           methodPool.linkSupertype(superPool);
         }
@@ -75,7 +76,7 @@
           DexClass subClazz = appView.definitionFor(subtype);
           if (subClazz != null) {
             MemberPool<DexMethod> childPool =
-                memberPools.computeIfAbsent(subClazz, k -> new MemberPool<>(equivalence));
+                memberPools.computeIfAbsent(subClazz, k -> new MemberPool<>(equivalence, subClazz));
             methodPool.linkSubtype(childPool);
             childPool.linkInterface(methodPool);
           }
diff --git a/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java b/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
index a675776..dd67f3a 100644
--- a/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
+++ b/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
@@ -167,14 +167,16 @@
 
     if (accessFlags.isPackagePrivate()) {
       // If we publicize a package private method we have to ensure there is no overrides of it. We
-      // could potentially publicize a method if it only has package-private overrides, but for know
-      // we just check if it is seen below.
-      // Note that we will not publize private methods if there exists a package-private override,
-      // and there is therefore no need to check the hierarchy above.
+      // could potentially publicize a method if it only has package-private overrides.
+      // TODO(b/182136236): See if we can break the hierarchy for clusters.
       MemberPool<DexMethod> memberPool = methodPoolCollection.get(method.getHolder());
       Wrapper<DexMethod> methodKey = MethodSignatureEquivalence.get().wrap(method.getReference());
-      if (memberPool.hasSeenStrictlyBelow(methodKey)
-          && appView.options().enablePackagePrivateAwarePublicization) {
+      if (memberPool.below(
+          methodKey,
+          false,
+          true,
+          (clazz, ignored) ->
+              !method.getContextType().getPackageName().equals(clazz.getType().getPackageName()))) {
         return false;
       }
       doPublicize(method);
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 4d95e6b..7a99160 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -270,8 +270,6 @@
   public boolean encodeChecksums = false;
   public BiPredicate<String, Long> dexClassChecksumFilter = (name, checksum) -> true;
   public boolean cfToCfDesugar = false;
-  // TODO(b/172496438): Temporarily enable publicizing package-private overrides.
-  public boolean enablePackagePrivateAwarePublicization = false;
 
   public int callGraphLikelySpuriousCallEdgeThreshold = 50;
 
diff --git a/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverridePublicizerBottomTest.java b/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverridePublicizerBottomTest.java
index 3c71dbe..6dc8328 100644
--- a/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverridePublicizerBottomTest.java
+++ b/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverridePublicizerBottomTest.java
@@ -4,6 +4,10 @@
 
 package com.android.tools.r8.accessrelaxation;
 
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertTrue;
+
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.NoVerticalClassMerging;
@@ -12,6 +16,8 @@
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.TestRunResult;
 import com.android.tools.r8.utils.DescriptorUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -58,8 +64,17 @@
         .enableNeverClassInliningAnnotations()
         .allowAccessModification()
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/181328496): This should be EXPECTED.
-        .assertSuccessWithOutputLines(EXPECTED_ART_4);
+        // TODO(b/182185057): This is an error in the devirtualizer
+        .assertSuccessWithOutputLines(EXPECTED_ART_4)
+        .inspect(
+            inspector -> {
+              ClassSubject subViewModelSubject =
+                  inspector.clazz(DescriptorUtils.descriptorToJavaType(NEW_DESCRIPTOR));
+              assertThat(subViewModelSubject, isPresent());
+              MethodSubject clearSubject = subViewModelSubject.uniqueMethodWithName("clear");
+              assertThat(clearSubject, isPresent());
+              assertTrue(clearSubject.isPublic());
+            });
   }
 
   private byte[] getSubViewModelInAnotherPackage() throws Exception {
diff --git a/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverridePublicizerTest.java b/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverridePublicizerTest.java
index 5a9da04..ccd0b21 100644
--- a/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverridePublicizerTest.java
+++ b/src/test/java/com/android/tools/r8/accessrelaxation/PackagePrivateOverridePublicizerTest.java
@@ -23,7 +23,6 @@
 
   private final TestParameters parameters;
   private final String[] EXPECTED = new String[] {"SubViewModel.clear()", "ViewModel.clear()"};
-  private final String[] R8_OUT = new String[] {"SubViewModel.clear()", "SubViewModel.clear()"};
 
   @Parameters(name = "{0}")
   public static TestParametersCollection data() {
@@ -52,8 +51,7 @@
         .enableNeverClassInliningAnnotations()
         .allowAccessModification()
         .run(parameters.getRuntime(), Main.class)
-        // TODO(b/172496438): This should be EXPECTED.
-        .assertSuccessWithOutputLines(R8_OUT);
+        .apply(this::assertSuccessOutput);
   }
 
   private void assertSuccessOutput(TestRunResult<?> result) {