Synthesize lambdas and interface bridges in the enqueuer fixed point.

Bug: 151074209
Change-Id: If2253ee19c81f9572e2c354b0936d2670d8a21fc
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index c43b1f9..a1333fb 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -23,6 +23,7 @@
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.experimental.graphinfo.GraphConsumer;
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppInfoWithSubtyping;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.CfCode;
@@ -95,6 +96,7 @@
 import com.android.tools.r8.utils.InternalOptions.DesugarState;
 import com.android.tools.r8.utils.MethodSignatureEquivalence;
 import com.android.tools.r8.utils.OptionalBool;
+import com.android.tools.r8.utils.Pair;
 import com.android.tools.r8.utils.SetUtils;
 import com.android.tools.r8.utils.StringDiagnostic;
 import com.android.tools.r8.utils.Timing;
@@ -170,8 +172,10 @@
 
   private Set<EnqueuerAnalysis> analyses = Sets.newIdentityHashSet();
   private Set<EnqueuerInvokeAnalysis> invokeAnalyses = Sets.newIdentityHashSet();
-  private final AppInfoWithSubtyping appInfo;
-  private final AppView<? extends AppInfoWithSubtyping> appView;
+
+  // Don't hold a direct pointer to app info (use appView).
+  private AppInfoWithSubtyping appInfo;
+  private final AppView<AppInfoWithSubtyping> appView;
   private final InternalOptions options;
   private RootSet rootSet;
   private ProguardClassFilter dontWarnPatterns;
@@ -320,7 +324,8 @@
 
   private final LambdaRewriter lambdaRewriter;
   private final DesugaredLibraryConversionWrapperAnalysis desugaredLibraryWrapperAnalysis;
-  private final Map<DexType, LambdaClass> lambdaClasses = new IdentityHashMap<>();
+  private final Map<DexType, Pair<LambdaClass, DexEncodedMethod>> lambdaClasses =
+      new IdentityHashMap<>();
   private final Map<DexEncodedMethod, Map<DexCallSite, LambdaClass>> lambdaCallSites =
       new IdentityHashMap<>();
   private final Set<DexProgramClass> classesWithSerializableLambdas = Sets.newIdentityHashSet();
@@ -332,7 +337,7 @@
     assert appView.appServices() != null;
     InternalOptions options = appView.options();
     this.appInfo = appView.appInfo();
-    this.appView = appView;
+    this.appView = appView.withSubtyping();
     this.forceProguardCompatibility = options.forceProguardCompatibility;
     this.graphReporter = new GraphReporter(appView, keptGraphConsumer);
     this.mode = mode;
@@ -368,6 +373,10 @@
     }
   }
 
+  private AppInfoWithClassHierarchy appInfo() {
+    return appView.appInfo();
+  }
+
   public Mode getMode() {
     return mode;
   }
@@ -649,7 +658,7 @@
       if (code != null) {
         LambdaClass lambdaClass =
             lambdaRewriter.getOrCreateLambdaClass(descriptor, contextMethod.method.holder);
-        lambdaClasses.put(lambdaClass.type, lambdaClass);
+        lambdaClasses.put(lambdaClass.type, new Pair<>(lambdaClass, contextMethod));
         lambdaCallSites
             .computeIfAbsent(contextMethod, k -> new IdentityHashMap<>())
             .put(callSite, lambdaClass);
@@ -2498,6 +2507,151 @@
     return appInfoWithLiveness;
   }
 
