[TraceReferences] Add tracing of enum annotation values

Bug: b/376016627
Change-Id: I8d47b290a649b40e76117a5afe2dfd6e477798d5
diff --git a/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java b/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
index 6ccbf71..a90ac15 100644
--- a/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
+++ b/src/main/java/com/android/tools/r8/jar/CfApplicationWriter.java
@@ -625,7 +625,11 @@
         DexField enumField = value.asDexValueEnum().getValue();
         // This must not be renamed, as the Java runtime will use Enum.valueOf to find the enum's
         // referenced in annotations. See b/236691999 for details.
-        assert getNamingLens().lookupName(enumField) == enumField.name;
+        assert getNamingLens().lookupName(enumField) == enumField.name
+            : "Enum field "
+                + enumField.name
+                + " renamed to "
+                + getNamingLens().lookupName(enumField);
         visitor.visitEnum(
             name,
             getNamingLens().lookupDescriptor(enumField.getType()).toString(),
diff --git a/src/main/java/com/android/tools/r8/tracereferences/KeepRuleFormatter.java b/src/main/java/com/android/tools/r8/tracereferences/KeepRuleFormatter.java
index e36b7b1..78580d9 100644
--- a/src/main/java/com/android/tools/r8/tracereferences/KeepRuleFormatter.java
+++ b/src/main/java/com/android/tools/r8/tracereferences/KeepRuleFormatter.java
@@ -23,7 +23,12 @@
       appendLine("# Missing class: " + tracedClass.getReference().getTypeName());
       return;
     }
-    append(allowObfuscation ? "-keep,allowobfuscation" : "-keep");
+    // Don't obfuscate enums as the Java runtime will use Enum.valueOf to find enum's referenced in
+    // annotations, see b/236691999.
+    append(
+        allowObfuscation && !tracedClass.getAccessFlags().isEnum()
+            ? "-keep,allowobfuscation"
+            : "-keep");
     if (tracedClass.getAccessFlags().isInterface()) {
       appendLine(
           " "
diff --git a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
index ed21df7..bb223ea 100644
--- a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
+++ b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
@@ -82,7 +82,8 @@
       // This iterates all annotations on the class, including on its methods and fields.
       clazz
           .annotations()
-          .forEach(dexAnnotation -> useCollector.registerAnnotation(dexAnnotation, classContext));
+          .forEach(
+              dexAnnotation -> useCollector.registerAnnotation(dexAnnotation, clazz, classContext));
     }
     consumer.finished(diagnostics);
   }
@@ -184,7 +185,7 @@
     private void handleMemberResolution(
         DexMember<?, ?> reference,
         DexClassAndMember<?, ?> member,
-        ProgramMethod context,
+        DexProgramClass context,
         DefinitionContext referencedFrom) {
       DexClass holder = member.getHolder();
       assert isTargetType(holder.getType());
@@ -196,10 +197,10 @@
     }
 
     private void ensurePackageAccessToMember(
-        DexClassAndMember<?, ?> member, ProgramMethod context) {
+        DexClassAndMember<?, ?> member, DexProgramClass context) {
       if (member.getAccessFlags().isPackagePrivateOrProtected()) {
         if (member.getAccessFlags().isPackagePrivate()
-            || !appInfo().isSubtype(context.getHolder(), member.getHolder())) {
+            || !appInfo().isSubtype(context, member.getHolder())) {
           consumer.acceptPackage(
               Reference.packageFromString(member.getHolderType().getPackageName()), diagnostics);
         }
@@ -217,7 +218,7 @@
       //   class.
       TracedMethodImpl tracedMethod = new TracedMethodImpl(method.getDefinition(), referencedFrom);
       consumer.acceptMethod(tracedMethod, diagnostics);
-      ensurePackageAccessToMember(method, context);
+      ensurePackageAccessToMember(method, context.getHolder());
     }
 
     private <R, T extends TracedReference<R, ?>> void collectMissing(
@@ -247,7 +248,9 @@
       addType(field.getType(), referencedFrom);
       field
           .getAnnotations()
-          .forEach(dexAnnotation -> registerAnnotation(dexAnnotation, referencedFrom));
+          .forEach(
+              dexAnnotation ->
+                  registerAnnotation(dexAnnotation, field.getHolder(), referencedFrom));
     }
 
     private void registerMethod(ProgramMethod method) {
@@ -256,10 +259,14 @@
       addType(method.getReturnType(), referencedFrom);
       method
           .getAnnotations()
-          .forEach(dexAnnotation -> registerAnnotation(dexAnnotation, referencedFrom));
+          .forEach(
+              dexAnnotation ->
+                  registerAnnotation(dexAnnotation, method.getHolder(), referencedFrom));
       method
           .getParameterAnnotations()
-          .forEachAnnotation(dexAnnotation -> registerAnnotation(dexAnnotation, referencedFrom));
+          .forEachAnnotation(
+              dexAnnotation ->
+                  registerAnnotation(dexAnnotation, method.getHolder(), referencedFrom));
     }
 
     private void traceCode(ProgramMethod method) {
@@ -287,7 +294,8 @@
           });
     }
 
