Reland "Add callback API to obtain the compilation unit for desugaring."

This reverts commit 0187b2b54c6f670dfa2cf659cdddd18f9bbedab4.
Also contains fix for incorrectly assuming only program-input origins
are in the dependent set (see b/167562221).

Fixes: b/226055157
Change-Id: I59376d46ce250bdfebc44decf2783c314e77c43e
diff --git a/src/main/java/com/android/tools/r8/DesugarGraphConsumer.java b/src/main/java/com/android/tools/r8/DesugarGraphConsumer.java
index 2210490..fd553b3 100644
--- a/src/main/java/com/android/tools/r8/DesugarGraphConsumer.java
+++ b/src/main/java/com/android/tools/r8/DesugarGraphConsumer.java
@@ -10,6 +10,29 @@
 public interface DesugarGraphConsumer {
 
   /**
+   * Callback indicating that the {@code node} is a program input which is part of the current
+   * compilation unit for desugaring.
+   *
+   * <p>Note: this callback is guaranteed to be called on every *program-input* origin that could be
+   * passed as a {@code dependent} in a callback to {@code accept(Orign dependent, Origin
+   * dependency)}. It is also guaranteed to be called before any such call. In effect, this callback
+   * will receive the complete set of program-input origins for the compilation unit that is being
+   * desugared and it can reliably be used to remove any existing and potentially stale edges
+   * pertaining to those origins from a dependency graph maintained in the client.
+   *
+   * <p>Note: this will not receive a callback for classpath origins.
+   *
+   * <p>Note: this callback may be called on multiple threads.
+   *
+   * <p>Note: this callback places no guarantees on order of calls or on duplicate calls.
+   *
+   * @param node Origin of code that is part of the program input in the compilation unit.
+   */
+  default void acceptProgramNode(Origin node) {
+    // Default behavior ignores the node callbacks.
+  }
+
+  /**
    * Callback indicating that code originating from {@code dependency} is needed to correctly
    * desugar code originating from {@code dependent}.
    *
diff --git a/src/main/java/com/android/tools/r8/graph/JarClassFileReader.java b/src/main/java/com/android/tools/r8/graph/JarClassFileReader.java
index afbda8d..8e5ec1f 100644
--- a/src/main/java/com/android/tools/r8/graph/JarClassFileReader.java
+++ b/src/main/java/com/android/tools/r8/graph/JarClassFileReader.java
@@ -106,6 +106,12 @@
       }
     }
 
+    if (classKind == ClassKind.PROGRAM
+        && application.options.isDesugaring()
+        && application.options.desugarGraphConsumer != null) {
+      application.options.desugarGraphConsumer.acceptProgramNode(origin);
+    }
+
     ClassReader reader = new ClassReader(bytes);
 
     int parsingOptions = SKIP_FRAMES | SKIP_CODE;
diff --git a/src/test/java/com/android/tools/r8/compilerapi/CompilerApiTestCollection.java b/src/test/java/com/android/tools/r8/compilerapi/CompilerApiTestCollection.java
index 1ca2f2b..2b37eed 100644
--- a/src/test/java/com/android/tools/r8/compilerapi/CompilerApiTestCollection.java
+++ b/src/test/java/com/android/tools/r8/compilerapi/CompilerApiTestCollection.java
@@ -9,6 +9,7 @@
 
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.compilerapi.assertionconfiguration.AssertionConfigurationTest;
+import com.android.tools.r8.compilerapi.desugardependencies.DesugarDependenciesTest;
 import com.android.tools.r8.compilerapi.inputdependencies.InputDependenciesTest;
 import com.android.tools.r8.compilerapi.mapid.CustomMapIdTest;
 import com.android.tools.r8.compilerapi.mockdata.MockClass;
@@ -36,7 +37,9 @@
 
   private static final List<Class<? extends CompilerApiTest>> CLASSES_PENDING_BINARY_COMPATIBILITY =
       ImmutableList.of(
-          AssertionConfigurationTest.ApiTest.class, InputDependenciesTest.ApiTest.class);
+          AssertionConfigurationTest.ApiTest.class,
+          InputDependenciesTest.ApiTest.class,
+          DesugarDependenciesTest.ApiTest.class);
 
   private final TemporaryFolder temp;
 
diff --git a/src/test/java/com/android/tools/r8/compilerapi/desugardependencies/DesugarDependenciesTest.java b/src/test/java/com/android/tools/r8/compilerapi/desugardependencies/DesugarDependenciesTest.java
new file mode 100644
index 0000000..d361457
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/compilerapi/desugardependencies/DesugarDependenciesTest.java
@@ -0,0 +1,86 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.compilerapi.desugardependencies;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.D8;
+import com.android.tools.r8.D8Command;
+import com.android.tools.r8.DesugarGraphConsumer;
+import com.android.tools.r8.DexIndexedConsumer;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.compilerapi.CompilerApiTest;
+import com.android.tools.r8.compilerapi.CompilerApiTestRunner;
+import com.android.tools.r8.origin.Origin;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.junit.Test;
+
+public class DesugarDependenciesTest extends CompilerApiTestRunner {
+
+  public DesugarDependenciesTest(TestParameters parameters) {
+    super(parameters);
+  }
+
+  @Override
+  public Class<? extends CompilerApiTest> binaryTestClass() {
+    return ApiTest.class;
+  }
+
+  @Test
+  public void testDesugarDependencies() throws Exception {
+    ApiTest test = new ApiTest(ApiTest.PARAMETERS);
+    runTest(test::run);
+  }
+
+  private interface Runner {
+    void run() throws Exception;
+  }
+
+  private void runTest(Runner test) throws Exception {
+    test.run();
+  }
+
+  public static class ApiTest extends CompilerApiTest {
+
+    public ApiTest(Object parameters) {
+      super(parameters);
+    }
+
+    public void run() throws Exception {
+      D8.run(
+          D8Command.builder()
+              .addClassProgramData(getBytesForClass(getMockClass()), Origin.unknown())
+              .addLibraryFiles(getJava8RuntimeJar())
+              .setProgramConsumer(DexIndexedConsumer.emptyConsumer())
+              .setDesugarGraphConsumer(
+                  new DesugarGraphConsumer() {
+                    private final Map<Origin, Origin> desugaringUnit = new ConcurrentHashMap<>();
+
+                    @Override
+                    public void acceptProgramNode(Origin node) {
+                      desugaringUnit.put(node, node);
+                    }
+
+                    @Override
+                    public void accept(Origin dependent, Origin dependency) {
+                      assertTrue(desugaringUnit.containsKey(dependent));
+                    }
+
+                    @Override
+                    public void finished() {
+                      // Input unit contains just the mock class.
+                      assertEquals(1, desugaringUnit.size());
+                    }
+                  })
+              .build());
+    }
+
+    @Test
+    public void testDesugarDependencies() throws Exception {
+      run();
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/desugar/graph/CompilationDependentSetTest.java b/src/test/java/com/android/tools/r8/desugar/graph/CompilationDependentSetTest.java
new file mode 100644
index 0000000..58fafe6
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/graph/CompilationDependentSetTest.java
@@ -0,0 +1,97 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.desugar.graph;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.D8TestBuilder;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.origin.Origin;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.google.common.collect.ImmutableSet;
+import java.nio.file.Path;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class CompilationDependentSetTest extends TestBase {
+
+  public interface I {
+    // Emtpy.
+  }
+
+  public static class A implements I {
+    // Empty.
+  }
+
+  public static class B {
+    // Empty.
+  }
+
+  public static class TestClass {
+
+    public static void main(String[] args) {
+      System.out.println("Hello World!");
+    }
+  }
+
+  // Test runner follows.
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimes().withAllApiLevels().build();
+  }
+
+  private final TestParameters parameters;
+
+  public CompilationDependentSetTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    if (parameters.isCfRuntime()) {
+      testForJvm()
+          .addProgramClasses(I.class, A.class, B.class, TestClass.class)
+          .run(parameters.getRuntime(), TestClass.class)
+          .assertSuccessWithOutputLines("Hello World!");
+    } else {
+      Path dexInputForB =
+          testForD8()
+              .addProgramClasses(B.class)
+              .setMinApi(parameters.getApiLevel())
+              .compile()
+              .writeToZip();
+
+      D8TestBuilder builder = testForD8();
+      DesugarGraphTestConsumer consumer = new DesugarGraphTestConsumer();
+      builder.getBuilder().setDesugarGraphConsumer(consumer);
+      Origin originI = DesugarGraphUtils.addClassWithOrigin(I.class, builder);
+      Origin originA = DesugarGraphUtils.addClassWithOrigin(A.class, builder);
+      Origin originTestClass = DesugarGraphUtils.addClassWithOrigin(TestClass.class, builder);
+      builder
+          .addProgramFiles(dexInputForB)
+          .setMinApi(parameters.getApiLevel())
+          .run(parameters.getRuntime(), TestClass.class)
+          .assertSuccessWithOutputLines("Hello World!");
+      // If API level indicates desugaring is needed check the edges are reported.
+      if (parameters.getApiLevel().getLevel() < AndroidApiLevel.N.getLevel()) {
+        assertTrue(consumer.contains(originI, originA));
+        assertEquals(1, consumer.totalEdgeCount());
+      } else {
+        assertEquals(0, consumer.totalEdgeCount());
+      }
+      // Regardless of API the potential inputs are reported.
+      // Note that the DEX input is not a desugaring candidate and thus not included in the unit.
+      assertEquals(
+          ImmutableSet.of(originI, originA, originTestClass),
+          consumer.getDesugaringCompilationUnit());
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/desugar/graph/DesugarGraphTestConsumer.java b/src/test/java/com/android/tools/r8/desugar/graph/DesugarGraphTestConsumer.java
index cb0da77..d80d81e 100644
--- a/src/test/java/com/android/tools/r8/desugar/graph/DesugarGraphTestConsumer.java
+++ b/src/test/java/com/android/tools/r8/desugar/graph/DesugarGraphTestConsumer.java
@@ -22,6 +22,9 @@
 
   private boolean finished = false;
 
+  // Set of all origins for the desugaring candidates in the compilation unit.
+  private final Set<Origin> desugaringCompilationUnit = new HashSet<>();
+
   // Map from a dependency to its immediate dependents.
   private final Map<Origin, Set<Origin>> dependents = new HashMap<>();
 
@@ -81,6 +84,16 @@
     return count;
   }
 
+  public Set<Origin> getDesugaringCompilationUnit() {
+    assertTrue(finished);
+    return desugaringCompilationUnit;
+  }
+
+  @Override
+  public synchronized void acceptProgramNode(Origin node) {
+    desugaringCompilationUnit.add(node);
+  }
+
   @Override
   public synchronized void accept(Origin dependent, Origin dependency) {
     assertFalse(finished);