[TraceReferences] Add tracing of annotations

Bug: b/376016627
Change-Id: Ia63f2eaa1d96737d55f18a3bd17deed6615ad6c5
diff --git a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
index e3b63b4..53af34e 100644
--- a/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
+++ b/src/main/java/com/android/tools/r8/tracereferences/Tracer.java
@@ -22,6 +22,7 @@
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexProto;
+import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexTypeList;
 import com.android.tools.r8.graph.DexValue;
@@ -77,6 +78,9 @@
             useCollector.registerMethod(method);
             useCollector.traceCode(method);
           });
+      // This iterates all annotations on the class, including on its methods and fields.
+      clazz.forEachAnnotation(
+          dexAnnotation -> useCollector.registerAnnotation(dexAnnotation, classContext));
     }
     consumer.finished(diagnostics);
   }
@@ -95,6 +99,8 @@
     private final Set<FieldReference> missingFields = new HashSet<>();
     private final Set<MethodReference> missingMethods = new HashSet<>();
 
+    public final DexString dalvikAnnotationCodegenPrefix;
+
     UseCollector(
         AppView<? extends AppInfoWithClassHierarchy> appView,
         TraceReferencesConsumer consumer,
@@ -105,6 +111,7 @@
       this.consumer = consumer;
       this.diagnostics = diagnostics;
       this.targetPredicate = targetPredicate;
+      this.dalvikAnnotationCodegenPrefix = factory.createString("Ldalvik/annotation/codegen/");
     }
 
     AppView<? extends AppInfoWithClassHierarchy> appView() {
@@ -270,6 +277,38 @@
           });
     }
 