+  private static class SyntheticAdditions {
+
+    Map<DexType, Pair<DexProgramClass, DexEncodedMethod>> syntheticInstantiations =
+        new IdentityHashMap<>();
+
+    Map<DexMethod, ProgramMethod> liveMethods = new IdentityHashMap<>();
+
+    // Subset of live methods that need to be pinned.
+    Set<DexMethod> pinnedMethods = Sets.newIdentityHashSet();
+
+    // Subset of synthesized classes that need to be added to the main-dex file.
+    Set<DexType> mainDexTypes = Sets.newIdentityHashSet();
+
+    boolean isEmpty() {
+      boolean empty = syntheticInstantiations.isEmpty() && liveMethods.isEmpty();
+      assert !empty || (pinnedMethods.isEmpty() && mainDexTypes.isEmpty());
+      return empty;
+    }
+
+    void addInstantiatedClass(
+        DexProgramClass clazz, DexEncodedMethod context, boolean isMainDexClass) {
+      assert !syntheticInstantiations.containsKey(clazz.type);
+      syntheticInstantiations.put(clazz.type, new Pair<>(clazz, context));
+      if (isMainDexClass) {
+        mainDexTypes.add(clazz.type);
+      }
+    }
+
+    void addLiveMethod(ProgramMethod method) {
+      DexMethod signature = method.getMethod().method;
+      assert !liveMethods.containsKey(signature);
+      liveMethods.put(signature, method);
+    }
+
+    void addLiveAndPinnedMethod(ProgramMethod method) {
+      addLiveMethod(method);
+      pinnedMethods.add(method.getMethod().method);
+    }
+
+    void amendApplication(Builder appBuilder) {
+      assert !isEmpty();
+      for (Pair<DexProgramClass, DexEncodedMethod> clazzAndContext :
+          syntheticInstantiations.values()) {
+        appBuilder.addProgramClass(clazzAndContext.getFirst());
+      }
+      appBuilder.addToMainDexList(mainDexTypes);
+    }
+
+    void enqueueWorkItems(Enqueuer enqueuer) {
+      assert !isEmpty();
+      assert enqueuer.mode.isInitialTreeShaking();
+      // All synthetic additions are initial tree shaking only. No need to track keep reasons.
+      KeepReasonWitness fakeReason = enqueuer.graphReporter.fakeReportShouldNotBeUsed();
+
+      enqueuer.pinnedItems.addAll(pinnedMethods);
+      for (Pair<DexProgramClass, DexEncodedMethod> clazzAndContext :
+          syntheticInstantiations.values()) {
+        enqueuer.workList.enqueueMarkInstantiatedAction(
+            clazzAndContext.getFirst(),
+            clazzAndContext.getSecond(),
+            InstantiationReason.SYNTHESIZED_CLASS,
+            fakeReason);
+      }
+      for (ProgramMethod liveMethod : liveMethods.values()) {
+        assert !enqueuer.targetedMethods.contains(liveMethod.getMethod());
+        DexProgramClass holder = liveMethod.getHolder();
+        DexEncodedMethod method = liveMethod.getMethod();
+        enqueuer.markMethodAsTargeted(holder, method, fakeReason);
+        enqueuer.enqueueMarkMethodLiveAction(holder, method, fakeReason);
+      }
+    }
+  }
+
+  private void synthesize() {
+    // First part of synthesis is to create and register all reachable synthetic additions.
+    // In particular these additions are order independent, i.e., it does not matter which are
+    // registered first and no dependencies may exist among them.
+    SyntheticAdditions additions = new SyntheticAdditions();
+    synthesizeInterfaceMethodBridges(additions);
+    synthesizeLambdas(additions);
+    if (additions.isEmpty()) {
+      return;
+    }
+
+    // Now all additions are computed, the application is atomically extended with those additions.
+    Builder appBuilder = appInfo.app().asDirect().builder();
+    additions.amendApplication(appBuilder);
+    appInfo = new AppInfoWithSubtyping(appBuilder.build());
+    appView.setAppInfo(appInfo);
+
+    // Finally once all synthesized items "exist" it is now safe to continue tracing. The new work
+    // items are enqueued and the fixed point will continue once this subroutine returns.
+    additions.enqueueWorkItems(this);
+  }
+
+  private void synthesizeInterfaceMethodBridges(SyntheticAdditions additions) {
+    for (ProgramMethod bridge : syntheticInterfaceMethodBridges.values()) {
+      DexProgramClass holder = bridge.getHolder();
+      DexEncodedMethod method = bridge.getMethod();
+      holder.addVirtualMethod(method);
+      additions.addLiveAndPinnedMethod(bridge);
+    }
+    syntheticInterfaceMethodBridges.clear();
+  }
+
+  private void synthesizeLambdas(SyntheticAdditions additions) {
+    if (lambdaRewriter == null || lambdaClasses.isEmpty()) {
+      assert lambdaCallSites.isEmpty();
+      assert classesWithSerializableLambdas.isEmpty();
+      return;
+    }
+    for (Pair<LambdaClass, DexEncodedMethod> lambdaClassAndContext : lambdaClasses.values()) {
+      // Add all desugared classes to the application, main-dex list, and mark them instantiated.
+      LambdaClass lambdaClass = lambdaClassAndContext.getFirst();
+      DexEncodedMethod context = lambdaClassAndContext.getSecond();
+      DexProgramClass programClass = lambdaClass.getOrCreateLambdaClass();
+      additions.addInstantiatedClass(programClass, context, lambdaClass.addToMainDexList.get());
+
+      // Mark all methods on the desugared lambda classes as live.
+      for (DexEncodedMethod method : programClass.methods()) {
+        additions.addLiveMethod(new ProgramMethod(programClass, method));
+      }
+
+      // Ensure accessors if needed and mark them live too.
+      DexEncodedMethod accessor = lambdaClass.target.ensureAccessibilityIfNeeded(false);
+      if (accessor != null) {
+        DexProgramClass clazz = getProgramClassOrNull(accessor.method.holder);
+        additions.addLiveMethod(new ProgramMethod(clazz, accessor));
+      }
+    }
+
+    // Rewrite all of the invoke-dynamic instructions to lambda class instantiations.
+    lambdaCallSites.forEach(this::rewriteLambdaCallSites);
+
+    // Remove all '$deserializeLambda$' methods which are not supported by desugaring.
+    for (DexProgramClass clazz : classesWithSerializableLambdas) {
+      clazz.removeDirectMethod(appView.dexItemFactory().deserializeLambdaMethod);
+    }
+
+    // Clear state before next fixed point iteration.
+    lambdaClasses.clear();
+    lambdaCallSites.clear();
+    classesWithSerializableLambdas.clear();
+  }
+
   private void finalizeLibraryMethodOverrideInformation() {
     for (DexProgramClass liveType : liveTypes.getItems()) {
       for (DexEncodedMethod method : liveType.virtualMethods()) {
@@ -2524,15 +2678,6 @@
         (field, info) -> field != info.getField() || info == MISSING_FIELD_ACCESS_INFO);
     assert fieldAccessInfoCollection.verifyMappingIsOneToOne();
 
-    for (ProgramMethod bridge : syntheticInterfaceMethodBridges.values()) {
-      DexProgramClass holder = bridge.getHolder();
-      DexEncodedMethod method = bridge.getMethod();
-      holder.addVirtualMethod(method);
-      targetedMethods.add(method, graphReporter.fakeReportShouldNotBeUsed());
-      liveMethods.add(holder, method, graphReporter.fakeReportShouldNotBeUsed());
-      pinnedItems.add(method.method);
-    }
-
     // Ensure references from various root set collections.
     rootSet
         .noSideEffects
@@ -2566,7 +2711,6 @@
     appBuilder.replaceClasspathClasses(classpathClasses);
     // Can't replace the program classes at this point as they are needed in tree pruning.
     // Post process the app to add synthetic content.
-    postProcessLambdaDesugaring(appBuilder);
     postProcessLibraryConversionWrappers(appBuilder);
     DirectMappedDexApplication app = appBuilder.build();
 
@@ -2709,60 +2853,6 @@
     appBuilder.addClasspathClasses(mockVivifiedClasses);
   }
 
-  private void postProcessLambdaDesugaring(DirectMappedDexApplication.Builder appBuilder) {
-    if (lambdaRewriter == null || lambdaClasses.isEmpty()) {
-      return;
-    }
-    for (LambdaClass lambdaClass : lambdaClasses.values()) {
-      // Add all desugared classes to the application, main-dex list, and mark them instantiated.
-      DexProgramClass programClass = lambdaClass.getOrCreateLambdaClass();
-      appBuilder.addProgramClass(programClass);
-      if (lambdaClass.addToMainDexList.get()) {
-        appBuilder.addToMainDexList(Collections.singletonList(programClass.type));
-      }
-      liveTypes.add(programClass, graphReporter.fakeReportShouldNotBeUsed());
-      objectAllocationInfoCollection.recordDirectAllocationSite(
-          programClass,
-          null,
-          InstantiationReason.SYNTHESIZED_CLASS,
-          graphReporter.fakeReportShouldNotBeUsed(),
-          appInfo);
-
-      // Register all of the field writes in the lambda constructors.
-      // This is needed to ensure that the initializers can be optimized.
-      Map<DexEncodedField, Set<DexEncodedMethod>> writes =
-          lambdaRewriter.getWritesWithContexts(programClass);
-      writes.forEach(
-          (field, contexts) -> {
-            for (DexEncodedMethod context : contexts) {
-              registerFieldWrite(field.field, context);
-            }
-          });
-
-      // Mark all methods on the desugared lambda classes as live.
-      for (DexEncodedMethod method : programClass.methods()) {
-        targetedMethods.add(method, graphReporter.fakeReportShouldNotBeUsed());
-        liveMethods.add(programClass, method, graphReporter.fakeReportShouldNotBeUsed());
-      }
-
-      // Ensure accessors if needed and mark them live too.
-      DexEncodedMethod accessor = lambdaClass.target.ensureAccessibilityIfNeeded(false);
-      if (accessor != null) {
-        DexProgramClass clazz = getProgramClassOrNull(accessor.method.holder);
-        targetedMethods.add(accessor, graphReporter.fakeReportShouldNotBeUsed());
-        liveMethods.add(clazz, accessor, graphReporter.fakeReportShouldNotBeUsed());
-      }
-    }
-
-    // Rewrite all of the invoke-dynamic instructions to lambda class instantiations.
-    lambdaCallSites.forEach(this::rewriteLambdaCallSites);
-
-    // Remove all '$deserializeLambda$' methods which are not supported by desugaring.
-    for (DexProgramClass clazz : classesWithSerializableLambdas) {
-      clazz.removeDirectMethod(appView.dexItemFactory().deserializeLambdaMethod);
-    }
-  }
-
   private void rewriteLambdaCallSites(
       DexEncodedMethod method, Map<DexCallSite, LambdaClass> callSites) {
     assert !callSites.isEmpty();
@@ -2898,6 +2988,11 @@
           continue;
         }
 
+        synthesize();
+        if (!workList.isEmpty()) {
+          continue;
+        }
+
         // Reached the fixpoint.
         break;
       }
