[KeepAnno] Interpret unbound consequences as independent targets

The consequences should be treated as a disjunction unless they are part
of an explicit binding (that is used). This is consistent with the
conservative rule extraction.

Fixes: b/336270931
Bug: b/323816623
Change-Id: I1e23c8ff2732fc01ea6349d8d2e151cc6be03f1b
diff --git a/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java b/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
index fa7fb7d..4df970e 100644
--- a/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
+++ b/src/main/java/com/android/tools/r8/shaking/rules/KeepAnnotationMatcher.java
@@ -41,6 +41,7 @@
 import com.android.tools.r8.shaking.KeepInfo.Joiner;
 import com.android.tools.r8.shaking.MinimumKeepInfoCollection;
 import com.android.tools.r8.threading.ThreadingModule;
+import com.android.tools.r8.utils.BooleanBox;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.google.common.collect.ImmutableList;
@@ -302,7 +303,11 @@
     }
 
     private void foundMatch() {
-      callback.accept(assignment.createMatch(schema));
+      MatchResult match = assignment.createMatch(schema);
+      // We might not have found any matching consequences and this is not an actual match.
+      if (!match.consequences.isEmpty()) {
+        callback.accept(match);
+      }
     }
 
     private void findMatchingClass(int classIndex) {
@@ -319,20 +324,20 @@
                 .createType(classPattern.getClassNamePattern().getExactDescriptor());
         DexProgramClass clazz = DexProgramClass.asProgramClassOrNull(appInfo.definitionFor(type));
         if (clazz == null) {
-          // No valid match, so the rule is discarded. This should likely be a diagnostics info.
+          continueWithNoClass(classIndex);
           return;
         }
         if (!predicates.matchesClass(clazz, classPattern, appInfo)) {
-          // Invalid match for this class.
+          continueWithNoClass(classIndex);
           return;
         }
         continueWithClass(classIndex, clazz);
-      } else {
-        // TODO(b/323816623): This repeated iteration on all classes must be avoided.
-        for (DexProgramClass clazz : appInfo.classes()) {
-          if (predicates.matchesClass(clazz, classPattern, appInfo)) {
-            continueWithClass(classIndex, clazz);
-          }
+        return;
+      }
+      // TODO(b/323816623): This repeated iteration on all classes must be avoided.
+      for (DexProgramClass clazz : appInfo.classes()) {
+        if (predicates.matchesClass(clazz, classPattern, appInfo)) {
+          continueWithClass(classIndex, clazz);
         }
       }
     }
@@ -343,6 +348,13 @@
       findMatchingMember(0, classMemberIndexList, clazz, classIndex + 1);
     }
 
+    private void continueWithNoClass(int classIndex) {
+      if (schema.isOptionalClass(classIndex)) {
+        assignment.setClass(classIndex, null);
+        findMatchingClass(classIndex + 1);
+      }
+    }
+
     private void findMatchingMember(
         int memberInHolderIndex,
         IntList memberIndexTranslation,
@@ -355,10 +367,13 @@
       }
       int memberIndex = memberIndexTranslation.getInt(memberInHolderIndex);
       KeepMemberItemPattern memberItemPattern = schema.members.get(memberIndex);
+      BooleanBox didContinue = new BooleanBox(false);
       Consumer<ProgramDefinition> continueWithMember =
-          m ->
-              continueWithMember(
-                  m, memberIndex, memberInHolderIndex + 1, memberIndexTranslation, nextClassIndex);
+          m -> {
+            didContinue.isTrue();
+            continueWithMember(
+                m, memberIndex, memberInHolderIndex + 1, memberIndexTranslation, nextClassIndex);
+          };
       memberItemPattern
           .getMemberPattern()
           .match(
@@ -382,6 +397,11 @@
               methodPattern ->
                   holder.forEachProgramMethodMatching(
                       m -> predicates.matchesMethod(m, methodPattern), continueWithMember));
+      if (didContinue.isFalse()) {
+        // No match for the member pattern existed, continue with empty member.
+        continueWithNoMember(
+            memberIndex, memberInHolderIndex + 1, memberIndexTranslation, holder, nextClassIndex);
+      }
     }
 
     private void continueWithMember(
@@ -403,6 +423,18 @@
           definition.getContextClass(),
           nextClassIndex);
     }
