[keepanno] Add test for instanceOf type pattern

Bug: b/349268994
Change-Id: If0a0731f26d58a049d6a259b50f718dd3d340f60
diff --git a/src/main/java/com/android/tools/r8/shaking/rules/MaterializedConditionalRule.java b/src/main/java/com/android/tools/r8/shaking/rules/MaterializedConditionalRule.java
index cbae8fe..b6c85d0 100644
--- a/src/main/java/com/android/tools/r8/shaking/rules/MaterializedConditionalRule.java
+++ b/src/main/java/com/android/tools/r8/shaking/rules/MaterializedConditionalRule.java
@@ -26,23 +26,8 @@
   }
 
   public boolean pruneItems(PrunedItems prunedItems) {
-    for (DexReference precondition : preconditions) {
-      if (precondition.isDexType()) {
-        if (prunedItems.getRemovedClasses().contains(precondition.asDexType())) {
-          return true;
-        }
-      } else if (precondition.isDexField()) {
-        if (prunedItems.getRemovedFields().contains(precondition.asDexField())) {
-          return true;
-        }
-      } else {
-        assert precondition.isDexMethod();
-        if (prunedItems.getRemovedMethods().contains(precondition.asDexMethod())) {
-          return true;
-        }
-      }
-    }
-    // Preconditions are in place, so trim down consequences.
+    // Preconditions cannot be pruned as they reference "original" program references which may be
+    // in inlined positions even when the items themselves are "pruned".
     consequences.pruneItems(prunedItems);
     return consequences.isEmpty();
   }
diff --git a/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java b/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
index 27b74ee..3a4a3c4 100644
--- a/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
+++ b/src/test/java/com/android/tools/r8/keepanno/KeepAnnoTestBuilder.java
@@ -28,6 +28,8 @@
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
@@ -66,6 +68,12 @@
     return keepAnnoParams.parameters();
   }
 
+  public KeepAnnoTestBuilder addInnerClasses(Class<?> clazz) throws IOException {
+    return addProgramFiles(new ArrayList<>(ToolHelper.getClassFilesForInnerClasses(clazz)));
+  }
+
+  public abstract KeepAnnoTestBuilder addProgramFiles(List<Path> programFiles) throws IOException;
+
   public abstract KeepAnnoTestBuilder addProgramClasses(List<Class<?>> programClasses)
       throws IOException;
 
@@ -148,6 +156,12 @@
     }
 
     @Override
+    public KeepAnnoTestBuilder addProgramFiles(List<Path> programFiles) {
+      builder.addProgramFiles(programFiles);
+      return this;
+    }
+
+    @Override
     public KeepAnnoTestBuilder addProgramClasses(List<Class<?>> programClasses) {
       builder.addProgramClasses(programClasses);
       return this;
@@ -230,6 +244,14 @@
     }
 
     @Override