+    private void registerAnnotation(DexAnnotation annotation, DefinitionContext referencedFrom) {
+      DexType type = annotation.getAnnotationType();
+      assert type.isClassType();
+      if (type.isIdenticalTo(factory.annotationThrows)
+          || type.isIdenticalTo(factory.annotationDefault)
+          || type.isIdenticalTo(factory.annotationMethodParameters)
+          || type.isIdenticalTo(factory.annotationReachabilitySensitive)
+          || type.getDescriptor().startsWith(factory.dalvikAnnotationOptimizationPrefix)
+          || type.getDescriptor().startsWith(dalvikAnnotationCodegenPrefix)) {
+        // The remaining system annotations
+        //   dalvik.annotation.EnclosingClass
+        //   dalvik.annotation.EnclosingMethod
+        //   dalvik.annotation.InnerClass
+        //   dalvik.annotation.MemberClasses
+        //   dalvik.annotation.Signature
+        //   dalvik.annotation.NestHost (*)
+        //   dalvik.annotation.NestMembers (*)
+        //   dalvik.annotation.Record (*)
+        //   dalvik.annotation.PermittedSubclasses (*)
+        // are not added as annotations in the DexParser.
+        //
+        // (*) Not officially supported and documented.
+        return;
+      }
+      assert !type.getDescriptor().startsWith(factory.dalvikAnnotationPrefix)
+          : "Unexpected annotation with prefix "
+              + factory.dalvikAnnotationPrefix
+              + ": "
+              + type.getDescriptor();
+      addClassType(type, referencedFrom);
+    }
+
     class MethodUseCollector extends UseRegistry<ProgramMethod> {
 
       private final DefinitionContext referencedFrom;
diff --git a/src/test/java/com/android/tools/r8/partial/PartialCompilationClassAnnotatedInD8Test.java b/src/test/java/com/android/tools/r8/partial/PartialCompilationClassAnnotatedInD8Test.java
index beef080..be3c872 100644
--- a/src/test/java/com/android/tools/r8/partial/PartialCompilationClassAnnotatedInD8Test.java
+++ b/src/test/java/com/android/tools/r8/partial/PartialCompilationClassAnnotatedInD8Test.java
@@ -3,13 +3,13 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.partial;
 
-import static org.hamcrest.CoreMatchers.containsString;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
 
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper.DexVm;
-import com.android.tools.r8.ToolHelper.DexVm.Version;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.StringUtils;
 import java.lang.annotation.Retention;
@@ -68,16 +68,8 @@
         .setR8PartialConfiguration(
             builder -> builder.includeAll().excludeClasses(AnnotatedClass.class))
         .run(parameters.getRuntime(), Main.class, getClass().getTypeName())
-        // TODO(b/376016627): Should succeed with EXPECTED_RESULT.
-        .applyIf(
-            parameters.isDexRuntimeVersionOlderThanOrEqual(Version.DEFAULT),
-            r ->
-                r.assertStderrMatches(
-                    containsString(
-                        "Unable to resolve java.lang.Class<"
-                            + AnnotatedClass.class.getTypeName()
-                            + "> annotation class 2")))
-        .assertSuccessWithOutputLines("0");
+        .inspect(inspector -> assertThat(inspector.clazz(Annotation.class), isPresent()))
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
   // Compiled with R8
diff --git a/src/test/java/com/android/tools/r8/partial/PartialCompilationFieldAnnotatedInD8Test.java b/src/test/java/com/android/tools/r8/partial/PartialCompilationFieldAnnotatedInD8Test.java
index ea13e25..e0140e0 100644
--- a/src/test/java/com/android/tools/r8/partial/PartialCompilationFieldAnnotatedInD8Test.java
+++ b/src/test/java/com/android/tools/r8/partial/PartialCompilationFieldAnnotatedInD8Test.java
@@ -3,7 +3,8 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.partial;
 
-import static org.hamcrest.CoreMatchers.containsString;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
 
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
@@ -69,16 +70,8 @@
         .setR8PartialConfiguration(
             builder -> builder.includeAll().excludeClasses(AnnotatedClass.class))
         .run(parameters.getRuntime(), Main.class, getClass().getTypeName())
-        // TODO(b/376016627): Should succeed with EXPECTED_RESULT.
-        .applyIf(
-            parameters.isDexRuntimeVersionOlderThanOrEqual(Version.DEFAULT),
-            r ->
-                r.assertStderrMatches(
-                    containsString(
-                        "Unable to resolve java.lang.Class<"
-                            + AnnotatedClass.class.getTypeName()
-                            + "> annotation class 2")))
-        .assertSuccessWithOutputLines("0");
+        .inspect(inspector -> assertThat(inspector.clazz(Annotation.class), isPresent()))
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
   // Compiled with R8
diff --git a/src/test/java/com/android/tools/r8/partial/PartialCompilationMethodAnnotatedInD8Test.java b/src/test/java/com/android/tools/r8/partial/PartialCompilationMethodAnnotatedInD8Test.java
index c847694..2d0f94b 100644
--- a/src/test/java/com/android/tools/r8/partial/PartialCompilationMethodAnnotatedInD8Test.java
+++ b/src/test/java/com/android/tools/r8/partial/PartialCompilationMethodAnnotatedInD8Test.java
@@ -3,7 +3,8 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.partial;
 
-import static org.hamcrest.CoreMatchers.containsString;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
 
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
@@ -70,16 +71,8 @@
         .setR8PartialConfiguration(
             builder -> builder.includeAll().excludeClasses(AnnotatedClass.class))
         .run(parameters.getRuntime(), Main.class, getClass().getTypeName())
-        // TODO(b/376016627): Should succeed with EXPECTED_RESULT.
-        .applyIf(
-            parameters.isDexRuntimeVersionOlderThanOrEqual(Version.DEFAULT),
-            r ->
-                r.assertStderrMatches(
-                    containsString(
-                        "Unable to resolve java.lang.Class<"
-                            + AnnotatedClass.class.getTypeName()
-                            + "> annotation class 2")))
-        .assertSuccessWithOutputLines("0");
+        .inspect(inspector -> assertThat(inspector.clazz(Annotation.class), isPresent()))
+        .assertSuccessWithOutput(EXPECTED_OUTPUT);
   }
 
   @Retention(RetentionPolicy.RUNTIME)