+
+    private void continueWithNoMember(
+        int memberIndex,
+        int nextMemberInHolderIndex,
+        IntList memberIndexTranslation,
+        DexProgramClass holder,
+        int nextClassIndex) {
+      if (schema.isOptionalMember(memberIndex, nextClassIndex - 1)) {
+        assignment.setMember(memberIndex, null);
+        findMatchingMember(nextMemberInHolderIndex, memberIndexTranslation, holder, nextClassIndex);
+      }
+    }
   }
 
   /**
@@ -418,18 +450,25 @@
     final List<KeepClassItemPattern> classes = new ArrayList<>();
     final List<KeepMemberItemPattern> members = new ArrayList<>();
     final List<IntList> classMembers = new ArrayList<>();
+    final IntList boundClasses = new IntArrayList();
     final IntList preconditions = new IntArrayList();
     final IntList consequences = new IntArrayList();
     final List<KeepConstraints> constraints = new ArrayList<>();
+    int preconditionClassesCount = -1;
+    int preconditionMembersCount = -1;
 
     public NormalizedSchema(KeepDeclaration declaration) {
       this.declaration = declaration;
       declaration.match(
           edge -> {
             edge.getPreconditions().forEach(this::addPrecondition);
+            preconditionClassesCount = classes.size();
+            preconditionMembersCount = members.size();
             edge.getConsequences().forEachTarget(this::addConsequence);
           },
           check -> {
+            preconditionClassesCount = 0;
+            preconditionMembersCount = 0;
             consequences.add(defineItemPattern(check.getItemPattern()));
           });
     }
@@ -439,6 +478,14 @@
       return declaration.asKeepEdge().getBindings().get(symbol).getItem();
     }
 
+    public boolean isOptionalClass(int classIndex) {
+      return classIndex >= preconditionClassesCount && !boundClasses.contains(classIndex);
+    }
+
+    public boolean isOptionalMember(int memberIndex, int classIndex) {
+      return memberIndex >= preconditionMembersCount && isOptionalClass(classIndex);
+    }
+
     public static boolean isClassKeyReference(int keyRef) {
       return keyRef >= 0;
     }
@@ -470,7 +517,14 @@
 
     private int defineBindingReference(KeepBindingReference reference) {
       return symbolToKey.computeIfAbsent(
-          reference.getName(), symbol -> defineItemPattern(getItemForBinding(symbol)));
+          reference.getName(),
+          symbol -> {
+            int bindingId = defineItemPattern(getItemForBinding(symbol));
+            if (isClassKeyReference(bindingId)) {
+              boundClasses.add(bindingId);
+            }
+            return bindingId;
+          });
     }
 
     private int defineItemPattern(KeepItemPattern item) {
@@ -539,17 +593,17 @@
           schema.declaration,
           schema.preconditions.isEmpty()
               ? Collections.emptyList()
-              : getItemList(schema.preconditions),
-          getItemList(schema.consequences),
+              : getItemList(schema.preconditions, false),
+          getItemList(schema.consequences, true),
           getConstraints(schema));
     }
 
-    private List<ProgramDefinition> getItemList(IntList indexReferences) {
+    private List<ProgramDefinition> getItemList(IntList indexReferences, boolean allowUnset) {
       assert !indexReferences.isEmpty();
       List<ProgramDefinition> definitions = new ArrayList<>(indexReferences.size());
       for (int i = 0; i < indexReferences.size(); i++) {
         ProgramDefinition item = getItemForReference(indexReferences.getInt(i));
-        assert item != null || hasEmptyMembers;
+        assert item != null || hasEmptyMembers || allowUnset;
         if (item != null) {
           definitions.add(item);
         }
diff --git a/src/test/java/com/android/tools/r8/keepanno/KeepConjunctiveBindingsTest.java b/src/test/java/com/android/tools/r8/keepanno/KeepConjunctiveBindingsTest.java
new file mode 100644
index 0000000..7c230d0
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/keepanno/KeepConjunctiveBindingsTest.java
@@ -0,0 +1,108 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.keepanno;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.SingleTestRunResult;
+import com.android.tools.r8.keepanno.annotations.KeepBinding;
+import com.android.tools.r8.keepanno.annotations.KeepEdge;
+import com.android.tools.r8.keepanno.annotations.KeepTarget;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.google.common.collect.ImmutableList;
+import java.lang.reflect.Field;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+
+@RunWith(Parameterized.class)
+public class KeepConjunctiveBindingsTest extends KeepAnnoTestBase {
+
+  static final String EXPECTED = StringUtils.lines("Hello, world");
+
+  @Parameter public KeepAnnoParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build());
+  }
+
+  @Test
+  public void test() throws Exception {
+    SingleTestRunResult<?> result =
+        testForKeepAnno(parameters)
+            .addProgramClasses(getInputClasses())
+            .addKeepMainRule(TestClass.class)
+            .setExcludedOuterClass(getClass())
+            .allowUnusedProguardConfigurationRules()
+            .run(TestClass.class);
+    if (parameters.isReference()) {
+      result.assertSuccessWithOutput(EXPECTED);
+    } else if (parameters.isPG()) {
+      // PG will make the field private and result in access error.
+      result.assertFailureWithErrorThatThrows(IllegalAccessException.class);
+    } else {
+      // R8 will remove the field.
+      result.assertSuccessWithOutput("").inspect(this::checkOutput);
+    }
+  }
+
+  public List<Class<?>> getInputClasses() {
+    return ImmutableList.of(TestClass.class, A.class);
+  }
+
+  private void checkOutput(CodeInspector inspector) {
+    assertThat(inspector.clazz(A.class), isPresent());
+    assertThat(inspector.clazz(A.class).uniqueFieldWithOriginalName("fieldA"), isAbsent());
+    assertThat(inspector.clazz(A.class).uniqueFieldWithOriginalName("fieldB"), isAbsent());
+  }
+
+  static class A {
+
+    public String fieldA = "Hello, world";
+    public Integer fieldB = 42;
+
+    @KeepEdge(
+        bindings = {
+          @KeepBinding(bindingName = "A", classConstant = A.class),
+          @KeepBinding(
+              bindingName = "StringFields",
+              classFromBinding = "A",
+              fieldTypeConstant = String.class),
+          @KeepBinding(
+              bindingName = "NonMatchingFields",
+              classFromBinding = "A",
+              fieldType = "some.NonExistingClass"),
+        },
+        consequences = {
+          // The bindings on A defines the required structure of A on input, thus the binding will
+          // fail to find the required match when used. (This mirrors -keepclasseswithmembers).
+          // Contrast this with the test in KeepDisjunctiveConsequencesTest.
+          @KeepTarget(classFromBinding = "A"),
+          @KeepTarget(memberFromBinding = "StringFields"),
+          @KeepTarget(memberFromBinding = "NonMatchingFields")
+        })
+    public void foo() throws Exception {
+      for (Field field : getClass().getDeclaredFields()) {
+        if (field.getType().equals(String.class)) {
+          System.out.println(field.get(this));
+        }
+      }
+    }
+  }
+
+  static class TestClass {
+
+    public static void main(String[] args) throws Exception {
+      new A().foo();
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/keepanno/KeepDisjunctiveConsequencesTest.java b/src/test/java/com/android/tools/r8/keepanno/KeepDisjunctiveConsequencesTest.java
new file mode 100644
index 0000000..cf0b2a2
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/keepanno/KeepDisjunctiveConsequencesTest.java
@@ -0,0 +1,93 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.keepanno;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.keepanno.annotations.KeepConstraint;
+import com.android.tools.r8.keepanno.annotations.KeepTarget;
+import com.android.tools.r8.keepanno.annotations.UsesReflection;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.google.common.collect.ImmutableList;
+import java.lang.reflect.Field;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+
+@RunWith(Parameterized.class)
+public class KeepDisjunctiveConsequencesTest extends KeepAnnoTestBase {
+
+  static final String EXPECTED = StringUtils.lines("Hello, world");
+
+  @Parameter public KeepAnnoParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static List<KeepAnnoParameters> data() {
+    return createParameters(
+        getTestParameters().withDefaultRuntimes().withApiLevel(AndroidApiLevel.B).build());
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForKeepAnno(parameters)
+        .addProgramClasses(getInputClasses())
+        .addKeepMainRule(TestClass.class)
+        .setExcludedOuterClass(getClass())
+        .run(TestClass.class)
+        .assertSuccessWithOutput(EXPECTED)
+        .applyIf(parameters.isShrinker(), r -> r.inspect(this::checkOutput));
+  }
+
+  public List<Class<?>> getInputClasses() {
+    return ImmutableList.of(TestClass.class, A.class);
+  }
+
+  private void checkOutput(CodeInspector inspector) {
+    assertThat(inspector.clazz(A.class), isPresent());
+    assertThat(inspector.clazz(A.class).uniqueFieldWithOriginalName("fieldA"), isPresent());
+    assertThat(
+        inspector.clazz(A.class).uniqueFieldWithOriginalName("fieldB"),
+        parameters.isPG() ? isPresent() : isAbsent());
+  }
+
+  static class A {
+
+    public String fieldA = "Hello, world";
+    public Integer fieldB = 42;
+
+    @UsesReflection({
+      @KeepTarget(
+          classConstant = A.class,
+          fieldTypeConstant = String.class,
+          constraints = {KeepConstraint.LOOKUP, KeepConstraint.FIELD_GET}),
+      // This target does not match anything, but that should not cause the above target to be
+      // ignored.
+      // Contrast this with the test in KeepConjunctiveBindingsTest.
+      @KeepTarget(
+          classConstant = A.class,
+          fieldType = "some.NonExistentClass",
+          constraints = {KeepConstraint.LOOKUP, KeepConstraint.FIELD_GET})
+    })
+    public void foo() throws Exception {
+      for (Field field : getClass().getDeclaredFields()) {
+        if (field.getType().equals(String.class)) {
+          System.out.println(field.get(this));
+        }
+      }
+    }
+  }
+
+  static class TestClass {
+
+    public static void main(String[] args) throws Exception {
+      new A().foo();
+    }
+  }
+}