Move desugaring into LensRewriting.

Desugaring during LensRewriting needs the
enableStatefulLambdaInstanceMethod testing flag to be enabled which
is enabled only for 2 tests
([R8]RunExamplesAndroidOTest.lambdaDesugaringCreateMethod).
Additionally, this CL has been tested locally for the entire test suit
in which case only the ClassInlinerTest fails due to desugared
methods not inlined.

Change-Id: I9341a62458621842f0a27ef46355f8de02803316
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 6755f46..9b1b30e 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -199,11 +199,11 @@
               : null;
       this.classStaticizer =
           options.enableClassStaticizer ? new ClassStaticizer(appViewWithLiveness, this) : null;
-      this.inliner = new Inliner(appViewWithLiveness, mainDexClasses);
+      this.lensCodeRewriter = new LensCodeRewriter(appViewWithLiveness, lambdaRewriter);
+      this.inliner = new Inliner(appViewWithLiveness, mainDexClasses, lensCodeRewriter);
       this.outliner = new Outliner(appViewWithLiveness, this);
       this.memberValuePropagation =
           options.enableValuePropagation ? new MemberValuePropagation(appViewWithLiveness) : null;
-      this.lensCodeRewriter = new LensCodeRewriter(appViewWithLiveness);
       if (!appInfoWithLiveness.identifierNameStrings.isEmpty() && options.isMinifying()) {
         this.identifierNameStringMarker = new IdentifierNameStringMarker(appViewWithLiveness);
       } else {
@@ -859,6 +859,10 @@
         lensCodeRewriter.rewrite(code, method);
       } else {
         assert appView.graphLense().isIdentityLense();
+        if (lambdaRewriter != null && options.testing.desugarLambdasThroughLensCodeRewriter()) {
+          lambdaRewriter.desugarLambdas(method, code);
+          assert code.isConsistentSSA();
+        }
       }
     }
 
@@ -1000,11 +1004,12 @@
 
     stringConcatRewriter.desugarStringConcats(method.method, code);
 
-    if (lambdaRewriter != null) {
+    if (options.testing.desugarLambdasThroughLensCodeRewriter()) {
+      assert !options.enableDesugaring || lambdaRewriter.verifyNoLambdasToDesugar(code);
+    } else if (lambdaRewriter != null) {
       lambdaRewriter.desugarLambdas(method, code);
       assert code.isConsistentSSA();
     }
-
     previous = printMethod(code, "IR after lambda desugaring (SSA)", previous);
 
     assert code.verifyTypes(appView);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 4f05aae..7650fe7 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -56,6 +56,7 @@
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.StaticPut;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.desugar.LambdaRewriter;
 import com.android.tools.r8.logging.Log;
 import com.android.tools.r8.shaking.VerticalClassMerger.VerticallyMergedClasses;
 import com.google.common.collect.Sets;
@@ -72,9 +73,12 @@
   private final AppView<? extends AppInfoWithSubtyping> appView;
 
   private final Map<DexProto, DexProto> protoFixupCache = new ConcurrentHashMap<>();
+  private final LambdaRewriter lambdaRewriter;
 
-  public LensCodeRewriter(AppView<? extends AppInfoWithSubtyping> appView) {
+  public LensCodeRewriter(
+      AppView<? extends AppInfoWithSubtyping> appView, LambdaRewriter lambdaRewriter) {
     this.appView = appView;
+    this.lambdaRewriter = lambdaRewriter;
   }
 
   private Value makeOutValue(Instruction insn, IRCode code, Set<Value> collector) {
@@ -87,9 +91,7 @@
     }
   }
 
-  /**
-   * Replace type appearances, invoke targets and field accesses with actual definitions.
-   */
+  /** Replace type appearances, invoke targets and field accesses with actual definitions. */
   public void rewrite(IRCode code, DexEncodedMethod method) {
     GraphLense graphLense = appView.graphLense();
 
@@ -138,6 +140,13 @@
                 invokeCustom.inValues());
             iterator.replaceCurrentInstruction(newInvokeCustom);
           }
