Avoid super hierarchy traversals when visiting all classes

Change-Id: Ib8026a81fefb969394bef10115b393fe3de8bb90
diff --git a/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java b/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
index de54647..2d65292 100644
--- a/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
+++ b/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
@@ -64,12 +64,14 @@
 import com.android.tools.r8.shaking.KeepInfo.Joiner;
 import com.android.tools.r8.threading.TaskCollection;
 import com.android.tools.r8.utils.ArrayUtils;
+import com.android.tools.r8.utils.BiPredicateUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.LazyBox;
 import com.android.tools.r8.utils.MethodSignatureEquivalence;
 import com.android.tools.r8.utils.OriginWithPosition;
 import com.android.tools.r8.utils.PredicateSet;
 import com.android.tools.r8.utils.Reporter;
+import com.android.tools.r8.utils.SetUtils;
 import com.android.tools.r8.utils.StringDiagnostic;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.TraversalContinuation;
@@ -101,6 +103,7 @@
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
+import java.util.function.BiPredicate;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 import java.util.stream.Collectors;
@@ -213,11 +216,19 @@
       return this;
     }
 
-    // Process a class with the keep rule.
     private void process(
         DexClass clazz,
         ProguardConfigurationRule rule,
         ProguardIfRulePreconditionMatch ifRulePreconditionMatch) {
+      process(clazz, rule, ifRulePreconditionMatch, BiPredicateUtils.alwaysTrue());
+    }
+
+    // Process a class with the keep rule.
+    private void process(
+        DexClass clazz,
+        ProguardConfigurationRule rule,
+        ProguardIfRulePreconditionMatch ifRulePreconditionMatch,
+        BiPredicate<DexClass, DexClass> includeSuperclass) {
       if (!satisfyClassType(rule, clazz)) {
         return;
       }
@@ -248,14 +259,29 @@
         if (clazz.isNotProgramClass()) {
           return;
         }
-        switch (((ProguardKeepRule) rule).getType()) {
+        ProguardKeepRule keepRule = (ProguardKeepRule) rule;
+        switch (keepRule.getType()) {
           case KEEP_CLASS_MEMBERS:
             // Members mentioned at -keepclassmembers always depend on their holder.
             preconditionSupplier = ImmutableMap.of(definition -> true, clazz.asProgramClass());
             markMatchingVisibleMethods(
-                clazz, memberKeepRules, rule, preconditionSupplier, false, ifRulePreconditionMatch);
+                clazz,
+                memberKeepRules,
+                rule,
+                preconditionSupplier,
+                includeSuperclass,
+                false,
+                ifRulePreconditionMatch);
             markMatchingVisibleFields(
-                clazz, memberKeepRules, rule, preconditionSupplier, false, ifRulePreconditionMatch);
+                clazz,
+                memberKeepRules,
+                rule,
+                preconditionSupplier,
+                includeSuperclass,
+                false,
+                ifRulePreconditionMatch);
+            synthesizeMissingInterfaceMethodsForMemberRules(
+                clazz, memberKeepRules, keepRule, preconditionSupplier, ifRulePreconditionMatch);
             break;
           case KEEP_CLASSES_WITH_MEMBERS:
             if (!allRulesSatisfied(memberKeepRules, clazz)) {
@@ -277,9 +303,23 @@
               preconditionSupplier.put(alwaysTrue(), null);
             }
             markMatchingVisibleMethods(
-                clazz, memberKeepRules, rule, preconditionSupplier, false, ifRulePreconditionMatch);
+                clazz,
+                memberKeepRules,
+                rule,
+                preconditionSupplier,
+                includeSuperclass,
+                false,
+                ifRulePreconditionMatch);
             markMatchingVisibleFields(
-                clazz, memberKeepRules, rule, preconditionSupplier, false, ifRulePreconditionMatch);
+                clazz,
+                memberKeepRules,
+                rule,
+                preconditionSupplier,
+                includeSuperclass,
+                false,
+                ifRulePreconditionMatch);
+            synthesizeMissingInterfaceMethodsForMemberRules(
+                clazz, memberKeepRules, keepRule, preconditionSupplier, ifRulePreconditionMatch);
             break;
           case CONDITIONAL:
             throw new Unreachable("-if rule will be evaluated separately, not here.");