diff --git a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
index a642002..b8256fe 100644
--- a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
+++ b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
@@ -102,7 +102,7 @@
         .withOptionConsumer(opts -> opts.enableClassInlining = false)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 115, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 116, "lambdadesugaring"))
         .run();
 
     test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
@@ -142,7 +142,7 @@
         .withOptionConsumer(opts -> opts.enableClassInlining = false)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 115, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 116, "lambdadesugaring"))
         .run();
 
     test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
diff --git a/src/test/java/com/android/tools/r8/desugar/DesugarInstanceLambdaWithReadsTest.java b/src/test/java/com/android/tools/r8/desugar/DesugarInstanceLambdaWithReadsTest.java
index 07e4f62..bf507de 100644
--- a/src/test/java/com/android/tools/r8/desugar/DesugarInstanceLambdaWithReadsTest.java
+++ b/src/test/java/com/android/tools/r8/desugar/DesugarInstanceLambdaWithReadsTest.java
@@ -87,8 +87,8 @@
   }
 
   static class Main {
-    // Field that is read from the lambda$ method.
-    A filter;
+    // Field that is read from the lambda$ method (private ensures the method can't be inlined).
+    private A filter;
 
     public Main(A filter) {
       this.filter = filter;
diff --git a/src/test/java/com/android/tools/r8/shaking/LibraryMethodOverrideInLambdaMarkingTest.java b/src/test/java/com/android/tools/r8/shaking/LibraryMethodOverrideInLambdaMarkingTest.java
index e3925ca..9d16769 100644
--- a/src/test/java/com/android/tools/r8/shaking/LibraryMethodOverrideInLambdaMarkingTest.java
+++ b/src/test/java/com/android/tools/r8/shaking/LibraryMethodOverrideInLambdaMarkingTest.java
@@ -63,7 +63,11 @@
     DexEncodedMethod method =
         clazz.lookupVirtualMethod(m -> m.method.name.toString().equals("iterator"));
     // TODO(b/149976493): Mark library overrides from lambda instances.
-    assertTrue(method.isLibraryMethodOverride().isFalse());
+    if (parameters.isCfRuntime()) {
+      assertTrue(method.isLibraryMethodOverride().isFalse());
+    } else {
+      assertTrue(method.isLibraryMethodOverride().isTrue());
+    }
   }
 
   static class TestClass {