+          if (lambdaRewriter != null
+              && appView.options().testing.desugarLambdasThroughLensCodeRewriter()) {
+            Instruction previous = iterator.peekPrevious();
+            assert previous.isInvokeCustom();
+            lambdaRewriter.desugarLambda(
+                method.method.holder, iterator, previous.asInvokeCustom(), code);
+          }
         } else if (current.isConstMethodHandle()) {
           DexMethodHandle handle = current.asConstMethodHandle().getValue();
           DexMethodHandle newHandle = rewriteDexMethodHandle(
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
index b5a5f0c..5be5a81 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClass.java
@@ -64,7 +64,7 @@
   final DexMethod constructor;
   final DexMethod classConstructor;
   final DexMethod createInstanceMethod;
-  final DexField instanceField;
+  final DexField lambdaField;
   final Target target;
   final AtomicBoolean addToMainDexList = new AtomicBoolean(false);
   private final Collection<DexProgramClass> synthesizedFrom = new ArrayList<>(1);
@@ -95,8 +95,10 @@
     boolean stateless = isStateless();
     this.classConstructor = !stateless ? null
         : factory.createMethod(lambdaClassType, constructorProto, rewriter.classConstructorName);
-    this.instanceField = !stateless ? null
-        : factory.createField(lambdaClassType, lambdaClassType, rewriter.instanceFieldName);
+    this.lambdaField =
+        !stateless
+            ? null
+            : factory.createField(lambdaClassType, lambdaClassType, rewriter.instanceFieldName);
     this.createInstanceMethod =
         stateless
             ? null
@@ -309,11 +311,11 @@
     }
 
     // Create instance field for stateless lambda.
-    assert this.instanceField != null;
+    assert this.lambdaField != null;
     DexEncodedField[] fields = new DexEncodedField[1];
     fields[0] =
         new DexEncodedField(
-            this.instanceField,
+            this.lambdaField,
             FieldAccessFlags.fromSharedAccessFlags(
                 Constants.ACC_PUBLIC
                     | Constants.ACC_FINAL
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassConstructorSourceCode.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassConstructorSourceCode.java
index cec70bb..6b85039 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassConstructorSourceCode.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaClassConstructorSourceCode.java
@@ -16,7 +16,7 @@
 
   LambdaClassConstructorSourceCode(LambdaClass lambda, Position callerPosition) {
     super(lambda, lambda.classConstructor, callerPosition, null /* Class initializer is static */);
-    assert lambda.instanceField != null;
+    assert lambda.lambdaField != null;
   }
 
   @Override
@@ -35,7 +35,7 @@
                 false /* isInterface */));
 
     // Assign to a field.
-    add(builder -> builder.addStaticPut(instance, lambda.instanceField));
+    add(builder -> builder.addStaticPut(instance, lambda.lambdaField));
 
     // Final return.
     add(IRBuilder::addReturn);
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
index 3356d97..c265016 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
@@ -102,7 +102,7 @@
   /**
    * Detect and desugar lambdas and method references found in the code.
    *
-   * NOTE: this method can be called concurrently for several different methods.
+   * <p>NOTE: this method can be called concurrently for several different methods.
    */
   public void desugarLambdas(DexEncodedMethod encodedMethod, IRCode code) {
     DexType currentType = encodedMethod.method.holder;
@@ -132,6 +132,34 @@
     }
   }
 
+  public void desugarLambda(
+      DexType currentType,
+      InstructionListIterator instructions,
+      InvokeCustom lenseRewrittenInvokeCustom,
+      IRCode code) {
+    LambdaDescriptor descriptor = inferLambdaDescriptor(lenseRewrittenInvokeCustom.getCallSite());
+    if (descriptor == LambdaDescriptor.MATCH_FAILED) {
+      return;
+    }
+
+    // We have a descriptor, get or create lambda class.
+    LambdaClass lambdaClass = getOrCreateLambdaClass(descriptor, currentType);
+    assert lambdaClass != null;
+
+    // We rely on patch performing its work in a way which
+    // keeps `instructions` iterator in valid state so that we can continue iteration.
+    patchInstructionSimple(lambdaClass, code, instructions, lenseRewrittenInvokeCustom);
+  }
+
+  public boolean verifyNoLambdasToDesugar(IRCode code) {
+    for (Instruction instruction : code.instructions()) {
+      assert !instruction.isInvokeCustom()
+          || inferLambdaDescriptor(instruction.asInvokeCustom().getCallSite())
+              == LambdaDescriptor.MATCH_FAILED;
+    }
+    return true;
+  }
+
   /** Remove lambda deserialization methods. */
   public boolean removeLambdaDeserializationMethods(Iterable<DexProgramClass> classes) {
     for (DexProgramClass clazz : classes) {
@@ -288,7 +316,7 @@
     // reading the value of INSTANCE field created for singleton lambda class.
     if (lambdaClass.isStateless()) {
       instructions.replaceCurrentInstruction(
-          new StaticGet(lambdaInstanceValue, lambdaClass.instanceField));
+          new StaticGet(lambdaInstanceValue, lambdaClass.lambdaField));
       // Note that since we replace one throwing operation with another we don't need
       // to have any special handling for catch handlers.
       return;
@@ -347,4 +375,50 @@
       instructions.replaceCurrentInstruction(invokeStatic);
     }
   }