diff --git a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesMissingAnnotationReferencesInDexTest.java b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesMissingAnnotationReferencesInDexTest.java
new file mode 100644
index 0000000..99ca76c
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesMissingAnnotationReferencesInDexTest.java
@@ -0,0 +1,137 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.tracereferences;
+
+import static com.android.tools.r8.tracereferences.TraceReferencesTestUtils.collectMissingClassReferences;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.android.tools.r8.CompilationFailedException;
+import com.android.tools.r8.DiagnosticsHandler;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.references.Reference;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.google.common.collect.ImmutableSet;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+import java.nio.file.Path;
+import java.util.HashSet;
+import java.util.Set;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class TraceReferencesMissingAnnotationReferencesInDexTest extends TestBase {
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withNoneRuntime().build();
+  }
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  static class MissingReferencesConsumer implements TraceReferencesConsumer {
+
+    Set<ClassReference> missingTypes = new HashSet<>();
+
+    @Override
+    public void acceptType(TracedClass tracedClass, DiagnosticsHandler handler) {
+      assertTrue(tracedClass.isMissingDefinition());
+      missingTypes.add(tracedClass.getReference());
+    }
+
+    @Override
+    public void acceptField(TracedField tracedField, DiagnosticsHandler handler) {
+      fail();
+    }
+
+    @Override
+    public void acceptMethod(TracedMethod tracedMethod, DiagnosticsHandler handler) {
+      fail();
+    }
+  }
+
+  private void missingClassReferenced(Path sourceDex) {
+    Set<ClassReference> expectedMissingClasses =
+        ImmutableSet.of(
+            Reference.classFromClass(ClassAnnotation.class),
+            Reference.classFromClass(FieldAnnotation.class),
+            Reference.classFromClass(MethodAnnotation.class),
+            Reference.classFromClass(ConstructorAnnotation.class),
+            Reference.classFromClass(ParameterAnnotation.class));
+
+    MissingReferencesConsumer consumer = new MissingReferencesConsumer();
+    assertThrows(
+        CompilationFailedException.class,
+        () ->
+            testForTraceReferences()
+                .addLibraryFiles(ToolHelper.getAndroidJar(AndroidApiLevel.LATEST))
+                .addSourceFiles(sourceDex)
+                .setConsumer(new TraceReferencesCheckConsumer(consumer))
+                .traceWithExpectedDiagnostics(
+                    diagnostics ->
+                        assertEquals(
+                            expectedMissingClasses, collectMissingClassReferences(diagnostics))));
+
+    assertEquals(expectedMissingClasses, consumer.missingTypes);
+  }
+
+  @Test
+  public void missingClassReferencedInDexArchive() throws Throwable {
+    missingClassReferenced(
+        testForD8(Backend.DEX).addProgramClasses(Source.class).compile().writeToZip());
+  }
+
+  @Test
+  public void missingClassReferencedInDexFile() throws Throwable {
+    missingClassReferenced(
+        testForD8(Backend.DEX)
+            .addProgramClasses(Source.class)
+            .compile()
+            .writeToDirectory()
+            .resolve("classes.dex"));
+  }
+
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.TYPE)
+  public @interface ClassAnnotation {}
+
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.FIELD)
+  public @interface FieldAnnotation {}
+
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.METHOD)
+  public @interface MethodAnnotation {}
+
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.CONSTRUCTOR)
+  public @interface ConstructorAnnotation {}
+
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.PARAMETER)
+  public @interface ParameterAnnotation {}
+
+  @ClassAnnotation
+  static class Source {
+    @FieldAnnotation public static int field;
+
+    @ConstructorAnnotation
+    public Source() {}
+
+    @MethodAnnotation
+    public static void source(@ParameterAnnotation int param) {}
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesMissingReferencesInDexTest.java b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesMissingReferencesInDexTest.java
index 74786ad..c0f8d1e 100644
--- a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesMissingReferencesInDexTest.java
+++ b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesMissingReferencesInDexTest.java
@@ -3,25 +3,33 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.tracereferences;
 
+import static com.android.tools.r8.DiagnosticsMatcher.diagnosticType;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 import com.android.tools.r8.CompilationFailedException;
-import com.android.tools.r8.DiagnosticsChecker;
 import com.android.tools.r8.DiagnosticsHandler;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.diagnostic.MissingDefinitionsDiagnostic;
+import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.references.FieldReference;
+import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.references.Reference;
 import com.android.tools.r8.utils.AndroidApiLevel;
+import com.google.common.collect.ImmutableSet;
 import java.io.IOException;
 import java.nio.file.Path;
+import java.util.HashSet;
+import java.util.Set;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
@@ -31,63 +39,75 @@
     return getTestParameters().withNoneRuntime().build();
   }
 
