Synthesize lambda classes prior to each wave

Bug: 142203515
Change-Id: I4e0f2c38787dbcfb49e1885829cc04d2912751a4
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 556b7c4..57425ab 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -801,8 +801,12 @@
     return builder.build();
   }
 
-  private void waveStart() {
+  private void waveStart(Collection<DexEncodedMethod> wave) {
     onWaveDoneActions = Collections.synchronizedList(new ArrayList<>());
+
+    if (lambdaRewriter != null) {
+      wave.forEach(method -> lambdaRewriter.synthesizeLambdaClassesFor(method, lensCodeRewriter));
+    }
   }
 
   private void waveDone() {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index ef5da18..3092293 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -117,24 +117,8 @@
         if (current.isInvokeCustom()) {
           InvokeCustom invokeCustom = current.asInvokeCustom();
           DexCallSite callSite = invokeCustom.getCallSite();
-          DexProto newMethodProto =
-              factory.applyClassMappingToProto(
-                  callSite.methodProto, graphLense::lookupType, protoFixupCache);
-          DexMethodHandle newBootstrapMethod = rewriteDexMethodHandle(
-              callSite.bootstrapMethod, method, NOT_ARGUMENT_TO_LAMBDA_METAFACTORY);
-          boolean isLambdaMetaFactory =
-              factory.isLambdaMetafactoryMethod(callSite.bootstrapMethod.asMethod());
-          MethodHandleUse methodHandleUse = isLambdaMetaFactory
-              ? ARGUMENT_TO_LAMBDA_METAFACTORY
-              : NOT_ARGUMENT_TO_LAMBDA_METAFACTORY;
-          List<DexValue> newArgs =
-              rewriteBootstrapArgs(callSite.bootstrapArgs, method, methodHandleUse);
-          if (!newMethodProto.equals(callSite.methodProto)
-              || newBootstrapMethod != callSite.bootstrapMethod
-              || !newArgs.equals(callSite.bootstrapArgs)) {
-            DexCallSite newCallSite =
-                factory.createCallSite(
-                    callSite.methodName, newMethodProto, newBootstrapMethod, newArgs);
+          DexCallSite newCallSite = rewriteCallSite(callSite, method);
+          if (newCallSite != callSite) {
             Value newOutValue = makeOutValue(invokeCustom, code);
             InvokeCustom newInvokeCustom =
                 new InvokeCustom(newCallSite, newOutValue, invokeCustom.inValues());
@@ -421,6 +405,28 @@
     assert code.hasNoVerticallyMergedClasses(appView);
   }
 
+  public DexCallSite rewriteCallSite(DexCallSite callSite, DexEncodedMethod context) {
+    DexItemFactory dexItemFactory = appView.dexItemFactory();
+    DexProto newMethodProto =
+        dexItemFactory.applyClassMappingToProto(
+            callSite.methodProto, appView.graphLense()::lookupType, protoFixupCache);
+    DexMethodHandle newBootstrapMethod =
+        rewriteDexMethodHandle(
+            callSite.bootstrapMethod, context, NOT_ARGUMENT_TO_LAMBDA_METAFACTORY);
+    boolean isLambdaMetaFactory =
+        dexItemFactory.isLambdaMetafactoryMethod(callSite.bootstrapMethod.asMethod());
+    MethodHandleUse methodHandleUse =
+        isLambdaMetaFactory ? ARGUMENT_TO_LAMBDA_METAFACTORY : NOT_ARGUMENT_TO_LAMBDA_METAFACTORY;
+    List<DexValue> newArgs = rewriteBootstrapArgs(callSite.bootstrapArgs, context, methodHandleUse);
+    if (!newMethodProto.equals(callSite.methodProto)
+        || newBootstrapMethod != callSite.bootstrapMethod
+        || !newArgs.equals(callSite.bootstrapArgs)) {
+      return dexItemFactory.createCallSite(
+          callSite.methodName, newMethodProto, newBootstrapMethod, newArgs);
+    }
+    return callSite;
+  }
+
   // If the given invoke is on the form "invoke-direct A.<init>, v0, ..." and the definition of
   // value v0 is "new-instance v0, B", where B is a subtype of A (see the Art800 and B116282409
   // tests), then fail with a compilation error if A has previously been merged into B.
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
index 195583d..9a55240 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
@@ -92,7 +92,7 @@
    */
   public <E extends Exception> void forEachMethod(
       ThrowingBiConsumer<DexEncodedMethod, Predicate<DexEncodedMethod>, E> consumer,
-      Action waveStart,
+      Consumer<Collection<DexEncodedMethod>> waveStart,
       Action waveDone,
       ExecutorService executorService)
       throws ExecutionException {
@@ -100,7 +100,7 @@
       Collection<DexEncodedMethod> wave = waves.removeFirst();
       assert wave.size() > 0;
       List<Future<?>> futures = new ArrayList<>();
-      waveStart.execute();
+      waveStart.accept(wave);
       for (DexEncodedMethod method : wave) {
         futures.add(
             executorService.submit(
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
index 454085b..9cc26c4 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
@@ -7,6 +7,8 @@
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.Code;
+import com.android.tools.r8.graph.DefaultUseRegistry;
 import com.android.tools.r8.graph.DexApplication.Builder;
 import com.android.tools.r8.graph.DexCallSite;
 import com.android.tools.r8.graph.DexEncodedMethod;
@@ -30,6 +32,7 @@
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.IRConverter;
+import com.android.tools.r8.ir.conversion.LensCodeRewriter;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.google.common.collect.BiMap;
 import com.google.common.collect.HashBiMap;
@@ -102,6 +105,36 @@
     this.createInstanceMethodName = factory.createString(LAMBDA_CREATE_INSTANCE_METHOD_NAME);
   }
 
+  public void synthesizeLambdaClassesFor(
+      DexEncodedMethod method, LensCodeRewriter lensCodeRewriter) {
+    if (!method.hasCode() || method.isProcessed()) {
+      // Nothing to desugar.
+      return;
+    }
+
+    Code code = method.getCode();
+    if (!code.isCfCode()) {
+      // Nothing to desugar.
+      return;
+    }
+
+    // Introduce a lambda class in AppInfo for each call site such that we do not modify the
+    // application (and, in particular, the class hierarchy) during wave processing.
+    code.registerCodeReferences(
+        method,
+        new DefaultUseRegistry(appView.dexItemFactory()) {
+
+          @Override
+          public void registerCallSite(DexCallSite callSite) {
+            LambdaDescriptor descriptor =
+                inferLambdaDescriptor(lensCodeRewriter.rewriteCallSite(callSite, method));
+            if (descriptor != LambdaDescriptor.MATCH_FAILED) {
+              getOrCreateLambdaClass(descriptor, method.method.holder);
+            }
+          }
+        });
+  }
+
   /**
    * Detect and desugar lambdas and method references found in the code.
    *