-    private void registerAnnotation(DexAnnotation annotation, DefinitionContext referencedFrom) {
+    private void registerAnnotation(
+        DexAnnotation annotation, DexProgramClass context, DefinitionContext referencedFrom) {
       DexType type = annotation.getAnnotationType();
       assert type.isClassType();
       if (type.isIdenticalTo(factory.annotationMethodParameters)
@@ -317,13 +325,16 @@
                 element -> {
                   assert element.getValue().isDexValueAnnotation();
                   registerEncodedAnnotation(
-                      element.getValue().asDexValueAnnotation().getValue(), referencedFrom);
+                      element.getValue().asDexValueAnnotation().getValue(),
+                      context,
+                      referencedFrom);
                 });
         return;
       }
       if (type.isIdenticalTo(factory.annotationThrows)) {
         assert referencedFrom.isMethodContext();
-        registerDexValue(annotation.annotation.elements[0].value.asDexValueArray(), referencedFrom);
+        registerDexValue(
+            annotation.annotation.elements[0].value.asDexValueArray(), context, referencedFrom);
         return;
       }
       assert !type.getDescriptor().startsWith(factory.dalvikAnnotationPrefix)
@@ -331,11 +342,13 @@
               + factory.dalvikAnnotationPrefix
               + ": "
               + type.getDescriptor();
-      registerEncodedAnnotation(annotation.getAnnotation(), referencedFrom);
+      registerEncodedAnnotation(annotation.getAnnotation(), context, referencedFrom);
     }
 
     void registerEncodedAnnotation(
-        DexEncodedAnnotation annotation, DefinitionContext referencedFrom) {
+        DexEncodedAnnotation annotation,
+        DexProgramClass context,
+        DefinitionContext referencedFrom) {
       addClassType(
           annotation.getType(),
           referencedFrom,
@@ -351,21 +364,47 @@
                     }
                   }
                   // Handle the argument values passed to the annotation "method".
-                  registerDexValue(element.getValue(), referencedFrom);
+                  registerDexValue(element.getValue(), context, referencedFrom);
                 });
           });
     }
 
