Make lambda merger aware of InitClass instructions

Change-Id: I3b3a13b411eb092e850790c36eb697693a88a228
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/lambda/CodeProcessor.java b/src/main/java/com/android/tools/r8/ir/optimize/lambda/CodeProcessor.java
index f10c3fe..1648958 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/lambda/CodeProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/lambda/CodeProcessor.java
@@ -18,6 +18,7 @@
 import com.android.tools.r8.ir.code.ConstMethodType;
 import com.android.tools.r8.ir.code.DefaultInstructionVisitor;
 import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.InitClass;
 import com.android.tools.r8.ir.code.InstanceGet;
 import com.android.tools.r8.ir.code.InstancePut;
 import com.android.tools.r8.ir.code.InstructionListIterator;
@@ -62,6 +63,8 @@
 
     boolean isValidNewInstance(CodeProcessor context, NewInstance invoke);
 
+    boolean isValidInitClass(CodeProcessor context, DexType clazz);
+
     void patch(ApplyStrategy context, NewInstance newInstance);
 
     void patch(ApplyStrategy context, InvokeMethod invoke);
@@ -69,6 +72,8 @@
     void patch(ApplyStrategy context, InstanceGet instanceGet);
 
     void patch(ApplyStrategy context, StaticGet staticGet);
+
+    void patch(ApplyStrategy context, InitClass initClass);
   }
 
   // No-op strategy.
@@ -110,6 +115,11 @@
         }
 
         @Override
+        public boolean isValidInitClass(CodeProcessor context, DexType clazz) {
+          return false;
+        }
+
+        @Override
         public void patch(ApplyStrategy context, NewInstance newInstance) {
           throw new Unreachable();
         }
@@ -128,6 +138,11 @@
         public void patch(ApplyStrategy context, StaticGet staticGet) {
           throw new Unreachable();
         }
+
+        @Override
+        public void patch(ApplyStrategy context, InitClass initClass) {
+          throw new Unreachable();
+        }
       };
 
   public final AppView<AppInfoWithLiveness> appView;
@@ -353,6 +368,21 @@
     return null;
   }
 
+  @Override
+  public Void visit(InitClass initClass) {
+    DexType clazz = initClass.getClassValue();
+    Strategy strategy = strategyProvider.apply(clazz);
+    if (strategy.isValidInitClass(this, clazz)) {
+      if (shouldRewrite(clazz)) {
+        // Only rewrite references to lambda classes if we are outside the class.
+        process(strategy, initClass);
+      }
+    } else {
+      lambdaChecker.accept(clazz);
+    }
+    return null;
+  }
+
   abstract void process(Strategy strategy, InvokeMethod invokeMethod);
 
   abstract void process(Strategy strategy, NewInstance newInstance);
@@ -364,4 +394,6 @@
   abstract void process(Strategy strategy, StaticPut staticPut);
 
   abstract void process(Strategy strategy, StaticGet staticGet);
+
+  abstract void process(Strategy strategy, InitClass initClass);
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java b/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
index 176db10..510c122 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
@@ -20,6 +20,7 @@
 import com.android.tools.r8.graph.classmerging.HorizontallyMergedLambdaClasses;
 import com.android.tools.r8.ir.analysis.type.DestructivePhiTypeUpdater;
 import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.InitClass;
 import com.android.tools.r8.ir.code.InstanceGet;
 import com.android.tools.r8.ir.code.InstancePut;
 import com.android.tools.r8.ir.code.Instruction;
@@ -529,6 +530,11 @@
     void process(Strategy strategy, StaticGet staticGet) {
       queueForProcessing(method);
     }
+
+    @Override
+    void process(Strategy strategy, InitClass initClass) {
+      queueForProcessing(method);
+    }
   }
 
   public final class ApplyStrategy extends CodeProcessor {
@@ -644,6 +650,11 @@
     void process(Strategy strategy, StaticGet staticGet) {
       strategy.patch(this, staticGet);
     }
+
+    @Override
+    void process(Strategy strategy, InitClass initClass) {
+      strategy.patch(this, initClass);
+    }
   }
 
   private final class LambdaMergerOptimizationInfoFixer
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/lambda/kotlin/KotlinLambdaGroupCodeStrategy.java b/src/main/java/com/android/tools/r8/ir/optimize/lambda/kotlin/KotlinLambdaGroupCodeStrategy.java
index e54d241..81480d6 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/lambda/kotlin/KotlinLambdaGroupCodeStrategy.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/lambda/kotlin/KotlinLambdaGroupCodeStrategy.java
@@ -15,6 +15,7 @@
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.code.CheckCast;
 import com.android.tools.r8.ir.code.ConstNumber;
+import com.android.tools.r8.ir.code.InitClass;
 import com.android.tools.r8.ir.code.InstanceGet;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InvokeDirect;
@@ -50,10 +51,10 @@
     assert group.containsLambda(lambda);
     // Only support writes to singleton static field named 'INSTANCE' from lambda
     // static class initializer.
-    return field.name == context.kotlin.functional.kotlinStyleLambdaInstanceName &&
-        lambda == field.type &&
-        context.factory.isClassConstructor(context.method.method) &&
-        context.method.method.holder == lambda;
+    return field.name == context.kotlin.functional.kotlinStyleLambdaInstanceName
+        && lambda == field.type
+        && context.factory.isClassConstructor(context.method.method)
+        && context.method.method.holder == lambda;
   }
 
   @Override
@@ -61,8 +62,8 @@
     DexType lambda = field.holder;
     assert group.containsLambda(lambda);
     // Support all reads of singleton static field named 'INSTANCE'.
-    return field.name == context.kotlin.functional.kotlinStyleLambdaInstanceName &&
-        lambda == field.type;
+    return field.name == context.kotlin.functional.kotlinStyleLambdaInstanceName
+        && lambda == field.type;
   }
 
   @Override
@@ -112,6 +113,13 @@
   }
 
   @Override
+  public boolean isValidInitClass(CodeProcessor context, DexType clazz) {
+    assert group.containsLambda(clazz);
+    // Support all init class instructions.
+    return true;
+  }
+
+  @Override
   public void patch(ApplyStrategy context, NewInstance newInstance) {
     DexType oldType = newInstance.clazz;
     DexType newType = group.getGroupClassType();
@@ -202,6 +210,14 @@
     context.recordTypeHasChanged(patchedStaticGet.outValue());
   }
 
+  @Override
+  public void patch(ApplyStrategy context, InitClass initClass) {
+    InitClass pachedInitClass =
+        new InitClass(
+            context.code.createValue(TypeLatticeElement.getInt()), group.getGroupClassType());
+    context.instructions().replaceCurrentInstruction(pachedInitClass);
+  }
+
   private void patchInitializer(CodeProcessor context, InvokeDirect invoke) {
     // Patching includes:
     //  - change of methods