+
+  // Patches invoke-custom instruction to create or get an instance
+  // of the generated lambda class. Assumes that for stateful lambdas the createInstance method
+  // is enabled so invokeCustom is always replaced by a single instruction.
+  private void patchInstructionSimple(
+      LambdaClass lambdaClass,
+      IRCode code,
+      InstructionListIterator instructions,
+      InvokeCustom invoke) {
+    assert lambdaClass != null;
+    assert instructions != null;
+
+    // The value representing new lambda instance: we reuse the value from the original
+    // invoke-custom instruction, and thus all its usages.
+    Value lambdaInstanceValue = invoke.outValue();
+    if (lambdaInstanceValue == null) {
+      // The out value might be empty in case it was optimized out.
+      lambdaInstanceValue =
+          code.createValue(
+              TypeLatticeElement.fromDexType(lambdaClass.type, Nullability.maybeNull(), appView));
+    }
+
+    // For stateless lambdas we replace InvokeCustom instruction with StaticGet reading the value of
+    // INSTANCE field created for singleton lambda class.
+    if (lambdaClass.isStateless()) {
+      instructions.replaceCurrentInstruction(
+          new StaticGet(lambdaInstanceValue, lambdaClass.lambdaField));
+      // Note that since we replace one throwing operation with another we don't need
+      // to have any special handling for catch handlers.
+      return;
+    }
+
+    assert appView.options().testing.enableStatefulLambdaCreateInstanceMethod;
+    // For stateful lambdas we call the createInstance method.
+    //
+    //    original:
+    //      Invoke-Custom rResult <- { rArg0, rArg1, ... }; call site: ...
+    //
+    //    result:
+    //      Invoke-Static rResult <- { rArg0, rArg1, ... }; method void
+    // LambdaClass.createInstance(...)
+    InvokeStatic invokeStatic =
+        new InvokeStatic(
+            lambdaClass.getCreateInstanceMethod(), lambdaInstanceValue, invoke.arguments());
+    instructions.replaceCurrentInstruction(invokeStatic);
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
index f1bce2d..3bb5c95 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
@@ -62,10 +62,15 @@
   private final Map<DexEncodedMethod, DexEncodedMethod> doubleInlineeCandidates = new HashMap<>();
 
   private final Set<DexMethod> blackList = Sets.newIdentityHashSet();
+  private final LensCodeRewriter lensCodeRewriter;
 
-  public Inliner(AppView<AppInfoWithLiveness> appView, MainDexClasses mainDexClasses) {
+  public Inliner(
+      AppView<AppInfoWithLiveness> appView,
+      MainDexClasses mainDexClasses,
+      LensCodeRewriter lensCodeRewriter) {
     this.appView = appView;
     this.mainDexClasses = mainDexClasses;
+    this.lensCodeRewriter = lensCodeRewriter;
     fillInBlackList();
   }
 
@@ -436,7 +441,8 @@
         DexEncodedMethod context,
         ValueNumberGenerator generator,
         AppView<? extends AppInfoWithSubtyping> appView,
-        Position callerPosition) {
+        Position callerPosition,
+        LensCodeRewriter lensCodeRewriter) {
       Origin origin = appView.appInfo().originFor(target.method.holder);
 
       IRCode code;
@@ -483,7 +489,7 @@
           }
         }
         if (!target.isProcessed()) {
-          new LensCodeRewriter(appView).rewrite(code, target);
+          lensCodeRewriter.rewrite(code, target);
         }
       }
       return new InlineeWithReason(code, reason);
@@ -702,7 +708,8 @@
                     == appView.graphLense().getOriginalMethodSignature(context.method);
 
             InlineeWithReason inlinee =
-                result.buildInliningIR(context, code.valueNumberGenerator, appView, invokePosition);
+                result.buildInliningIR(
+                    context, code.valueNumberGenerator, appView, invokePosition, lensCodeRewriter);
             if (inlinee != null) {
               if (strategy.willExceedBudget(inlinee, block)) {
                 continue;
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index b451490..ded6688 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -631,6 +631,10 @@
     // TODO(b/129458850) When fixed, remove this and change all usages to "true".
     public boolean enableStatefulLambdaCreateInstanceMethod = false;
 
+    public boolean desugarLambdasThroughLensCodeRewriter() {
+      return enableStatefulLambdaCreateInstanceMethod;
+    }
+
     public MinifierTestingOptions minifier = new MinifierTestingOptions();
 
     public static class MinifierTestingOptions {
diff --git a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
index fb69657..21baee3 100644
--- a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
+++ b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
@@ -118,6 +118,36 @@
         .run();
   }
 
+  @Override
+  @Test
+  public void lambdaDesugaringCreateMethod() throws Throwable {
+    test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
+        .withMinApiLevel(ToolHelper.getMinApiLevelForDexVmNoHigherThan(AndroidApiLevel.K))
+        .withOptionConsumer(
+            opts -> {
+              opts.enableClassInlining = false;
+              opts.testing.enableStatefulLambdaCreateInstanceMethod = true;
+            })
+        .withBuilderTransformation(
+            b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 180, "lambdadesugaring"))
+        .run();
+
+    test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
+        .withMinApiLevel(ToolHelper.getMinApiLevelForDexVmNoHigherThan(AndroidApiLevel.K))
+        .withOptionConsumer(
+            opts -> {
+              opts.enableClassInlining = true;
+              opts.testing.enableStatefulLambdaCreateInstanceMethod = true;
+            })
+        .withBuilderTransformation(
+            b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
+        // TODO(b/120814598): Should be 24. Some lambdas are not class inlined because parameter
+        // usages for lambda methods are not present for the class inliner.
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 46, "lambdadesugaring"))
+        .run();
+  }
+
   @Test
   @IgnoreIfVmOlderThan(Version.V7_0_0)
   public void lambdaDesugaringWithDefaultMethods() throws Throwable {