Account for classpath classes in -if rule evaluator

Fixes: b/389598240
Change-Id: I6351b7262f15f6d31f7d91067f1286606a93511d
diff --git a/src/main/java/com/android/tools/r8/shaking/IfRuleEvaluator.java b/src/main/java/com/android/tools/r8/shaking/IfRuleEvaluator.java
index 058e951..6867016 100644
--- a/src/main/java/com/android/tools/r8/shaking/IfRuleEvaluator.java
+++ b/src/main/java/com/android/tools/r8/shaking/IfRuleEvaluator.java
@@ -14,7 +14,6 @@
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.graph.SubtypingInfo;
 import com.android.tools.r8.shaking.RootSetUtils.ConsequentRootSetBuilder;
 import com.android.tools.r8.shaking.RootSetUtils.RootSetBuilder;
 import com.android.tools.r8.shaking.ifrules.MaterializedSubsequentRulesOptimizer;
@@ -23,7 +22,7 @@
 import com.android.tools.r8.utils.Pair;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.UncheckedExecutionException;
-import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.android.tools.r8.utils.collections.DexClassAndMethodSet;
 import com.google.common.base.Equivalence.Wrapper;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
@@ -41,21 +40,18 @@
 
   private final AppView<? extends AppInfoWithClassHierarchy> appView;
   private final DexItemFactory factory;
-  private final SubtypingInfo subtypingInfo;
   private final Enqueuer enqueuer;
   private final ConsequentRootSetBuilder rootSetBuilder;
   private final TaskCollection<?> tasks;
 
   IfRuleEvaluator(
       AppView<? extends AppInfoWithClassHierarchy> appView,
-      SubtypingInfo subtypingInfo,
       Enqueuer enqueuer,
       ConsequentRootSetBuilder rootSetBuilder,
       TaskCollection<?> tasks) {
     assert tasks.isEmpty();
     this.appView = appView;
     this.factory = appView.dexItemFactory();
-    this.subtypingInfo = subtypingInfo;
     this.enqueuer = enqueuer;
     this.rootSetBuilder = rootSetBuilder;
     this.tasks = tasks;
@@ -101,11 +97,10 @@
     if (classKind == ClassKind.PROGRAM) {
       return ifRule.relevantCandidatesForRule(
           appView,
-          subtypingInfo,
+          enqueuer.getSubtypingInfo(),
           (Iterable<DexProgramClass>) classesWithNewlyLiveMembers,
           isEffectivelyLive);
     }
-    assert classKind == ClassKind.LIBRARY;
     return classesWithNewlyLiveMembers;
   }
 
@@ -361,7 +356,7 @@
               .collect(Collectors.toList());
       // Member rules are combined as AND logic: if found unsatisfied member rule, this
       // combination of live members is not a good fit.
-      ProgramMethodSet methodsSatisfyingRule = ProgramMethodSet.create();
+      DexClassAndMethodSet methodsSatisfyingRule = DexClassAndMethodSet.create();
       boolean satisfied =
           memberKeepRules.stream()
               .allMatch(
@@ -372,7 +367,7 @@
                     DexClassAndMethod methodSatisfyingRule =
                         rootSetBuilder.getMethodSatisfyingRule(memberRule, methodsInCombination);
                     if (methodSatisfyingRule != null) {
-                      methodsSatisfyingRule.add(methodSatisfyingRule.asProgramMethod());
+                      methodsSatisfyingRule.add(methodSatisfyingRule);
                       return true;
                     }
                     return false;
diff --git a/src/main/java/com/android/tools/r8/shaking/IfRuleEvaluatorFactory.java b/src/main/java/com/android/tools/r8/shaking/IfRuleEvaluatorFactory.java
index 6075570..fa32a02 100644
--- a/src/main/java/com/android/tools/r8/shaking/IfRuleEvaluatorFactory.java
+++ b/src/main/java/com/android/tools/r8/shaking/IfRuleEvaluatorFactory.java
@@ -18,7 +18,6 @@
 import com.android.tools.r8.graph.ProgramDefinition;
 import com.android.tools.r8.graph.ProgramField;
 import com.android.tools.r8.graph.ProgramMethod;
-import com.android.tools.r8.graph.SubtypingInfo;
 import com.android.tools.r8.graph.analysis.EnqueuerAnalysisCollection;
 import com.android.tools.r8.graph.analysis.FixpointEnqueuerAnalysis;
 import com.android.tools.r8.graph.analysis.NewlyLiveClassEnqueuerAnalysis;
@@ -33,6 +32,7 @@
 import com.google.common.base.Equivalence.Wrapper;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.IdentityHashMap;
@@ -145,26 +145,50 @@
     return false;
   }
 
