[TraceReferences] Refactor handling of dalvik.annotation.Throws

Bug: b/376016627
Change-Id: I98f600620f12e5fa5de7db07e5954baa87a4a2b8
diff --git a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
index 8c300e3..ed21df7 100644
--- a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
+++ b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
@@ -27,7 +27,6 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexTypeList;
 import com.android.tools.r8.graph.DexValue;
-import com.android.tools.r8.graph.DexValue.DexValueArray;
 import com.android.tools.r8.graph.FieldResolutionResult;
 import com.android.tools.r8.graph.MethodResolutionResult;
 import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
@@ -81,8 +80,9 @@
             useCollector.traceCode(method);
           });
       // This iterates all annotations on the class, including on its methods and fields.
-      clazz.forEachAnnotation(
-          dexAnnotation -> useCollector.registerAnnotation(dexAnnotation, classContext));
+      clazz
+          .annotations()
+          .forEach(dexAnnotation -> useCollector.registerAnnotation(dexAnnotation, classContext));
     }
     consumer.finished(diagnostics);
   }
@@ -245,20 +245,21 @@
     private void registerField(ProgramField field) {
       DefinitionContext referencedFrom = DefinitionContextUtils.create(field);
       addType(field.getType(), referencedFrom);
+      field
+          .getAnnotations()
+          .forEach(dexAnnotation -> registerAnnotation(dexAnnotation, referencedFrom));
     }
 
     private void registerMethod(ProgramMethod method) {
       DefinitionContext referencedFrom = DefinitionContextUtils.create(method);
       addTypes(method.getParameters(), referencedFrom);
       addType(method.getReturnType(), referencedFrom);
-      for (DexAnnotation annotation : method.getAnnotations().getAnnotations()) {
-        if (annotation.getAnnotationType().isIdenticalTo(factory.annotationThrows)) {
-          DexValueArray dexValues = annotation.annotation.elements[0].value.asDexValueArray();
-          for (DexValue dexValType : dexValues.getValues()) {
-            addType(dexValType.asDexValueType().value, referencedFrom);
-          }
-        }
-      }
+      method
+          .getAnnotations()
+          .forEach(dexAnnotation -> registerAnnotation(dexAnnotation, referencedFrom));
+      method
+          .getParameterAnnotations()
+          .forEachAnnotation(dexAnnotation -> registerAnnotation(dexAnnotation, referencedFrom));
     }
 
     private void traceCode(ProgramMethod method) {
@@ -289,8 +290,7 @@
     private void registerAnnotation(DexAnnotation annotation, DefinitionContext referencedFrom) {
       DexType type = annotation.getAnnotationType();
       assert type.isClassType();
-      if (type.isIdenticalTo(factory.annotationThrows)
-          || type.isIdenticalTo(factory.annotationMethodParameters)
+      if (type.isIdenticalTo(factory.annotationMethodParameters)
           || type.isIdenticalTo(factory.annotationReachabilitySensitive)
           || type.getDescriptor().startsWith(factory.dalvikAnnotationOptimizationPrefix)
           || type.getDescriptor().startsWith(dalvikAnnotationCodegenPrefix)) {
@@ -310,6 +310,7 @@
         return;
       }
       if (type.isIdenticalTo(factory.annotationDefault)) {
+        assert referencedFrom.isClassContext();
         annotation
             .getAnnotation()
             .forEachElement(
@@ -320,6 +321,11 @@
                 });
         return;
       }
+      if (type.isIdenticalTo(factory.annotationThrows)) {
+        assert referencedFrom.isMethodContext();
+        registerDexValue(annotation.annotation.elements[0].value.asDexValueArray(), referencedFrom);
+        return;
+      }
       assert !type.getDescriptor().startsWith(factory.dalvikAnnotationPrefix)
           : "Unexpected annotation with prefix "
               + factory.dalvikAnnotationPrefix
diff --git a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationReferencesInDexTest.java b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationReferencesInDexTest.java
index 4ecfb4c..60fedb5 100644
--- a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationReferencesInDexTest.java
+++ b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesAnnotationReferencesInDexTest.java
@@ -5,6 +5,7 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import com.android.tools.r8.DiagnosticsHandler;
@@ -12,8 +13,8 @@
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.diagnostic.DefinitionContext;
 import com.android.tools.r8.references.ClassReference;
-import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.references.Reference;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.StringUtils;
@@ -23,8 +24,8 @@
 import java.lang.annotation.RetentionPolicy;
 import java.lang.annotation.Target;
 import java.nio.file.Path;
-import java.util.HashSet;
-import java.util.Set;
+import java.util.HashMap;
+import java.util.Map;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -43,13 +44,14 @@
 
   static class Consumer implements TraceReferencesConsumer {
 
-    Set<ClassReference> tracedTypes = new HashSet<>();
-    Set<MethodReference> tracedMethods = new HashSet<>();
+    Map<ClassReference, DefinitionContext> tracedTypes = new HashMap<>();
 
     @Override
     public void acceptType(TracedClass tracedClass, DiagnosticsHandler handler) {
       assertFalse(tracedClass.isMissingDefinition());
-      tracedTypes.add(tracedClass.getReference());
+      DefinitionContext prev =
+          tracedTypes.put(tracedClass.getReference(), tracedClass.getReferencedFromContext());
+      assert prev == null;
     }
 
     @Override
@@ -87,7 +89,26 @@
             Reference.classFromClass(FieldAnnotation.class),
             Reference.classFromClass(MethodAnnotation.class),
             Reference.classFromClass(ParameterAnnotation.class)),
-        consumer.tracedTypes);
+        consumer.tracedTypes.keySet());
+    assertTrue(
+        consumer.tracedTypes.get(Reference.classFromClass(ClassAnnotation.class)).isClassContext());
+    assertTrue(
+        consumer
+            .tracedTypes
+            .get(Reference.classFromClass(ConstructorAnnotation.class))
+            .isMethodContext());
+    assertTrue(
+        consumer.tracedTypes.get(Reference.classFromClass(FieldAnnotation.class)).isFieldContext());
+    assertTrue(
+        consumer
+            .tracedTypes
+            .get(Reference.classFromClass(MethodAnnotation.class))
+            .isMethodContext());
+    assertTrue(
+        consumer
+            .tracedTypes
+            .get(Reference.classFromClass(ParameterAnnotation.class))
+            .isMethodContext());
   }
 
   private void testGeneratedKeepRules(Path sourceDex) throws Exception {
diff --git a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesThrowsInDexTest.java b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesThrowsInDexTest.java
new file mode 100644
index 0000000..c7d81cc
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesThrowsInDexTest.java
@@ -0,0 +1,136 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.tracereferences;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.fail;
+
+import com.android.tools.r8.DiagnosticsHandler;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.references.MethodReference;
+import com.android.tools.r8.references.Reference;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.StringUtils;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import java.nio.file.Path;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class TraceReferencesThrowsInDexTest extends TestBase {
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withNoneRuntime().build();
+  }
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  private static List<Class<?>> SOURCE_CLASSES = ImmutableList.of(Source.class);
+
+  static class Consumer implements TraceReferencesConsumer {
+
+    Set<ClassReference> tracedTypes = new HashSet<>();
+    Set<MethodReference> tracedMethods = new HashSet<>();
+
+    @Override
+    public void acceptType(TracedClass tracedClass, DiagnosticsHandler handler) {
+      assertFalse(tracedClass.isMissingDefinition());
+      tracedTypes.add(tracedClass.getReference());
+    }
+
+    @Override
+    public void acceptField(TracedField tracedField, DiagnosticsHandler handler) {
+      fail();
+    }
+
+    @Override
+    public void acceptMethod(TracedMethod tracedMethod, DiagnosticsHandler handler) {
+      assertFalse(tracedMethod.isMissingDefinition());
+      tracedMethods.add(tracedMethod.getReference());
+    }
+  }
+
+  private void runTest(Path sourceDex, TraceReferencesConsumer consumer) throws Exception {
+    testForTraceReferences()
+        .addLibraryFiles(ToolHelper.getAndroidJar(AndroidApiLevel.LATEST))
+        .addSourceFiles(sourceDex)
+        .addTargetClasses(E1.class, E2.class, E3.class)
+        .setConsumer(consumer)
+        .trace();
+  }
+
+  private void test(Path sourceDex) throws Exception {
+    Consumer consumer = new Consumer();
+    runTest(sourceDex, consumer);
+    assertEquals(
+        ImmutableSet.of(
+            Reference.classFromClass(E1.class),
+            Reference.classFromClass(E2.class),
+            Reference.classFromClass(E3.class)),
+        consumer.tracedTypes);
+  }
+
+  private void testGeneratedKeepRules(Path sourceDex) throws Exception {
+    StringBuilder keepRulesBuilder = new StringBuilder();
+    runTest(
+        sourceDex,
+        TraceReferencesKeepRules.builder()
+            .setOutputConsumer((string, handler) -> keepRulesBuilder.append(string))
+            .build());
+    String expected =
+        StringUtils.lines(
+            "-keep class " + E1.class.getTypeName() + " {",
+            "}",
+            "-keep class " + E2.class.getTypeName() + " {",
+            "}",
+            "-keep class " + E3.class.getTypeName() + " {",
+            "}");
+    assertEquals(expected, keepRulesBuilder.toString());
+  }
+
+  @Test
+  public void testDexArchive() throws Throwable {
+    Path archive = testForD8(Backend.DEX).addProgramClasses(SOURCE_CLASSES).compile().writeToZip();
+    test(archive);
+    testGeneratedKeepRules(archive);
+  }
+
+  @Test
+  public void testDexFile() throws Throwable {
+    Path dex =
+        testForD8(Backend.DEX)
+            .addProgramClasses(SOURCE_CLASSES)
+            .compile()
+            .writeToDirectory()
+            .resolve("classes.dex");
+    test(dex);
+    testGeneratedKeepRules(dex);
+  }
+
+  public static class E1 extends Exception {}
+
+  public static class E2 extends Exception {}
+
+  public static class E3 extends Exception {}
+
+  static class Source {
+    public void m1() throws E1 {}
+
+    public void m2() throws E2, E3 {}
+  }
+}