-  private static final AndroidApiLevel minApi = AndroidApiLevel.B;
-
-  public TraceReferencesMissingReferencesInDexTest(TestParameters parameters) {
-    parameters.assertNoneRuntime();
-  }
+  @Parameter(0)
+  public TestParameters parameters;
 
   static class MissingReferencesConsumer implements TraceReferencesConsumer {
 
-    boolean acceptTypeCalled;
-    boolean acceptFieldCalled;
-    boolean acceptMethodCalled;
+    Set<ClassReference> missingTypes = new HashSet<>();
+    Set<FieldReference> missingFields = new HashSet<>();
+    Set<MethodReference> missingMethods = new HashSet<>();
+
+    boolean hasMissingTypes() {
+      return missingTypes.size() > 0;
+    }
+
+    boolean hasMissingFields() {
+      return missingFields.size() > 0;
+    }
+
+    boolean hasMissingMethods() {
+      return missingMethods.size() > 0;
+    }
 
     @Override
     public void acceptType(TracedClass tracedClass, DiagnosticsHandler handler) {
-      acceptTypeCalled = true;
-      assertEquals(Reference.classFromClass(Target.class), tracedClass.getReference());
       assertTrue(tracedClass.isMissingDefinition());
+      missingTypes.add(tracedClass.getReference());
     }
 
     @Override
     public void acceptField(TracedField tracedField, DiagnosticsHandler handler) {
-      acceptFieldCalled = true;
+      assertTrue(tracedField.isMissingDefinition());
+      missingFields.add(tracedField.getReference());
       assertEquals(
           Reference.classFromClass(Target.class), tracedField.getReference().getHolderClass());
       assertEquals("field", tracedField.getReference().getFieldName());
-      assertTrue(tracedField.isMissingDefinition());
     }
 
     @Override
     public void acceptMethod(TracedMethod tracedMethod, DiagnosticsHandler handler) {
-      acceptMethodCalled = true;
+      assertTrue(tracedMethod.isMissingDefinition());
+      missingMethods.add(tracedMethod.getReference());
       assertEquals(
           Reference.classFromClass(Target.class), tracedMethod.getReference().getHolderClass());
       assertEquals("target", tracedMethod.getReference().getMethodName());
-      assertTrue(tracedMethod.isMissingDefinition());
     }
   }
 
   private void missingClassReferenced(Path sourceDex) {
-    DiagnosticsChecker diagnosticsChecker = new DiagnosticsChecker();
+    Set<ClassReference> expectedMissingClasses =
+        ImmutableSet.of(Reference.classFromClass(Target.class));
+
     MissingReferencesConsumer consumer = new MissingReferencesConsumer();
+    assertThrows(
+        CompilationFailedException.class,
+        () ->
+            testForTraceReferences()
+                .addLibraryFiles(ToolHelper.getAndroidJar(AndroidApiLevel.LATEST))
+                .addSourceFiles(sourceDex)
+                .setConsumer(new TraceReferencesCheckConsumer(consumer))
+                .traceWithExpectedDiagnostics(
+                    diagnostics ->
+                        assertEquals(
+                            expectedMissingClasses,
+                            TraceReferencesTestUtils.filterAndCollectMissingClassReferences(
+                                diagnostics))));
 
-    try {
-      TraceReferences.run(
-          TraceReferencesCommand.builder(diagnosticsChecker)
-              .addLibraryFiles(ToolHelper.getAndroidJar(AndroidApiLevel.P))
-              .addSourceFiles(sourceDex)
-              .setConsumer(new TraceReferencesCheckConsumer(consumer))
-              .build());
-      fail("Expected compilation to fail");
-    } catch (CompilationFailedException e) {
-      // Expected.
-    }
-
-    assertTrue(consumer.acceptTypeCalled);
-    assertTrue(consumer.acceptFieldCalled);
-    assertTrue(consumer.acceptMethodCalled);
+    assertTrue(consumer.hasMissingTypes());
+    assertEquals(expectedMissingClasses, consumer.missingTypes);
+    assertTrue(consumer.hasMissingFields());
+    assertTrue(consumer.hasMissingMethods());
   }
 
   @Test
@@ -95,7 +115,6 @@
     missingClassReferenced(
         testForD8(Backend.DEX)
             .addProgramClasses(Source.class)
-            .setMinApi(minApi)
             .compile()
             .writeToZip());
   }
