Reset LensCodeRewriterUtils when graphLens changes

Fixes a bug where the LensCodeRewriterUtils of the
LensCodeRewriter was never refreshed, leading to invalid
rewritting.

Bug: 204174455
Change-Id: I4b1fd564e7f5b8aba9cce03fb757fba0cab6e763
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index cd298e0..4695e9c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -109,16 +109,24 @@
 
   private final AppView<? extends AppInfoWithClassHierarchy> appView;
   private final EnumUnboxer enumUnboxer;
-  private final LensCodeRewriterUtils helper;
+  private LensCodeRewriterUtils helper = null;
   private final InternalOptions options;
 
   LensCodeRewriter(AppView<? extends AppInfoWithClassHierarchy> appView, EnumUnboxer enumUnboxer) {
     this.appView = appView;
     this.enumUnboxer = enumUnboxer;
-    this.helper = new LensCodeRewriterUtils(appView);
     this.options = appView.options();
   }
 
+  private LensCodeRewriterUtils getHelper() {
+    // The LensCodeRewriterUtils uses internal caches that are not valid if the graphLens changes.
+    if (helper != null && helper.hasGraphLens(appView.graphLens())) {
+      return helper;
+    }
+    helper = new LensCodeRewriterUtils(appView);
+    return helper;
+  }
+
   private Value makeOutValue(Instruction insn, IRCode code) {
     if (insn.outValue() != null) {
       TypeElement oldType = insn.getOutType();
@@ -163,7 +171,7 @@
             {
               InvokeCustom invokeCustom = current.asInvokeCustom();
               DexCallSite callSite = invokeCustom.getCallSite();
-              DexCallSite newCallSite = helper.rewriteCallSite(callSite, method);
+              DexCallSite newCallSite = getHelper().rewriteCallSite(callSite, method);
               if (newCallSite != callSite) {
                 Value newOutValue = makeOutValue(invokeCustom, code);
                 InvokeCustom newInvokeCustom =
@@ -180,7 +188,8 @@
             {
               DexMethodHandle handle = current.asConstMethodHandle().getValue();
               DexMethodHandle newHandle =
-                  helper.rewriteDexMethodHandle(handle, NOT_ARGUMENT_TO_LAMBDA_METAFACTORY, method);
+                  getHelper()
+                      .rewriteDexMethodHandle(handle, NOT_ARGUMENT_TO_LAMBDA_METAFACTORY, method);
               if (newHandle != handle) {
                 Value newOutValue = makeOutValue(current, code);
                 iterator.replaceCurrentInstruction(new ConstMethodHandle(newOutValue, newHandle));
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriterUtils.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriterUtils.java
index b8a70f9..0353705 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriterUtils.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriterUtils.java
@@ -33,7 +33,6 @@
 
 public class LensCodeRewriterUtils {
 
-  private final AppView<?> appView;
   private final DexDefinitionSupplier definitions;
   private final GraphLens graphLens;
 
@@ -44,27 +43,21 @@
   private final Map<DexCallSite, DexCallSite> rewrittenCallSiteCache;
 
   public LensCodeRewriterUtils(AppView<?> appView) {
-    this(appView, false);
+    this(appView, appView.graphLens());
   }
 
   public LensCodeRewriterUtils(AppView<?> appView, boolean enableCallSiteCaching) {
-    this.appView = appView;
     this.definitions = appView;
-    this.graphLens = null;
+    this.graphLens = appView.graphLens();
     this.rewrittenCallSiteCache = enableCallSiteCaching ? new ConcurrentHashMap<>() : null;
   }
 
   public LensCodeRewriterUtils(DexDefinitionSupplier definitions, GraphLens graphLens) {
-    this.appView = null;
     this.definitions = definitions;
     this.graphLens = graphLens;
     this.rewrittenCallSiteCache = null;
   }
 
-  private GraphLens graphLens() {
-    return appView != null ? appView.graphLens() : graphLens;
-  }
-
   public DexCallSite rewriteCallSite(DexCallSite callSite, ProgramMethod context) {
     if (rewrittenCallSiteCache == null) {
       return rewriteCallSiteInternal(callSite, context);
@@ -100,7 +93,7 @@
       DexMethod invokedMethod = methodHandle.asMethod();
       MethodHandleType oldType = methodHandle.type;
       MethodLookupResult lensLookup =
-          graphLens().lookupMethod(invokedMethod, context.getReference(), oldType.toInvokeType());
+          graphLens.lookupMethod(invokedMethod, context.getReference(), oldType.toInvokeType());
       DexMethod rewrittenTarget = lensLookup.getReference();
       DexMethod actualTarget;
       MethodHandleType newType;
@@ -120,7 +113,7 @@
             definitions
                 .dexItemFactory()
                 .createMethod(
-                    graphLens().lookupType(invokedMethod.holder),
+                    graphLens.lookupType(invokedMethod.holder),
                     rewrittenTarget.proto,
                     rewrittenTarget.name);
         newType = oldType;
@@ -147,7 +140,7 @@
       }
     } else {
       DexField field = methodHandle.asField();
-      DexField actualField = graphLens().lookupField(field);
+      DexField actualField = graphLens.lookupField(field);
       if (actualField != field) {
         return definitions
             .dexItemFactory()
@@ -192,7 +185,7 @@
         return rewriteDexMethodType(value.asDexValueMethodType());
       case TYPE:
         DexType oldType = value.asDexValueType().value;
-        DexType newType = graphLens().lookupType(oldType);
+        DexType newType = graphLens.lookupType(oldType);
         return newType != oldType ? new DexValueType(newType) : value;
       default:
         return value;
@@ -202,7 +195,7 @@
   public DexProto rewriteProto(DexProto proto) {
     return definitions
         .dexItemFactory()
-        .applyClassMappingToProto(proto, graphLens()::lookupType, protoFixupCache);
+        .applyClassMappingToProto(proto, graphLens::lookupType, protoFixupCache);
   }
 
   private DexValueMethodHandle rewriteDexValueMethodHandle(
@@ -211,4 +204,8 @@
     DexMethodHandle newHandle = rewriteDexMethodHandle(oldHandle, use, context);
     return newHandle != oldHandle ? new DexValueMethodHandle(newHandle) : methodHandle;
   }
+
+  public boolean hasGraphLens(GraphLens graphLens) {
+    return this.graphLens == graphLens;
+  }
 }