Fixup optimization info after lambda merging

Change-Id: I13b2acadf9bf22f23cb59a6197062d4735360aa4
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/ArrayTypeLatticeElement.java b/src/main/java/com/android/tools/r8/ir/analysis/type/ArrayTypeLatticeElement.java
index 5ebd0af..e2af564 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/ArrayTypeLatticeElement.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/ArrayTypeLatticeElement.java
@@ -9,8 +9,8 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.graph.GraphLense;
 import java.util.Objects;
+import java.util.function.Function;
 
 public class ArrayTypeLatticeElement extends ReferenceTypeLatticeElement {
 
@@ -130,11 +130,11 @@
   }
 
   @Override
-  public TypeLatticeElement substitute(
-      GraphLense substituteMap, AppView<? extends AppInfoWithSubtyping> appView) {
+  public ArrayTypeLatticeElement fixupClassTypeReferences(
+      Function<DexType, DexType> mapping, AppView<? extends AppInfoWithSubtyping> appView) {
     if (memberTypeLattice.isReference()) {
       TypeLatticeElement substitutedMemberType =
-          memberTypeLattice.asReferenceTypeLatticeElement().substitute(substituteMap, appView);
+          memberTypeLattice.fixupClassTypeReferences(mapping, appView);
       if (substitutedMemberType != memberTypeLattice) {
         return ArrayTypeLatticeElement.create(substitutedMemberType, nullability);
       }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/ClassTypeLatticeElement.java b/src/main/java/com/android/tools/r8/ir/analysis/type/ClassTypeLatticeElement.java
index ebc3d89..8e327aa 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/ClassTypeLatticeElement.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/ClassTypeLatticeElement.java
@@ -8,7 +8,6 @@
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.graph.GraphLense;
 import com.google.common.collect.ImmutableSet;
 import java.util.ArrayDeque;
 import java.util.Collections;
@@ -18,6 +17,7 @@
 import java.util.Objects;
 import java.util.Queue;
 import java.util.Set;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 public class ClassTypeLatticeElement extends ReferenceTypeLatticeElement {
@@ -132,11 +132,11 @@
   }
 
   @Override
-  public TypeLatticeElement substitute(
-      GraphLense substituteMap, AppView<? extends AppInfoWithSubtyping> appView) {
-    DexType mappedType = substituteMap.lookupType(type);
+  public ClassTypeLatticeElement fixupClassTypeReferences(
+      Function<DexType, DexType> mapping, AppView<? extends AppInfoWithSubtyping> appView) {
+    DexType mappedType = mapping.apply(type);
     if (mappedType != type) {
-      return fromDexType(mappedType, nullability, appView, false);
+      return create(mappedType, nullability, appView);
     }
     // If the mapped type is not object and no computation of interfaces, we can return early.
     if (mappedType != appView.dexItemFactory().objectType && lazyInterfaces == null) {
@@ -148,7 +148,7 @@
     boolean hasChangedInterfaces = false;
     DexClass interfaceToClassChange = null;
     for (DexType iface : getInterfaces()) {
-      DexType substitutedType = substituteMap.lookupType(iface);
+      DexType substitutedType = mapping.apply(iface);
       if (iface != substitutedType) {
         hasChangedInterfaces = true;
         DexClass mappedClass = appView.definitionFor(substitutedType);
@@ -172,7 +172,7 @@
       } else {
         Set<DexType> newInterfaces = new HashSet<>();
         for (DexType iface : lazyInterfaces) {
-          newInterfaces.add(substituteMap.lookupType(iface));
+          newInterfaces.add(mapping.apply(iface));
         }
         return create(mappedType, nullability, newInterfaces);
       }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/ReferenceTypeLatticeElement.java b/src/main/java/com/android/tools/r8/ir/analysis/type/ReferenceTypeLatticeElement.java
index 9825a83..8e42de0 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/ReferenceTypeLatticeElement.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/ReferenceTypeLatticeElement.java
@@ -3,12 +3,8 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.ir.analysis.type;
 
-import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.errors.Unreachable;
-import com.android.tools.r8.graph.AppInfoWithSubtyping;
-import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexItemFactory;
-import com.android.tools.r8.graph.GraphLense;
 
 public abstract class ReferenceTypeLatticeElement extends TypeLatticeElement {
 
@@ -47,12 +43,6 @@
     }
 
     @Override
-    public TypeLatticeElement substitute(
-        GraphLense substituteMap, AppView<? extends AppInfoWithSubtyping> appView) {
-      throw new CompilationError("Cannot substitute type on NULL reference");
-    }
-
-    @Override
     public boolean equals(Object o) {
       if (this == o) {
         return true;
@@ -108,7 +98,4 @@
   public int hashCode() {
     throw new Unreachable("Should be implemented on each sub type");
   }
-
-  public abstract TypeLatticeElement substitute(
-      GraphLense substituteMap, AppView<? extends AppInfoWithSubtyping> appView);
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/TypeLatticeElement.java b/src/main/java/com/android/tools/r8/ir/analysis/type/TypeLatticeElement.java
index a9cd745..09a9fad 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/TypeLatticeElement.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/TypeLatticeElement.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.code.Value;
+import java.util.function.Function;
 
 /**
  * The base abstraction of lattice elements for local type analysis.
@@ -32,6 +33,11 @@
   public static final ReferenceTypeLatticeElement NULL =
       ReferenceTypeLatticeElement.getNullTypeLatticeElement();
 
+  public TypeLatticeElement fixupClassTypeReferences(
+      Function<DexType, DexType> mapping, AppView<? extends AppInfoWithSubtyping> appView) {
+    return this;
+  }
+
   public boolean isNullable() {
     return nullability().isNullable();
   }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 8c40d92..28840b2 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -88,22 +88,14 @@
 
   private Value makeOutValue(Instruction insn, IRCode code) {
     if (insn.outValue() != null) {
-      TypeLatticeElement typeLattice = substitute(insn.outValue().getTypeLattice(), appView);
-      return code.createValue(typeLattice, insn.getLocalInfo());
+      TypeLatticeElement oldType = insn.outValue().getTypeLattice();
+      TypeLatticeElement newType =
+          oldType.fixupClassTypeReferences(appView.graphLense()::lookupType, appView);
+      return code.createValue(newType, insn.getLocalInfo());
     }
     return null;
   }
 
-  private static TypeLatticeElement substitute(
-      TypeLatticeElement latticeElement, AppView<? extends AppInfoWithSubtyping> appView) {
-    if (latticeElement.isReference() && !latticeElement.isNullType()) {
-      return latticeElement
-          .asReferenceTypeLatticeElement()
-          .substitute(appView.graphLense(), appView);
-    }
-    return latticeElement;
-  }
-
   /** Replace type appearances, invoke targets and field accesses with actual definitions. */
   public void rewrite(IRCode code, DexEncodedMethod method) {
     GraphLense graphLense = appView.graphLense();
@@ -410,7 +402,8 @@
         } else if (current.outValue() != null) {
           // For all other instructions, substitute any changed type.
           TypeLatticeElement typeLattice = current.outValue().getTypeLattice();
-          TypeLatticeElement substituted = substitute(typeLattice, appView);
+          TypeLatticeElement substituted =
+              typeLattice.fixupClassTypeReferences(graphLense::lookupType, appView);
           if (substituted != typeLattice) {
             current.outValue().setTypeLattice(substituted);
             affectedPhis.addAll(current.outValue().uniquePhiUsers());
@@ -438,7 +431,7 @@
       assert verifyAllChangedPhisAreScheduled(code, affectedPhis);
       // Assuming all values have been rewritten correctly above, the non-phi operands to phi's are
       // replaced with correct types and all other phi operands are BOTTOM.
-      assert verifyAllPhiOperandsAreBottom(affectedPhis, graphLense);
+      assert verifyAllPhiOperandsAreBottom(affectedPhis);
       worklist.addAll(affectedPhis);
       while (!worklist.isEmpty()) {
         Phi phi = worklist.poll();
@@ -456,7 +449,7 @@
     assert code.hasNoVerticallyMergedClasses(appView);
   }
 
-  private boolean verifyAllPhiOperandsAreBottom(Set<Phi> affectedPhis, GraphLense graphLense) {
+  private boolean verifyAllPhiOperandsAreBottom(Set<Phi> affectedPhis) {
     for (Phi phi : affectedPhis) {
       for (Value operand : phi.getOperands()) {
         if (operand.isPhi()) {
@@ -467,7 +460,7 @@
               || operandType.isPrimitive()
               || operandType.isNullType()
               || (operandType.isReference()
-                  && operandType.asReferenceTypeLatticeElement().substitute(graphLense, appView)
+                  && operandType.fixupClassTypeReferences(appView.graphLense()::lookupType, appView)
                       == operandType);
         }
       }
@@ -481,7 +474,8 @@
       BasicBlock block = blocks.next();
       for (Phi phi : block.getPhis()) {
         TypeLatticeElement phiTypeLattice = phi.getTypeLattice();
-        TypeLatticeElement substituted = substitute(phiTypeLattice, appView);
+        TypeLatticeElement substituted =
+            phiTypeLattice.fixupClassTypeReferences(appView.graphLense()::lookupType, appView);
         assert substituted == phiTypeLattice || affectedPhis.contains(phi);
       }
     }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/DefaultMethodOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/optimize/info/DefaultMethodOptimizationInfo.java
index d119e9e..3279d51 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/DefaultMethodOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/DefaultMethodOptimizationInfo.java
@@ -43,6 +43,21 @@
   private DefaultMethodOptimizationInfo() {}
 
   @Override
+  public boolean isDefaultMethodOptimizationInfo() {
+    return true;
+  }
+
+  @Override
+  public boolean isUpdatableMethodOptimizationInfo() {
+    return false;
+  }
+
+  @Override
+  public UpdatableMethodOptimizationInfo asUpdatableMethodOptimizationInfo() {
+    return null;
+  }
+
+  @Override
   public boolean cannotBeKept() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfo.java
index a60df56..4777d90 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfo.java
@@ -22,6 +22,12 @@
     Default
   }
 
+  boolean isDefaultMethodOptimizationInfo();
+
+  boolean isUpdatableMethodOptimizationInfo();
+
+  UpdatableMethodOptimizationInfo asUpdatableMethodOptimizationInfo();
+
   boolean cannotBeKept();
 
   boolean classInitializerMayBePostponed();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/MutableFieldOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/optimize/info/MutableFieldOptimizationInfo.java
index fff9db4..55ce1a4 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/MutableFieldOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/MutableFieldOptimizationInfo.java
@@ -4,7 +4,11 @@
 
 package com.android.tools.r8.ir.optimize.info;
 
+import com.android.tools.r8.graph.AppInfoWithSubtyping;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
+import java.util.function.Function;
 
 /**
  * Optimization info for fields.
@@ -20,6 +24,13 @@
   private boolean valueHasBeenPropagated = false;
   private TypeLatticeElement dynamicType = null;
 
+  public void fixupClassTypeReferences(
+      Function<DexType, DexType> mapping, AppView<? extends AppInfoWithSubtyping> appView) {
+    if (dynamicType != null) {
+      dynamicType = dynamicType.fixupClassTypeReferences(mapping, appView);
+    }
+  }
+
   @Override
   public MutableFieldOptimizationInfo mutableCopy() {
     MutableFieldOptimizationInfo copy = new MutableFieldOptimizationInfo();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/UpdatableMethodOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/optimize/info/UpdatableMethodOptimizationInfo.java
index a1c9298..4a34157 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/UpdatableMethodOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/UpdatableMethodOptimizationInfo.java
@@ -7,6 +7,7 @@
 import static com.android.tools.r8.ir.optimize.info.DefaultMethodOptimizationInfo.UNKNOWN_CLASS_TYPE;
 import static com.android.tools.r8.ir.optimize.info.DefaultMethodOptimizationInfo.UNKNOWN_TYPE;
 
+import com.android.tools.r8.graph.AppInfoWithSubtyping;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedMethod.ClassInlinerEligibility;
 import com.android.tools.r8.graph.DexEncodedMethod.TrivialInitializer;
@@ -17,6 +18,7 @@
 import com.android.tools.r8.ir.optimize.info.ParameterUsagesInfo.ParameterUsage;
 import java.util.BitSet;
 import java.util.Set;
+import java.util.function.Function;
 
 public class UpdatableMethodOptimizationInfo implements MethodOptimizationInfo {
 
@@ -104,6 +106,32 @@
     reachabilitySensitive = template.reachabilitySensitive;
   }
 
+  public void fixupClassTypeReferences(
+      Function<DexType, DexType> mapping, AppView<? extends AppInfoWithSubtyping> appView) {
+    if (returnsObjectOfType != null) {
+      returnsObjectOfType = returnsObjectOfType.fixupClassTypeReferences(mapping, appView);
+    }
+    if (returnsObjectWithLowerBoundType != null) {
+      returnsObjectWithLowerBoundType =
+          returnsObjectWithLowerBoundType.fixupClassTypeReferences(mapping, appView);
+    }
+  }
+
+  @Override
+  public boolean isDefaultMethodOptimizationInfo() {
+    return false;
+  }
+
+  @Override
+  public boolean isUpdatableMethodOptimizationInfo() {
+    return true;
+  }
+
+  @Override
+  public UpdatableMethodOptimizationInfo asUpdatableMethodOptimizationInfo() {
+    return this;
+  }
+
   @Override
   public boolean cannotBeKept() {
     return cannotBeKept;
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java b/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
index cd07285..fa19e3e 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
@@ -26,6 +26,8 @@
 import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
 import com.android.tools.r8.ir.optimize.Outliner;
+import com.android.tools.r8.ir.optimize.info.FieldOptimizationInfo;
+import com.android.tools.r8.ir.optimize.info.MethodOptimizationInfo;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
 import com.android.tools.r8.ir.optimize.lambda.CodeProcessor.Strategy;
 import com.android.tools.r8.ir.optimize.lambda.LambdaGroup.LambdaStructureError;
@@ -48,6 +50,7 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.function.BiFunction;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 // Merging lambda classes into single lambda group classes. There are three flavors
@@ -217,6 +220,10 @@
     // sequential lambda ids, create group lambda classes.
     Map<LambdaGroup, DexProgramClass> lambdaGroupsClasses = finalizeLambdaGroups();
 
+    // Fixup optimization info to ensure that the optimization info does not refer to any merged
+    // lambdas.
+    new OptimizationInfoFixer(lambdaGroupsClasses).fixupOptimizationInfos(executorService);
+
     // Switch to APPLY strategy.
     this.strategyFactory = ApplyStrategy::new;
 
@@ -462,4 +469,59 @@
       strategy.patch(this, staticGet);
     }
   }
+
+  private final class OptimizationInfoFixer implements Function<DexType, DexType> {
+
+    private final Map<LambdaGroup, DexProgramClass> lambdaGroupsClasses;
+
+    OptimizationInfoFixer(Map<LambdaGroup, DexProgramClass> lambdaGroupsClasses) {
+      this.lambdaGroupsClasses = lambdaGroupsClasses;
+    }
+
+    void fixupOptimizationInfos(ExecutorService executorService) throws ExecutionException {
+      List<Future<?>> futures = new ArrayList<>();
+      for (DexProgramClass clazz : appView.appInfo().classes()) {
+        futures.add(
+            executorService.submit(
+                () -> {
+                  fixupOptimizationInfos(clazz);
+                  return null;
+                }));
+      }
+      ThreadUtils.awaitFutures(futures);
+    }
+
+    private void fixupOptimizationInfos(DexProgramClass clazz) {
+      for (DexEncodedMethod method : clazz.methods()) {
+        MethodOptimizationInfo optimizationInfo = method.getOptimizationInfo();
+        if (optimizationInfo.isUpdatableMethodOptimizationInfo()) {
+          optimizationInfo
+              .asUpdatableMethodOptimizationInfo()
+              .fixupClassTypeReferences(this, appView);
+        } else {
+          assert optimizationInfo.isDefaultMethodOptimizationInfo();
+        }
+      }
+      for (DexEncodedField field : clazz.fields()) {
+        FieldOptimizationInfo optimizationInfo = field.getOptimizationInfo();
+        if (optimizationInfo.isMutableFieldOptimizationInfo()) {
+          optimizationInfo.asMutableFieldOptimizationInfo().fixupClassTypeReferences(this, appView);
+        } else {
+          assert optimizationInfo.isDefaultFieldOptimizationInfo();
+        }
+      }
+    }
+
+    @Override
+    public DexType apply(DexType type) {
+      LambdaGroup group = lambdas.get(type);
+      if (group != null) {
+        DexProgramClass clazz = lambdaGroupsClasses.get(group);
+        if (clazz != null) {
+          return clazz.type;
+        }
+      }
+      return type;
+    }
+  }
 }