@@ -105,31 +124,28 @@
     missingClassReferenced(
         testForD8(Backend.DEX)
             .addProgramClasses(Source.class)
-            .setMinApi(minApi)
             .compile()
             .writeToDirectory()
             .resolve("classes.dex"));
   }
 
   private void missingFieldAndMethodReferenced(Path sourceDex) {
-    DiagnosticsChecker diagnosticsChecker = new DiagnosticsChecker();
     MissingReferencesConsumer consumer = new MissingReferencesConsumer();
+    assertThrows(
+        CompilationFailedException.class,
+        () ->
+            testForTraceReferences()
+                .addLibraryFiles(ToolHelper.getAndroidJar(AndroidApiLevel.LATEST))
+                .addSourceFiles(sourceDex)
+                .setConsumer(new TraceReferencesCheckConsumer(consumer))
+                .traceWithExpectedDiagnostics(
+                    diagnostics ->
+                        diagnostics.assertAllDiagnosticsMatch(
+                            diagnosticType(MissingDefinitionsDiagnostic.class))));
 
-    try {
-      TraceReferences.run(
-          TraceReferencesCommand.builder(diagnosticsChecker)
-              .addLibraryFiles(ToolHelper.getAndroidJar(AndroidApiLevel.P))
-              .addSourceFiles(sourceDex)
-              .setConsumer(new TraceReferencesCheckConsumer(consumer))
-              .build());
-      fail("Expected compilation to fail");
-    } catch (CompilationFailedException e) {
-      // Expected.
-    }
-
-    assertFalse(consumer.acceptTypeCalled);
-    assertTrue(consumer.acceptFieldCalled);
-    assertTrue(consumer.acceptMethodCalled);
+    assertFalse(consumer.hasMissingTypes());
+    assertTrue(consumer.hasMissingFields());
+    assertTrue(consumer.hasMissingMethods());
   }
 
   @Test
@@ -138,7 +154,6 @@
         testForD8(Backend.DEX)
             .addProgramClasses(Source.class)
             .addProgramClassFileData(getClassWithTargetRemoved())
-            .setMinApi(minApi)
             .compile()
             .writeToZip());
   }
@@ -149,7 +164,6 @@
         testForD8(Backend.DEX)
             .addProgramClasses(Source.class)
             .addProgramClassFileData(getClassWithTargetRemoved())
-            .setMinApi(minApi)
             .compile()
             .writeToDirectory()
             .resolve("classes.dex"));
diff --git a/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesTestUtils.java b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesTestUtils.java
new file mode 100644
index 0000000..5169ced
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/tracereferences/TraceReferencesTestUtils.java
@@ -0,0 +1,35 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.tracereferences;
+
+import com.android.tools.r8.TestDiagnosticMessages;
+import com.android.tools.r8.diagnostic.MissingDefinitionInfo;
+import com.android.tools.r8.diagnostic.MissingDefinitionsDiagnostic;
+import com.android.tools.r8.references.ClassReference;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class TraceReferencesTestUtils {
+
+  public static Set<ClassReference> collectMissingClassReferences(
+      TestDiagnosticMessages diagnostic) {
+    return diagnostic
+        .assertSingleErrorDiagnosticType(MissingDefinitionsDiagnostic.class)
+        .getMissingDefinitions()
+        .stream()
+        .map(info -> info.asMissingClass().getClassReference())
+        .collect(Collectors.toSet());
+  }
+
+  public static Set<ClassReference> filterAndCollectMissingClassReferences(
+      TestDiagnosticMessages diagnostic) {
+    return diagnostic
+        .assertSingleErrorDiagnosticType(MissingDefinitionsDiagnostic.class)
+        .getMissingDefinitions()
+        .stream()
+        .filter(MissingDefinitionInfo::isMissingClass)
+        .map(info -> info.asMissingClass().getClassReference())
+        .collect(Collectors.toSet());
+  }
+}
diff --git a/src/test/testbase/java/com/android/tools/r8/TestDiagnosticMessages.java b/src/test/testbase/java/com/android/tools/r8/TestDiagnosticMessages.java
index 54f35a3..d20d678 100644
--- a/src/test/testbase/java/com/android/tools/r8/TestDiagnosticMessages.java
+++ b/src/test/testbase/java/com/android/tools/r8/TestDiagnosticMessages.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8;
 
 import static com.android.tools.r8.DiagnosticsMatcher.diagnosticMessage;