-    private void registerDexValue(DexValue value, DefinitionContext referencedFrom) {
+    private void registerDexValue(
+        DexValue value, DexProgramClass context, DefinitionContext referencedFrom) {
       if (value.isDexValueType()) {
         addType(value.asDexValueType().getValue(), referencedFrom);
+      } else if (value.isDexValueEnum()) {
+        DexField field = value.asDexValueEnum().value;
+        handleRewrittenFieldReference(field, context, referencedFrom);
       } else if (value.isDexValueArray()) {
         for (DexValue elementValue : value.asDexValueArray().getValues()) {
-          registerDexValue(elementValue, referencedFrom);
+          registerDexValue(elementValue, context, referencedFrom);
         }
       }
     }
 
+    private void handleRewrittenFieldReference(
+        DexField field, DexProgramClass context, DefinitionContext referencedFrom) {
+      addType(field.getHolderType(), referencedFrom);
+      addType(field.getType(), referencedFrom);
+      FieldResolutionResult resolutionResult = appInfo().resolveField(field);
+      if (resolutionResult.hasSuccessfulResolutionResult()) {
+        resolutionResult.forEachSuccessfulFieldResolutionResult(
+            singleResolutionResult -> {
+              DexClassAndField resolvedField = singleResolutionResult.getResolutionPair();
+              if (isTargetType(resolvedField.getHolderType())) {
+                handleMemberResolution(field, resolvedField, context, referencedFrom);
+                TracedFieldImpl tracedField = new TracedFieldImpl(resolvedField, referencedFrom);
+                consumer.acceptField(tracedField, diagnostics);
+              }
+            });
+      } else {
+        TracedFieldImpl tracedField = new TracedFieldImpl(field, referencedFrom);
+        collectMissingField(tracedField);
+        consumer.acceptField(tracedField, diagnostics);
+      }
+    }
+
     class MethodUseCollector extends UseRegistry<ProgramMethod> {
 
       private final DefinitionContext referencedFrom;
@@ -498,7 +537,8 @@
           assert resolvedMethod.getReference().match(method)
               || resolvedMethod.getHolder().isSignaturePolymorphicMethod(definition, factory);
           if (isTargetType(resolvedMethod.getHolderType())) {
-            handleMemberResolution(method, resolvedMethod, getContext(), referencedFrom);
+            handleMemberResolution(
+                method, resolvedMethod, getContext().getHolder(), referencedFrom);
             TracedMethodImpl tracedMethod = new TracedMethodImpl(definition, referencedFrom);
             consumer.acceptMethod(tracedMethod, diagnostics);
           }
@@ -515,7 +555,7 @@
       public void registerInitClass(DexType clazz) {
         DexType rewrittenClass = graphLens().lookupType(clazz);
         DexField clinitField = appView.initClassLens().getInitClassField(rewrittenClass);
-        handleRewrittenFieldReference(clinitField);
+        handleRewrittenFieldReference(clinitField, getContext().getHolder(), referencedFrom);
       }
 
       @Override
@@ -540,29 +580,8 @@
 
       private void handleFieldAccess(DexField field) {
         FieldLookupResult lookupResult = graphLens().lookupFieldResult(field);
-        handleRewrittenFieldReference(lookupResult.getReference());
-      }
-
-      private void handleRewrittenFieldReference(DexField field) {
-        addType(field.getHolderType(), referencedFrom);
-        addType(field.getType(), referencedFrom);
-
-        FieldResolutionResult resolutionResult = appInfo().resolveField(field);
-        if (resolutionResult.hasSuccessfulResolutionResult()) {
-          resolutionResult.forEachSuccessfulFieldResolutionResult(
-              singleResolutionResult -> {
-                DexClassAndField resolvedField = singleResolutionResult.getResolutionPair();
-                if (isTargetType(resolvedField.getHolderType())) {
-                  handleMemberResolution(field, resolvedField, getContext(), referencedFrom);
-                  TracedFieldImpl tracedField = new TracedFieldImpl(resolvedField, referencedFrom);
-                  consumer.acceptField(tracedField, diagnostics);
-                }
-              });
-        } else {
-          TracedFieldImpl tracedField = new TracedFieldImpl(field, referencedFrom);
-          collectMissingField(tracedField);
-          consumer.acceptField(tracedField, diagnostics);
-        }
+        handleRewrittenFieldReference(
+            lookupResult.getReference(), getContext().getHolder(), referencedFrom);
       }
 
       // Type references.
diff --git a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationValuesReferencesInDexTest.java b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationClassConstantValuesReferencesInDexTest.java
similarity index 69%
rename from src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationValuesReferencesInDexTest.java
rename to src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationClassConstantValuesReferencesInDexTest.java
index b19958b..e56d1d3 100644
--- a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationValuesReferencesInDexTest.java
+++ b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationClassConstantValuesReferencesInDexTest.java
@@ -5,7 +5,6 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.fail;
 
 import com.android.tools.r8.DiagnosticsHandler;
 import com.android.tools.r8.TestBase;
@@ -13,18 +12,19 @@
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.references.FieldReference;
 import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.references.Reference;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.StringUtils;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
 import java.nio.file.Path;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import java.util.stream.Collectors;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -32,7 +32,7 @@
 import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
-public class TraceReferencesAnnotationValuesReferencesInDexTest extends TestBase {
+public class TraceReferencesAnnotationClassConstantValuesReferencesInDexTest extends TestBase {
 
   @Parameters(name = "{0}")
   public static TestParametersCollection data() {
@@ -42,16 +42,35 @@
   @Parameter(0)
   public TestParameters parameters;
 
-  private static List<Class<?>> SOURCE_CLASSES =
+  private static final List<Class<?>> SOURCE_CLASSES =
       ImmutableList.of(
           Source.class,
           SourceAnnotationWithClassConstant.class,
-          SourceAnnotationWithClassConstantArray.class);
+          SourceAnnotationWithClassConstantArray.class,
+          SourceAnnotationWithEnum.class,
+          SourceAnnotationWithEnumArray.class);
+
+  private static final List<Class<?>> TARGET_CLASSES =
+      ImmutableList.of(
+          TargetAnnotationWithInt.class,
+          TargetAnnotationWithLongArray.class,
+          TargetAnnotationWithClassConstant.class,
+          TargetAnnotationWithClassConstantArray.class,
+          TargetAnnotationWithEnum.class,
+          TargetAnnotationWithEnumArray.class,
+          A.class,
+          B.class,
+          C.class,
+          D.class,
+          E.class,
+          F.class,
+          G.class);
 
   static class Consumer implements TraceReferencesConsumer {
 
     Set<ClassReference> tracedTypes = new HashSet<>();
     Set<MethodReference> tracedMethods = new HashSet<>();
+    Set<FieldReference> tracedFields = new HashSet<>();
 
     @Override
     public void acceptType(TracedClass tracedClass, DiagnosticsHandler handler) {
@@ -61,7 +80,8 @@
 
     @Override
     public void acceptField(TracedField tracedField, DiagnosticsHandler handler) {
-      fail();
+      assertFalse(tracedField.isMissingDefinition());
+      tracedFields.add(tracedField.getReference());
     }
 
     @Override
@@ -75,17 +95,7 @@
     testForTraceReferences()
         .addLibraryFiles(ToolHelper.getAndroidJar(AndroidApiLevel.LATEST))
         .addSourceFiles(sourceDex)
-        .addTargetClasses(
-            TargetAnnotationWithInt.class,
-            TargetAnnotationWithLongArray.class,
-            TargetAnnotationWithClassConstant.class,
-            TargetAnnotationWithClassConstantArray.class,
-            A.class,
-            B.class,
-            C.class,
-            D.class,
-            E.class,
-            F.class)
+        .addTargetClasses(TARGET_CLASSES)
         .setConsumer(consumer)
         .trace();
   }
@@ -94,17 +104,7 @@
     Consumer consumer = new Consumer();
     runTest(sourceDex, consumer);
     assertEquals(
-        ImmutableSet.of(
-            Reference.classFromClass(TargetAnnotationWithInt.class),
-            Reference.classFromClass(TargetAnnotationWithLongArray.class),
-            Reference.classFromClass(TargetAnnotationWithClassConstant.class),
-            Reference.classFromClass(TargetAnnotationWithClassConstantArray.class),
-            Reference.classFromClass(A.class),
-            Reference.classFromClass(B.class),
-            Reference.classFromClass(C.class),
-            Reference.classFromClass(D.class),
-            Reference.classFromClass(E.class),
-            Reference.classFromClass(F.class)),
+        TARGET_CLASSES.stream().map(Reference::classFromClass).collect(Collectors.toSet()),
         consumer.tracedTypes);
   }
 
@@ -129,12 +129,26 @@
             "}",
             "-keep class " + F.class.getTypeName() + " {",
             "}",
+            "-keep enum " + G.class.getTypeName() + " {",
+            "  " + G.class.getTypeName() + " FIVE;",
+            "  " + G.class.getTypeName() + " FOUR;",
+            "  " + G.class.getTypeName() + " ONE;",
+            "  " + G.class.getTypeName() + " SIX;",
+            "  " + G.class.getTypeName() + " THREE;",
+            "  " + G.class.getTypeName() + " TWO;",
+            "}",
             "-keep @interface " + TargetAnnotationWithClassConstant.class.getTypeName() + " {",
             "  public java.lang.Class value();",
             "}",
             "-keep @interface " + TargetAnnotationWithClassConstantArray.class.getTypeName() + " {",
             "  public java.lang.Class[] value();",
             "}",
+            "-keep @interface " + TargetAnnotationWithEnum.class.getTypeName() + " {",
+            "  public " + G.class.getTypeName() + " value();",
+            "}",
+            "-keep @interface " + TargetAnnotationWithEnumArray.class.getTypeName() + " {",
+            "  public " + G.class.getTypeName() + "[] value();",
+            "}",
             "-keep @interface " + TargetAnnotationWithInt.class.getTypeName() + " {",
             "  public int value();",
             "}",
@@ -175,6 +189,15 @@
 
   public class F {}
 
+  public enum G {
+    ONE,
+    TWO,
+    THREE,
+    FOUR,
+    FIVE,
+    SIX
+  }
+
   @Retention(RetentionPolicy.RUNTIME)
   public @interface TargetAnnotationWithInt {
     int value() default 0;
@@ -196,6 +219,16 @@
   }
 
   @Retention(RetentionPolicy.RUNTIME)
+  public @interface TargetAnnotationWithEnum {
+    G value();
+  }
+
+  @Retention(RetentionPolicy.RUNTIME)
+  public @interface TargetAnnotationWithEnumArray {
+    G[] value();
+  }
+
+  @Retention(RetentionPolicy.RUNTIME)
   public @interface SourceAnnotationWithClassConstant {
     Class<?> value() default D.class;
   }
@@ -205,11 +238,25 @@
     Class<?>[] value() default {E.class, F.class};
   }
 
+  @Retention(RetentionPolicy.RUNTIME)
+  public @interface SourceAnnotationWithEnum {
+    G value() default G.FOUR;
+  }
+
+  @Retention(RetentionPolicy.RUNTIME)
+  public @interface SourceAnnotationWithEnumArray {
+    G[] value() default {G.FIVE, G.SIX};
+  }
+
   @TargetAnnotationWithInt(1)
   @TargetAnnotationWithLongArray({2L, 3L})
   @TargetAnnotationWithClassConstant(A.class)
   @TargetAnnotationWithClassConstantArray({B.class, C.class})
+  @TargetAnnotationWithEnum(G.ONE)
+  @TargetAnnotationWithEnumArray({G.TWO, G.THREE})
   @SourceAnnotationWithClassConstant
   @SourceAnnotationWithClassConstantArray
+  @SourceAnnotationWithEnum
+  @SourceAnnotationWithEnumArray
   static class Source {}
 }