@@ -300,25 +340,25 @@
           || rule instanceof ProguardWhyAreYouKeepingRule) {
         markClass(clazz, rule, ifRulePreconditionMatch);
         markMatchingVisibleMethods(
-            clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
+            clazz, memberKeepRules, rule, null, includeSuperclass, true, ifRulePreconditionMatch);
         markMatchingVisibleFields(
-            clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
+            clazz, memberKeepRules, rule, null, includeSuperclass, true, ifRulePreconditionMatch);
       } else if (rule instanceof ProguardAssumeMayHaveSideEffectsRule) {
         markMatchingVisibleMethods(
-            clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
+            clazz, memberKeepRules, rule, null, includeSuperclass, true, ifRulePreconditionMatch);
         markMatchingOverriddenMethods(
             clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
         markMatchingVisibleFields(
-            clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
+            clazz, memberKeepRules, rule, null, includeSuperclass, true, ifRulePreconditionMatch);
       } else if (rule instanceof ProguardAssumeNoSideEffectRule
           || rule instanceof ProguardAssumeValuesRule) {
         if (assumeInfoCollectionBuilder != null) {
           markMatchingVisibleMethods(
-              clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
+              clazz, memberKeepRules, rule, null, includeSuperclass, true, ifRulePreconditionMatch);
           markMatchingOverriddenMethods(
               clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
           markMatchingVisibleFields(
-              clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
+              clazz, memberKeepRules, rule, null, includeSuperclass, true, ifRulePreconditionMatch);
         }
       } else if (rule instanceof NoFieldTypeStrengtheningRule
           || rule instanceof NoRedundantFieldLoadEliminationRule) {
@@ -345,9 +385,9 @@
         }
       } else if (rule instanceof NoValuePropagationRule) {
         markMatchingVisibleMethods(
-            clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
+            clazz, memberKeepRules, rule, null, includeSuperclass, true, ifRulePreconditionMatch);
         markMatchingVisibleFields(
-            clazz, memberKeepRules, rule, null, true, ifRulePreconditionMatch);
+            clazz, memberKeepRules, rule, null, includeSuperclass, true, ifRulePreconditionMatch);
       } else if (rule instanceof ProguardIdentifierNameStringRule) {
         markMatchingFields(clazz, memberKeepRules, rule, null, ifRulePreconditionMatch);
         markMatchingMethods(clazz, memberKeepRules, rule, null, ifRulePreconditionMatch);
@@ -377,10 +417,15 @@
 
       tasks.submit(
           () -> {
-            for (DexProgramClass clazz :
-                rule.relevantCandidatesForRule(
-                    appView, subtypingInfo, application.classes(), alwaysTrue())) {
-              process(clazz, rule, ifRulePreconditionMatch);
+            Collection<DexProgramClass> allCandidates = application.classes();
+            Iterable<DexProgramClass> relevantCandidates =
+                rule.relevantCandidatesForRule(appView, subtypingInfo, allCandidates, alwaysTrue());
+            for (DexProgramClass clazz : relevantCandidates) {
+              process(
+                  clazz,
+                  rule,
+                  ifRulePreconditionMatch,
+                  getIncludeSuperclassPredicate(allCandidates, relevantCandidates));
             }
             if (rule.applyToNonProgramClasses()) {
               for (DexLibraryClass clazz : application.libraryClasses()) {
@@ -390,6 +435,21 @@
           });
     }
 
+    private static BiPredicate<DexClass, DexClass> getIncludeSuperclassPredicate(
+        Collection<DexProgramClass> allCandidates, Iterable<DexProgramClass> relevantCandidates) {
+      if (relevantCandidates == allCandidates) {
+        return BiPredicateUtils.alwaysFalse();
+      } else {
+        Set<DexProgramClass> relevantCandidatesSet =
+            (relevantCandidates instanceof Set<?>)
+                ? (Set<DexProgramClass>) relevantCandidates
+                : SetUtils.newIdentityHashSet(relevantCandidates);
+        return (clazz, superclass) ->
+            !superclass.isProgramClass()
+                || !relevantCandidatesSet.contains(superclass.asProgramClass());
+      }
+    }
+
     public RootSet build(ExecutorService executorService) throws ExecutionException {
       application.timing.begin("Build root set...");
       try {
@@ -532,22 +592,19 @@
         Collection<ProguardMemberRule> memberKeepRules,
         ProguardConfigurationRule rule,
         Map<Predicate<DexDefinition>, DexProgramClass> preconditionSupplier,
+        BiPredicate<DexClass, DexClass> includeSuperclass,
         boolean includeLibraryClasses,
         ProguardIfRulePreconditionMatch ifRulePreconditionMatch) {
       Set<Wrapper<DexMethod>> methodsMarked =
           options.forceProguardCompatibility ? null : new HashSet<>();
-      Deque<DexClass> worklist = new ArrayDeque<>();
-      worklist.add(clazz);
-      while (!worklist.isEmpty()) {
-        DexClass currentClass = worklist.pop();
-        if (!includeLibraryClasses && currentClass.isNotProgramClass()) {
-          break;
-        }
+      DexClass currentClass = clazz;
+      while (includeLibraryClasses || currentClass.isProgramClass()) {
         // In compat mode traverse all direct methods in the hierarchy.
+        boolean isClass = currentClass == clazz;
         currentClass.forEachClassMethodMatching(
             method ->
                 method.belongsToVirtualPool()
-                    || currentClass == clazz
+                    || isClass
                     || (method.isStatic() && !method.isPrivate() && !method.isInitializer())
                     || options.forceProguardCompatibility,
             method -> {
@@ -561,17 +618,26 @@
                   precondition,
                   ifRulePreconditionMatch);
             });
-        if (currentClass.superType != null) {
-          DexClass dexClass = application.definitionFor(currentClass.superType);
-          if (dexClass != null) {
-            worklist.add(dexClass);
+        if (currentClass.hasSuperType()) {
+          DexClass superclass = application.definitionFor(currentClass.getSuperType());
+          if (superclass != null && includeSuperclass.test(currentClass, superclass)) {
+            currentClass = superclass;
+            continue;
           }
         }
+        break;
       }
+    }
+
+    private void synthesizeMissingInterfaceMethodsForMemberRules(
+        DexClass clazz,
+        Collection<ProguardMemberRule> memberKeepRules,
+        ProguardKeepRule rule,
+        Map<Predicate<DexDefinition>, DexProgramClass> preconditionSupplier,
+        ProguardIfRulePreconditionMatch ifRulePreconditionMatch) {
       // TODO(b/143643942): Generalize the below approach to also work for subtyping hierarchies in
       //  fullmode.
       if (clazz.isProgramClass()
-          && rule.isProguardKeepRule()
           && !rule.asProguardKeepRule().getModifiers().allowsShrinking
           && !isMainDexRootSetBuilder()) {
         new SynthesizeMissingInterfaceMethodsForMemberRules(
@@ -766,19 +832,24 @@
         Collection<ProguardMemberRule> memberKeepRules,
         ProguardConfigurationRule rule,
         Map<Predicate<DexDefinition>, DexProgramClass> preconditionSupplier,
+        BiPredicate<DexClass, DexClass> includeSuperclass,
         boolean includeLibraryClasses,
         ProguardIfRulePreconditionMatch ifRulePreconditionMatch) {
-      while (clazz != null) {
-        if (!includeLibraryClasses && clazz.isNotProgramClass()) {
-          return;
-        }
+      while (includeLibraryClasses || clazz.isProgramClass()) {
         clazz.forEachClassField(
             field -> {
               DexProgramClass precondition =
                   testAndGetPrecondition(field.getDefinition(), preconditionSupplier);
               markField(field, memberKeepRules, rule, precondition, ifRulePreconditionMatch);
             });
-        clazz = clazz.superType == null ? null : application.definitionFor(clazz.superType);
+        if (clazz.hasSuperType()) {
+          DexClass superclass = appView.definitionFor(clazz.getSuperType());
+          if (superclass != null && includeSuperclass.test(clazz, superclass)) {
+            clazz = superclass;
+            continue;
+          }
+        }
+        break;
       }
     }