+import static com.android.tools.r8.DiagnosticsMatcher.diagnosticType;
 import static org.hamcrest.CoreMatchers.not;
 import static org.junit.Assert.assertEquals;
 
@@ -200,4 +201,10 @@
   public final TestDiagnosticMessages assertNoErrorsMatch(Matcher<Diagnostic> matcher) {
     return assertAllErrorsMatch(not(matcher));
   }
+
+  @SuppressWarnings("unchecked")
+  public final <T extends Diagnostic> T assertSingleErrorDiagnosticType(Class<T> diagnosticType) {
+    return (T)
+        assertErrorsMatch(diagnosticType(diagnosticType)).assertErrorsCount(1).getErrors().get(0);
+  }
 }
diff --git a/src/test/testbase/java/com/android/tools/r8/TraceReferencesTestBuilder.java b/src/test/testbase/java/com/android/tools/r8/TraceReferencesTestBuilder.java
index aac444f..d86cbbd 100644
--- a/src/test/testbase/java/com/android/tools/r8/TraceReferencesTestBuilder.java
+++ b/src/test/testbase/java/com/android/tools/r8/TraceReferencesTestBuilder.java
@@ -3,10 +3,15 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8;
 
+import com.android.tools.r8.TestCompilerBuilder.DiagnosticsConsumer;
 import com.android.tools.r8.tracereferences.TraceReferences;
 import com.android.tools.r8.tracereferences.TraceReferencesCommand;
+import com.android.tools.r8.tracereferences.TraceReferencesConsumer;
 import com.android.tools.r8.utils.ZipUtils.ZipBuilder;
 import java.io.IOException;
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.Collection;
 
 public class TraceReferencesTestBuilder {
 
@@ -16,12 +21,44 @@
 
   public TraceReferencesTestBuilder(TestState state) {
     this.builder =
-        TraceReferencesCommand.builder()
+        TraceReferencesCommand.builder(state.getDiagnosticsHandler())
             .addLibraryFiles(ToolHelper.getMostRecentAndroidJar())
             .setConsumer(inspector);
     this.state = state;
   }
 
+  public TraceReferencesTestBuilder addLibraryFiles(Collection<Path> files) {
+    builder.addLibraryFiles(files);
+    return this;
+  }
+
+  public TraceReferencesTestBuilder addLibraryFiles(Path... files) {
+    return addLibraryFiles(Arrays.asList(files));
+  }
+
+  public TraceReferencesTestBuilder addSourceFiles(Collection<Path> files) {
+    builder.addSourceFiles(files);
+    return this;
+  }
+
+  public TraceReferencesTestBuilder addSourceFiles(Path... files) {
+    return addSourceFiles(Arrays.asList(files));
+  }
+
+  public TraceReferencesTestBuilder addTargetFiles(Collection<Path> files) {
+    builder.addTargetFiles(files);
+    return this;
+  }
+
+  public TraceReferencesTestBuilder addTargetFiles(Path... files) {
+    return addTargetFiles(Arrays.asList(files));
+  }
+
+  public TraceReferencesTestBuilder setConsumer(TraceReferencesConsumer consumer) {
+    builder.setConsumer(consumer);
+    return this;
+  }
+
   public TraceReferencesTestBuilder addInnerClassesAsSourceClasses(Class<?> clazz)
       throws IOException {
     builder.addSourceFiles(
@@ -46,4 +83,17 @@
     TraceReferences.run(builder.build());
     return new TraceReferencesTestResult(inspector);
   }
+
+  public <E extends Exception> TraceReferencesTestResult traceWithExpectedDiagnostics(
+      DiagnosticsConsumer<E> diagnosticsConsumer) throws CompilationFailedException, E {
+    TestDiagnosticMessages diagnosticsHandler = state.getDiagnosticsMessages();
+    try {
+      TraceReferencesTestResult result = trace();
+      diagnosticsConsumer.accept(diagnosticsHandler);
+      return result;
+    } catch (CompilationFailedException e) {
+      diagnosticsConsumer.accept(diagnosticsHandler);
+      throw e;
+    }
+  }
 }