-  public ConsequentRootSet applyActiveIfRulesToLibraryClasses(Enqueuer enqueuer, Timing timing)
+  public void applyActiveIfRulesToClasspathClasses(
+      ConsequentRootSetBuilder consequentRootSetBuilder, Enqueuer enqueuer, Timing timing)
       throws ExecutionException {
-    timing.begin("Apply if rules to library classes");
-    SubtypingInfo subtypingInfo = enqueuer.getSubtypingInfo();
-    ConsequentRootSetBuilder consequentRootSetBuilder =
-        ConsequentRootSet.builder(appView, enqueuer, subtypingInfo);
+    try (Timing t = timing.begin("Apply if rules to classpath classes")) {
+      applyActiveIfRulesToNonProgramClass(
+          consequentRootSetBuilder,
+          enqueuer,
+          appView.app().asDirect().classpathClasses(),
+          ClassKind.CLASSPATH,
+          timing);
+    }
+  }
+
+  public void applyActiveIfRulesToLibraryClasses(
+      ConsequentRootSetBuilder consequentRootSetBuilder, Enqueuer enqueuer, Timing timing)
+      throws ExecutionException {
+    if (!appView.testing().applyIfRulesToLibrary) {
+      return;
+    }
+    try (Timing t = timing.begin("Apply if rules to library classes")) {
+      applyActiveIfRulesToNonProgramClass(
+          consequentRootSetBuilder,
+          enqueuer,
+          appView.app().asDirect().libraryClasses(),
+          ClassKind.LIBRARY,
+          timing);
+    }
+  }
+
+  private <T extends DexClass> void applyActiveIfRulesToNonProgramClass(
+      ConsequentRootSetBuilder consequentRootSetBuilder,
+      Enqueuer enqueuer,
+      Collection<T> classes,
+      ClassKind<T> classKind,
+      Timing timing)
+      throws ExecutionException {
     IfRuleEvaluator evaluator =
-        new IfRuleEvaluator(appView, subtypingInfo, enqueuer, consequentRootSetBuilder, tasks);
+        new IfRuleEvaluator(appView, enqueuer, consequentRootSetBuilder, tasks);
     evaluator.processActiveIfRulesWithMembers(
-        activeIfRulesWithMembers,
-        ClassKind.LIBRARY,
-        appView.app().asDirect().libraryClasses(),
-        alwaysTrue());
+        activeIfRulesWithMembers, classKind, classes, alwaysTrue());
     evaluator.processActiveIfRulesWithoutMembers(
         activeIfRulesWithoutMembers,
-        newIdentityHashMapFromCollection(
-            appView.app().asDirect().libraryClasses(), DexClass::getType, Function.identity()),
+        newIdentityHashMapFromCollection(classes, DexClass::getType, Function.identity()),
         timing);
-    timing.end();
-    return consequentRootSetBuilder.buildConsequentRootSet();
   }
 
   @Override