+    public KeepAnnoTestBuilder addProgramFiles(List<Path> programFiles) throws IOException {
+      for (Path programFile : programFiles) {
+        extractAndAdd(Files.readAllBytes(programFile));
+      }
+      return this;
+    }
+
+    @Override
     public KeepAnnoTestBuilder addProgramClasses(List<Class<?>> programClasses) throws IOException {
       for (Class<?> programClass : programClasses) {
         extractAndAdd(ToolHelper.getClassAsBytes(programClass));
@@ -320,6 +342,14 @@
     }
 
     @Override
+    public KeepAnnoTestBuilder addProgramFiles(List<Path> programFiles) throws IOException {
+      List<String> rules = KeepAnnoTestUtils.extractRulesFromFiles(programFiles, extractorOptions);
+      builder.addProgramFiles(programFiles);
+      builder.addKeepRules(rules);
+      return this;
+    }
+
+    @Override
     public KeepAnnoTestBuilder addProgramClasses(List<Class<?>> programClasses) throws IOException {
       List<String> rules = KeepAnnoTestUtils.extractRules(programClasses, extractorOptions);
       builder.addProgramClasses(programClasses);
@@ -372,6 +402,14 @@
     }
 
     @Override
+    public KeepAnnoTestBuilder addProgramFiles(List<Path> programFiles) throws IOException {
+      List<String> rules = KeepAnnoTestUtils.extractRulesFromFiles(programFiles, extractorOptions);
+      builder.addProgramFiles(programFiles);
+      builder.addKeepRules(rules);
+      return this;
+    }
+
+    @Override
     public KeepAnnoTestBuilder addProgramClasses(List<Class<?>> programClasses) throws IOException {
       List<String> rules = KeepAnnoTestUtils.extractRules(programClasses, extractorOptions);
       builder.addProgramClasses(programClasses);
diff --git a/src/test/java/com/android/tools/r8/keepanno/KeepTypePatternWithInstanceOfTest.java b/src/test/java/com/android/tools/r8/keepanno/KeepTypePatternWithInstanceOfTest.java
new file mode 100644
index 0000000..ce7c1d7
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/keepanno/KeepTypePatternWithInstanceOfTest.java
@@ -0,0 +1,192 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.keepanno;
+
+import static org.junit.Assert.assertThrows;
+
+import com.android.tools.r8.TestShrinkerBuilder;
+import com.android.tools.r8.keepanno.annotations.ClassNamePattern;
+import com.android.tools.r8.keepanno.annotations.InstanceOfPattern;
+import com.android.tools.r8.keepanno.annotations.KeepTarget;
+import com.android.tools.r8.keepanno.annotations.TypePattern;
+import com.android.tools.r8.keepanno.annotations.UsesReflection;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.StringUtils;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+
+@RunWith(Parameterized.class)
+public class KeepTypePatternWithInstanceOfTest extends KeepAnnoTestBase {
+
+  @Parameter public KeepAnnoParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters()
+            .withDefaultRuntimes()
+            .withApiLevel(AndroidApiLevel.B)
+            .enableApiLevelsForCf()
+            .build());
+  }
+
+  @Test
+  public void test() throws Exception {
+    if (parameters.isReference() || parameters.isNativeR8()) {
+      testForKeepAnno(parameters)
+          .addInnerClasses(getClass())
+          .addKeepMainRule(TestClass.class)
+          .run(TestClass.class)
+          .assertSuccessWithOutput(getExpectedResult(parameters.isReference() ? 9 : 4));
+    } else {
+      // It's either Unimplemented or CompilationFailedException.
+      assertThrows(
+          Exception.class,
+          () ->
+              testForKeepAnno(parameters)
+                  .addInnerClasses(getClass())
+                  .addKeepMainRule(TestClass.class)
+                  .run(TestClass.class));
+    }
+  }
+
+  private String getExpectedResult(int numMethods) {
+    return StringUtils.lines(
+        "Num methods: " + numMethods,
+        "5",
+        "8",
+        "9",
+        "6",
+        "1",
+        "2",
+        "3",
+        "4",
+        "5",
+        "6",
+        "7",
+        "8",
+        "9");
+  }
+
+  private void clearClassMerging(TestShrinkerBuilder<?, ?, ?, ?, ?> sb) {
+    sb.addOptionsModification(
+        opt -> {
+          opt.horizontalClassMergerOptions().disable();
+          opt.getVerticalClassMergerOptions().disable();
+        });
+  }
+
+  public static class Top {}
+
+  public static class Top1 extends Top {}
+
+  public static class Sub1 extends Top1 {}
+
+  public static class Subb1 extends Top1 {}
+
+  public static class Top2 extends Top {}
+
+  public static class Sub2 extends Top2 {}
+
+  public static class Subb2 extends Top2 {}
+
+  public static class A {
+
+    public void foo(Top1 a, Top2 b) {
+      System.out.println("1");
+    }
+
+    public void foo(Sub1 a, Top2 b) {
+      System.out.println("2");
+    }
+
+    public void foo(Subb1 a, Top2 b) {
+      System.out.println("3");
+    }
+
+    public void foo(Top1 a, Sub2 b) {
+      System.out.println("4");
+    }
+
+    public void foo(Sub1 a, Sub2 b) {
+      System.out.println("5");
+    }
+
+    public void foo(Subb1 a, Sub2 b) {
+      System.out.println("6");
+    }
+
+    public void foo(Top1 a, Subb2 b) {
+      System.out.println("7");
+    }
+
+    public void foo(Sub1 a, Subb2 b) {
+      System.out.println("8");
+    }
+
+    public void foo(Subb1 a, Subb2 b) {
+      System.out.println("9");
+    }
+  }
+
+  static class TestClass {
+
+    public static void main(String[] args)
+        throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
+      System.out.println("Num methods: " + A.class.getDeclaredMethods().length);
+      for (Method method : reflectiveMethods()) {
+        method.invoke(new A(), null, null);
+      }
+      // Make sure all the methods are live, they can be inlined.
+      runAll();
+    }
+
+    private static void runAll() {
+      A a = new A();
+      a.foo(new Top1(), new Top2());
+      a.foo(new Sub1(), new Top2());
+      a.foo(new Subb1(), new Top2());
+      a.foo(new Top1(), new Sub2());
+      a.foo(new Sub1(), new Sub2());
+      a.foo(new Subb1(), new Sub2());
+      a.foo(new Top1(), new Subb2());
+      a.foo(new Sub1(), new Subb2());
+      a.foo(new Subb1(), new Subb2());
+    }
+
+    @UsesReflection(
+        @KeepTarget(
+            classConstant = A.class,
+            methodName = "foo",
+            methodParameterTypePatterns = {
+              @TypePattern(
+                  instanceOfPattern =
+                      @InstanceOfPattern(
+                          inclusive = false,
+                          classNamePattern = @ClassNamePattern(constant = Top1.class))),
+              @TypePattern(
+                  instanceOfPattern =
+                      @InstanceOfPattern(
+                          inclusive = false,
+                          classNamePattern = @ClassNamePattern(constant = Top2.class)))
+            }))
+    public static Method[] reflectiveMethods() throws NoSuchMethodException {
+      return new Method[] {
+        A.class.getDeclaredMethod(getFoo(), Sub1.class, Sub2.class),
+        A.class.getDeclaredMethod(getFoo(), Sub1.class, Subb2.class),
+        A.class.getDeclaredMethod(getFoo(), Subb1.class, Subb2.class),
+        A.class.getDeclaredMethod(getFoo(), Subb1.class, Sub2.class)
+      };
+    }
+
+    private static String getFoo() {
+      return System.currentTimeMillis() > 0 ? "foo" : "bar";
+    }
+  }
+}
diff --git a/src/test/testbase/java/com/android/tools/r8/keepanno/KeepAnnoTestUtils.java b/src/test/testbase/java/com/android/tools/r8/keepanno/KeepAnnoTestUtils.java
index 2e21529..93384e8 100644
--- a/src/test/testbase/java/com/android/tools/r8/keepanno/KeepAnnoTestUtils.java
+++ b/src/test/testbase/java/com/android/tools/r8/keepanno/KeepAnnoTestUtils.java
@@ -61,6 +61,21 @@
     return archive;
   }
 
+  public static List<String> extractRulesFromFiles(
+      List<Path> inputFiles, KeepRuleExtractorOptions extractorOptions) {
+    return extractRulesFromBytes(
+        ListUtils.map(
+            inputFiles,
+            path -> {
+              try {
+                return Files.readAllBytes(path);
+              } catch (IOException e) {
+                throw new RuntimeException(e);
+              }
+            }),
+        extractorOptions);
+  }
+
   public static List<String> extractRules(
       List<Class<?>> inputClasses, KeepRuleExtractorOptions extractorOptions) {
     return extractRulesFromBytes(