@@ -172,9 +196,12 @@
       Enqueuer enqueuer, EnqueuerWorklist worklist, ExecutorService executorService, Timing timing)
       throws ExecutionException {
     boolean isFirstFixpoint = setSeenFixpoint();
-    if (isFirstFixpoint && appView.testing().applyIfRulesToLibrary) {
-      ConsequentRootSet consequentRootSet = applyActiveIfRulesToLibraryClasses(enqueuer, timing);
-      enqueuer.addConsequentRootSet(consequentRootSet);
+    if (isFirstFixpoint) {
+      ConsequentRootSetBuilder consequentRootSetBuilder =
+          ConsequentRootSet.builder(appView, enqueuer);
+      applyActiveIfRulesToClasspathClasses(consequentRootSetBuilder, enqueuer, timing);
+      applyActiveIfRulesToLibraryClasses(consequentRootSetBuilder, enqueuer, timing);
+      enqueuer.addConsequentRootSet(consequentRootSetBuilder.buildConsequentRootSet());
     }
     if (!shouldProcessActiveIfRulesWithMembers(isFirstFixpoint)
         && !shouldProcessActiveIfRulesWithoutMembers(isFirstFixpoint)) {
@@ -201,11 +228,10 @@
 
   private ConsequentRootSet processActiveIfRules(
       Enqueuer enqueuer, boolean isFirstFixpoint, Timing timing) throws ExecutionException {
-    SubtypingInfo subtypingInfo = enqueuer.getSubtypingInfo();
     ConsequentRootSetBuilder consequentRootSetBuilder =
-        ConsequentRootSet.builder(appView, enqueuer, subtypingInfo);
+        ConsequentRootSet.builder(appView, enqueuer);
     IfRuleEvaluator evaluator =
-        new IfRuleEvaluator(appView, subtypingInfo, enqueuer, consequentRootSetBuilder, tasks);
+        new IfRuleEvaluator(appView, enqueuer, consequentRootSetBuilder, tasks);
     timing.begin("If rules with members");
     if (shouldProcessActiveIfRulesWithMembers(isFirstFixpoint)) {
       processActiveIfRulesWithMembers(evaluator, isFirstFixpoint);
diff --git a/src/main/java/com/android/tools/r8/shaking/ProguardIfRulePreconditionMatch.java b/src/main/java/com/android/tools/r8/shaking/ProguardIfRulePreconditionMatch.java
index cb97db1..81919d3 100644
--- a/src/main/java/com/android/tools/r8/shaking/ProguardIfRulePreconditionMatch.java
+++ b/src/main/java/com/android/tools/r8/shaking/ProguardIfRulePreconditionMatch.java
@@ -3,23 +3,24 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.shaking;
 
+import com.android.tools.r8.graph.Definition;
 import com.android.tools.r8.graph.DexClass;
-import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.DexClassAndMethod;
 import com.android.tools.r8.shaking.RootSetUtils.ConsequentRootSetBuilder;
-import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.android.tools.r8.utils.collections.DexClassAndMethodSet;
 
 public class ProguardIfRulePreconditionMatch {
 
   private final ProguardIfRule ifRule;
   private final DexClass classMatch;
-  private final ProgramMethodSet methodsMatch;
+  private final DexClassAndMethodSet methodsMatch;
 
   public ProguardIfRulePreconditionMatch(ProguardIfRule ifRule, DexClass classMatch) {
-    this(ifRule, classMatch, ProgramMethodSet.empty());
+    this(ifRule, classMatch, DexClassAndMethodSet.empty());
   }
 
   public ProguardIfRulePreconditionMatch(
-      ProguardIfRule ifRule, DexClass classMatch, ProgramMethodSet methodsMatch) {
+      ProguardIfRule ifRule, DexClass classMatch, DexClassAndMethodSet methodsMatch) {
     this.ifRule = ifRule;
     this.classMatch = classMatch;
     this.methodsMatch = methodsMatch;
@@ -50,13 +51,18 @@
   }
 
   private void disallowMethodOptimizationsForReevaluation(ConsequentRootSetBuilder rootSetBuilder) {
-    for (ProgramMethod method : methodsMatch) {
-      rootSetBuilder
-          .getDependentMinimumKeepInfo()
-          .getOrCreateUnconditionalMinimumKeepInfoFor(method.getReference())
-          .asMethodJoiner()
-          .disallowClassInlining()
-          .disallowInlining();
+    if (classMatch.isProgramClass()) {
+      for (DexClassAndMethod method : methodsMatch) {
+        assert method.isProgramMethod();
+        rootSetBuilder
+            .getDependentMinimumKeepInfo()
+            .getOrCreateUnconditionalMinimumKeepInfoFor(method.getReference())
+            .asMethodJoiner()
+            .disallowClassInlining()
+            .disallowInlining();
+      }
+    } else {
+      assert methodsMatch.stream().noneMatch(Definition::isProgramMethod);
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java b/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
index c1befbf..f6c9b5b 100644
--- a/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
+++ b/src/main/java/com/android/tools/r8/shaking/RootSetUtils.java
@@ -2394,10 +2394,8 @@
     }
 
     static ConsequentRootSetBuilder builder(
-        AppView<? extends AppInfoWithClassHierarchy> appView,
-        Enqueuer enqueuer,
-        SubtypingInfo subtypingInfo) {
-      return new ConsequentRootSetBuilder(appView, enqueuer, subtypingInfo);
+        AppView<? extends AppInfoWithClassHierarchy> appView, Enqueuer enqueuer) {
+      return new ConsequentRootSetBuilder(appView, enqueuer, enqueuer.getSubtypingInfo());
     }
   }
 
diff --git a/src/test/java/com/android/tools/r8/partial/PartialCompilationIfD8ClassPresentTest.java b/src/test/java/com/android/tools/r8/partial/PartialCompilationIfD8ClassPresentTest.java
new file mode 100644
index 0000000..aae826f
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/partial/PartialCompilationIfD8ClassPresentTest.java
@@ -0,0 +1,56 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.partial;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class PartialCompilationIfD8ClassPresentTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    parameters.assumeR8PartialTestParameters();
+    testForR8Partial(parameters.getBackend())
+        .addR8IncludedClasses(Main.class)
+        .addR8ExcludedClasses(ExcludedClass.class)
+        .addKeepRules(
+            "-if class **Excluded** { public static void foo(); }",
+            "-keep class " + Main.class.getTypeName() + " {",
+            "  public static void main(java.lang.String[]);",
+            "}")
+        .setMinApi(parameters)
+        .compile()
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("Hello, world!");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      ExcludedClass.foo();
+    }
+  }
+
+  static class ExcludedClass {
+
+    public static void foo() {
+      System.out.println("Hello, world!");
+    }
+  }
+}