Introduce NestedGraphLensWithCustomCodeRewriting

- Avoid specific code for enum unboxing

Bug: b/307872552
Change-Id: I03959213e4f789bab795e22ce6e075b77fd722bd
diff --git a/src/main/java/com/android/tools/r8/graph/lens/GraphLens.java b/src/main/java/com/android/tools/r8/graph/lens/GraphLens.java
index b34bf05..8b5b5d3 100644
--- a/src/main/java/com/android/tools/r8/graph/lens/GraphLens.java
+++ b/src/main/java/com/android/tools/r8/graph/lens/GraphLens.java
@@ -23,6 +23,7 @@
 import com.android.tools.r8.graph.proto.RewrittenPrototypeDescription;
 import com.android.tools.r8.ir.code.InvokeType;
 import com.android.tools.r8.ir.conversion.LensCodeRewriterUtils;
+import com.android.tools.r8.ir.optimize.CustomLensCodeRewriter;
 import com.android.tools.r8.ir.optimize.enums.EnumUnboxingLens;
 import com.android.tools.r8.optimize.MemberRebindingIdentityLens;
 import com.android.tools.r8.optimize.MemberRebindingLens;
@@ -388,6 +389,11 @@
     return false;
   }
 
+  public CustomLensCodeRewriter getCustomCodeRewriting() {
+    assert hasCustomCodeRewritings();
+    return null;
+  }
+
   public boolean isAppliedLens() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/graph/lens/NestedGraphLensWithCustomCodeRewriting.java b/src/main/java/com/android/tools/r8/graph/lens/NestedGraphLensWithCustomCodeRewriting.java
new file mode 100644
index 0000000..93cc321
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/graph/lens/NestedGraphLensWithCustomCodeRewriting.java
@@ -0,0 +1,50 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.graph.lens;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.ir.optimize.CustomLensCodeRewriter;
+import com.android.tools.r8.utils.collections.BidirectionalManyToManyRepresentativeMap;
+import com.android.tools.r8.utils.collections.BidirectionalManyToOneRepresentativeMap;
+import java.util.Map;
+
+public class NestedGraphLensWithCustomCodeRewriting extends NestedGraphLens {
+
+  private final CustomLensCodeRewriter customLensCodeRewriter;
+
+  public NestedGraphLensWithCustomCodeRewriting(
+      AppView<?> appView,
+      BidirectionalManyToOneRepresentativeMap<DexField, DexField> fieldMap,
+      BidirectionalManyToOneRepresentativeMap<DexMethod, DexMethod> methodMap,
+      BidirectionalManyToManyRepresentativeMap<DexType, DexType> typeMap,
+      CustomLensCodeRewriter customLensCodeRewriter) {
+    super(appView, fieldMap, methodMap, typeMap);
+    this.customLensCodeRewriter = customLensCodeRewriter;
+  }
+
+  public NestedGraphLensWithCustomCodeRewriting(
+      AppView<?> appView,
+      BidirectionalManyToOneRepresentativeMap<DexField, DexField> fieldMap,
+      Map<DexMethod, DexMethod> methodMap,
+      BidirectionalManyToManyRepresentativeMap<DexType, DexType> typeMap,
+      BidirectionalManyToManyRepresentativeMap<DexMethod, DexMethod> newMethodSignatures,
+      CustomLensCodeRewriter customLensCodeRewriter) {
+    super(appView, fieldMap, methodMap, typeMap, newMethodSignatures);
+    this.customLensCodeRewriter = customLensCodeRewriter;
+  }
+
+  @Override
+  public boolean hasCustomCodeRewritings() {
+    return true;
+  }
+
+  @Override
+  public CustomLensCodeRewriter getCustomCodeRewriting() {
+    return customLensCodeRewriter;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 0e13017..2a10e4c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -259,7 +259,7 @@
               : null;
       this.enumUnboxer = EnumUnboxer.create(appViewWithLiveness);
       this.numberUnboxer = NumberUnboxer.create(appViewWithLiveness);
-      this.lensCodeRewriter = new LensCodeRewriter(appViewWithLiveness, enumUnboxer);
+      this.lensCodeRewriter = new LensCodeRewriter(appViewWithLiveness);
       this.inliner = new Inliner(appViewWithLiveness, this, lensCodeRewriter);
       this.outliner = Outliner.create(appViewWithLiveness);
       this.memberValuePropagation = new R8MemberValuePropagation(appViewWithLiveness);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 9cb88c3..bd23bfd 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -107,7 +107,7 @@
 import com.android.tools.r8.ir.code.ValueType;
 import com.android.tools.r8.ir.conversion.passes.TrivialPhiSimplifier;
 import com.android.tools.r8.ir.optimize.AffectedValues;
-import com.android.tools.r8.ir.optimize.enums.EnumUnboxer;
+import com.android.tools.r8.ir.optimize.CustomLensCodeRewriter;
 import com.android.tools.r8.optimize.MemberRebindingAnalysis;
 import com.android.tools.r8.optimize.argumentpropagation.lenscoderewriter.NullCheckInserter;
 import com.android.tools.r8.utils.ArrayUtils;
@@ -159,13 +159,11 @@
 
   private final AppView<? extends AppInfoWithClassHierarchy> appView;
   private final DexItemFactory factory;
-  private final EnumUnboxer enumUnboxer;
   private final InternalOptions options;
 
-  LensCodeRewriter(AppView<? extends AppInfoWithClassHierarchy> appView, EnumUnboxer enumUnboxer) {
+  LensCodeRewriter(AppView<? extends AppInfoWithClassHierarchy> appView) {
     this.appView = appView;
     this.factory = appView.dexItemFactory();
-    this.enumUnboxer = enumUnboxer;
     this.options = appView.options();
   }
 
@@ -228,9 +226,10 @@
     rewriteArguments(
         code, originalMethodReference, prototypeChanges, affectedPhis, unusedArguments);
     if (graphLens.hasCustomCodeRewritings()) {
-      assert graphLens.isEnumUnboxerLens();
       assert graphLens.getPrevious() == codeLens;
-      affectedPhis.addAll(enumUnboxer.rewriteCode(code, methodProcessor, prototypeChanges));
+      CustomLensCodeRewriter customLensCodeRewriter = graphLens.getCustomCodeRewriting();
+      affectedPhis.addAll(
+          customLensCodeRewriter.rewriteCode(code, methodProcessor, prototypeChanges, graphLens));
     }
     if (!unusedArguments.isEmpty()) {
       for (UnusedArgument unusedArgument : unusedArguments) {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
index 6b6f2ad..8b4cbce 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
@@ -212,8 +212,6 @@
 
     appView.clearMethodResolutionOptimizationInfoCollection();
 
-    enumUnboxer.unsetRewriter();
-
     // All the code that should be impacted by the lenses inserted between phase 1 and phase 2
     // have now been processed and rewritten, we clear code lens rewriting so that the class
     // staticizer and phase 3 does not perform again the rewriting.
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CustomLensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CustomLensCodeRewriter.java
new file mode 100644
index 0000000..58827ef
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CustomLensCodeRewriter.java
@@ -0,0 +1,21 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize;
+
+import com.android.tools.r8.graph.lens.NonIdentityGraphLens;
+import com.android.tools.r8.graph.proto.RewrittenPrototypeDescription;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Phi;
+import com.android.tools.r8.ir.conversion.MethodProcessor;
+import java.util.Set;
+
+public interface CustomLensCodeRewriter {
+
+  Set<Phi> rewriteCode(
+      IRCode code,
+      MethodProcessor methodProcessor,
+      RewrittenPrototypeDescription prototypeChanges,
+      NonIdentityGraphLens lens);
+}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EmptyEnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EmptyEnumUnboxer.java
index cc90f9b..4e07fba 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EmptyEnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EmptyEnumUnboxer.java
@@ -8,18 +8,14 @@
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.lens.GraphLens;
-import com.android.tools.r8.graph.proto.RewrittenPrototypeDescription;
 import com.android.tools.r8.ir.analysis.fieldvalueanalysis.StaticFieldValues;
 import com.android.tools.r8.ir.code.IRCode;
-import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.conversion.PostMethodProcessor.Builder;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackDelayed;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.Timing;
-import java.util.Collections;
-import java.util.Set;
 import java.util.concurrent.ExecutorService;
 
 public class EmptyEnumUnboxer extends EnumUnboxer {
@@ -58,14 +54,6 @@
   }
 
   @Override
-  public Set<Phi> rewriteCode(
-      IRCode code,
-      MethodProcessor methodProcessor,
-      RewrittenPrototypeDescription prototypeChanges) {
-    return Collections.emptySet();
-  }
-
-  @Override
   @SuppressWarnings("BadImport")
   public void rewriteWithLens() {
     // Intentionally empty.
@@ -84,11 +72,6 @@
   }
 
   @Override
-  public void unsetRewriter() {
-    // Intentionally empty.
-  }
-
-  @Override
   public void updateEnumUnboxingCandidatesInfo() {
     // Intentionally empty.
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index 48ae3b2..081d608 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -8,17 +8,14 @@
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.lens.GraphLens;
-import com.android.tools.r8.graph.proto.RewrittenPrototypeDescription;
 import com.android.tools.r8.ir.analysis.fieldvalueanalysis.StaticFieldValues;
 import com.android.tools.r8.ir.code.IRCode;
-import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.conversion.PostMethodProcessor.Builder;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackDelayed;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.Timing;
-import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 
@@ -44,9 +41,6 @@
 
   public abstract void recordEnumState(DexProgramClass clazz, StaticFieldValues staticFieldValues);
 
-  public abstract Set<Phi> rewriteCode(
-      IRCode code, MethodProcessor methodProcessor, RewrittenPrototypeDescription prototypeChanges);
-
   public abstract void rewriteWithLens();
 
   @SuppressWarnings("BadImport")
@@ -59,7 +53,5 @@
       Timing timing)
       throws ExecutionException;
 
-  public abstract void unsetRewriter();
-
   public abstract void updateEnumUnboxingCandidatesInfo();
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
index 5821b3a..7d66f95 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxerImpl.java
@@ -166,9 +166,6 @@
       checkNotNullMethodsBuilder;
 
   private final DexClassAndField ordinalField;
-
-  private EnumUnboxingRewriter enumUnboxerRewriter;
-
   private final boolean debugLogEnabled;
   private final Map<DexType, List<Reason>> debugLogs;
 
@@ -741,7 +738,6 @@
         new EnumUnboxingTreeFixer(
                 appView, checkNotNullMethods, enumDataMap, enumClassesToUnbox, utilityClasses)
             .fixupTypeReferences(converter, executorService, timing);
-    EnumUnboxingLens enumUnboxingLens = treeFixerResult.getLens();
 
     // Enqueue the (lens rewritten) methods that require reprocessing.
     //
@@ -764,14 +760,6 @@
 
     updateOptimizationInfos(executorService, feedback, treeFixerResult, previousLens);
 
-    enumUnboxerRewriter =
-        new EnumUnboxingRewriter(
-            appView,
-            treeFixerResult.getCheckNotNullToCheckNotZeroMapping(),
-            enumUnboxingLens,
-            enumDataMap,
-            utilityClasses);
-
     // Ensure determinism of method-to-reprocess set.
     appView.testing().checkDeterminism(postMethodProcessorBuilder::dump);
 
@@ -1785,22 +1773,4 @@
     enumUnboxingCandidatesInfo.addPrunedMethod(method);
     methodsDependingOnLibraryModelisation.remove(method.getReference(), appView.graphLens());
   }
-
-  @Override
-  public Set<Phi> rewriteCode(
-      IRCode code,
-      MethodProcessor methodProcessor,
-      RewrittenPrototypeDescription prototypeChanges) {
-    // This has no effect during primary processing since the enumUnboxerRewriter is set
-    // in between primary and post processing.
-    if (enumUnboxerRewriter != null) {
-      return enumUnboxerRewriter.rewriteCode(code, methodProcessor, prototypeChanges);
-    }
-    return Sets.newIdentityHashSet();
-  }
-
-  @Override
-  public void unsetRewriter() {
-    enumUnboxerRewriter = null;
-  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
index 95e4ea6..0e2f7a8 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
@@ -14,7 +14,7 @@
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.graph.lens.MethodLookupResult;
-import com.android.tools.r8.graph.lens.NestedGraphLens;
+import com.android.tools.r8.graph.lens.NestedGraphLensWithCustomCodeRewriting;
 import com.android.tools.r8.graph.proto.ArgumentInfoCollection;
 import com.android.tools.r8.graph.proto.RewrittenPrototypeDescription;
 import com.android.tools.r8.graph.proto.RewrittenTypeInfo;
@@ -44,7 +44,7 @@
 import java.util.Map;
 import java.util.Set;
 
-public class EnumUnboxingLens extends NestedGraphLens {
+public class EnumUnboxingLens extends NestedGraphLensWithCustomCodeRewriting {
 
   private final AbstractValueFactory abstractValueFactory;
   private final Map<DexMethod, RewrittenPrototypeDescription> prototypeChangesPerMethod;
@@ -58,8 +58,9 @@
       BidirectionalManyToOneRepresentativeMap<DexType, DexType> typeMap,
       Map<DexMethod, DexMethod> methodMap,
       Map<DexMethod, RewrittenPrototypeDescription> prototypeChangesPerMethod,
-      Set<DexMethod> dispatchMethods) {
-    super(appView, fieldMap, methodMap, typeMap, renamedSignatures);
+      Set<DexMethod> dispatchMethods,
+      EnumUnboxingRewriter enumUnboxingRewriter) {
+    super(appView, fieldMap, methodMap, typeMap, renamedSignatures, enumUnboxingRewriter);
     assert !appView.unboxedEnums().isEmpty();
     this.abstractValueFactory = appView.abstractValueFactory();
     this.prototypeChangesPerMethod = prototypeChangesPerMethod;
@@ -68,11 +69,6 @@
   }
 
   @Override
-  public boolean hasCustomCodeRewritings() {
-    return true;
-  }
-
-  @Override
   public boolean isEnumUnboxerLens() {
     return true;
   }
@@ -82,6 +78,10 @@
     return this;
   }
 
+  public EnumDataMap getUnboxedEnums() {
+    return unboxedEnums;
+  }
+
   @Override
   public boolean isContextFreeForMethods(GraphLens codeLens) {
     if (codeLens == this) {
@@ -383,7 +383,10 @@
           originalCheckNotNullMethodSignature, checkNotNullMethod.getReference());
     }
 
-    public EnumUnboxingLens build(AppView<?> appView, Set<DexMethod> dispatchMethods) {
+    public EnumUnboxingLens build(
+        AppView<AppInfoWithLiveness> appView,
+        Set<DexMethod> dispatchMethods,
+        EnumUnboxingRewriter enumUnboxingRewriter) {
       assert !typeMap.isEmpty();
       return new EnumUnboxingLens(
           appView,
@@ -392,7 +395,8 @@
           typeMap,
           methodMap,
           ImmutableMap.copyOf(prototypeChangesPerMethod),
-          dispatchMethods);
+          dispatchMethods,
+          enumUnboxingRewriter);
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 46184c4..eb38f33 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -14,6 +14,7 @@
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.lens.NonIdentityGraphLens;
 import com.android.tools.r8.graph.proto.ArgumentInfo;
 import com.android.tools.r8.graph.proto.RewrittenPrototypeDescription;
 import com.android.tools.r8.graph.proto.RewrittenTypeInfo;
@@ -42,6 +43,7 @@
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
+import com.android.tools.r8.ir.optimize.CustomLensCodeRewriter;
 import com.android.tools.r8.ir.optimize.enums.EnumInstanceFieldData.EnumInstanceFieldKnownData;
 import com.android.tools.r8.ir.optimize.enums.classification.CheckNotNullEnumUnboxerMethodClassification;
 import com.android.tools.r8.ir.optimize.enums.classification.EnumUnboxerMethodClassification;
@@ -58,27 +60,24 @@
 import java.util.Map;
 import java.util.Set;
 
-public class EnumUnboxingRewriter {
+public class EnumUnboxingRewriter implements CustomLensCodeRewriter {
 
   private final AppView<AppInfoWithLiveness> appView;
   private final Map<DexMethod, DexMethod> checkNotNullToCheckNotZeroMapping;
   private final DexItemFactory factory;
   private final InternalOptions options;
   private final EnumDataMap unboxedEnumsData;
-  private final EnumUnboxingLens enumUnboxingLens;
   private final EnumUnboxingUtilityClasses utilityClasses;
 
   EnumUnboxingRewriter(
       AppView<AppInfoWithLiveness> appView,
       Map<DexMethod, DexMethod> checkNotNullToCheckNotZeroMapping,
-      EnumUnboxingLens enumUnboxingLens,
       EnumDataMap unboxedEnumsInstanceFieldData,
       EnumUnboxingUtilityClasses utilityClasses) {
     this.appView = appView;
     this.checkNotNullToCheckNotZeroMapping = checkNotNullToCheckNotZeroMapping;
     this.factory = appView.dexItemFactory();
     this.options = appView.options();
-    this.enumUnboxingLens = enumUnboxingLens;
     this.unboxedEnumsData = unboxedEnumsInstanceFieldData;
     this.utilityClasses = utilityClasses;
   }
@@ -146,15 +145,19 @@
     return convertedEnums;
   }
 
-  Set<Phi> rewriteCode(
+  @Override
+  public Set<Phi> rewriteCode(
       IRCode code,
       MethodProcessor methodProcessor,
-      RewrittenPrototypeDescription prototypeChanges) {
+      RewrittenPrototypeDescription prototypeChanges,
+      NonIdentityGraphLens graphLens) {
     // We should not process the enum methods, they will be removed and they may contain invalid
     // rewriting rules.
     if (unboxedEnumsData.isEmpty()) {
       return Sets.newIdentityHashSet();
     }
+    assert graphLens.isEnumUnboxerLens();
+    EnumUnboxingLens enumUnboxingLens = graphLens.asEnumUnboxerLens();
     assert code.isConsistentSSABeforeTypesAreCorrect(appView);
     EnumUnboxerMethodProcessorEventConsumer eventConsumer = methodProcessor.getEventConsumer();
     Set<Phi> affectedPhis = Sets.newIdentityHashSet();
@@ -192,7 +195,8 @@
               blocks,
               block,
               iterator,
-              instruction.asInvokeMethodWithReceiver());
+              instruction.asInvokeMethodWithReceiver(),
+              enumUnboxingLens);
         } else if (instruction.isNewArrayFilled()) {
           rewriteNewArrayFilled(instruction.asNewArrayFilled(), code, convertedEnums, iterator);
         } else if (instruction.isInvokeStatic()) {
@@ -379,7 +383,8 @@
       BasicBlockIterator blocks,
       BasicBlock block,
       InstructionListIterator iterator,
-      InvokeMethodWithReceiver invoke) {
+      InvokeMethodWithReceiver invoke,
+      EnumUnboxingLens enumUnboxingLens) {
     ProgramMethod context = code.context();
     // If the receiver is null, then the invoke is not rewritten even if the receiver is an
     // unboxed enum, but we end up with null.ordinal() or similar which has the correct behavior.
@@ -699,7 +704,7 @@
     }
   }
 
-  public void rewriteNullCheck(
+  private void rewriteNullCheck(
       InstructionListIterator iterator,
       InvokeMethod invoke,
       ProgramMethod context,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
index 59be01e..8910375 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingTreeFixer.java
@@ -154,26 +154,30 @@
     new ConcurrentMethodFixup(appView, this)
         .fixupClassesConcurrentlyByConnectedProgramComponents(Timing.empty(), executorService);
 
+    // Create mapping from checkNotNull() to checkNotZero() methods.
+    BiMap<DexMethod, DexMethod> checkNotNullToCheckNotZeroMapping =
+        duplicateCheckNotNullMethods(converter, executorService);
+
     // Install the new graph lens before processing any checkNotZero() methods.
     Set<DexMethod> dispatchMethodReferences = Sets.newIdentityHashSet();
     dispatchMethods.forEach((method, code) -> dispatchMethodReferences.add(method.getReference()));
-    EnumUnboxingLens lens = lensBuilder.build(appView, dispatchMethodReferences);
+    EnumUnboxingRewriter enumUnboxingRewriter =
+        new EnumUnboxingRewriter(
+            appView, checkNotNullToCheckNotZeroMapping, enumDataMap, utilityClasses);
+    EnumUnboxingLens lens =
+        lensBuilder.build(appView, dispatchMethodReferences, enumUnboxingRewriter);
     appView.rewriteWithLens(lens, executorService, timing);
 
     // Rewrite outliner with lens.
     converter.outliner.rewriteWithLens();
 
-    // Create mapping from checkNotNull() to checkNotZero() methods.
-    BiMap<DexMethod, DexMethod> checkNotNullToCheckNotZeroMapping =
-        duplicateCheckNotNullMethods(converter, executorService);
-
     dispatchMethods.forEach((method, code) -> code.setCodeLens(lens));
     profileCollectionAdditions
         .setArtProfileCollection(appView.getArtProfileCollection())
         .commit(appView);
 
     return new Result(
-        checkNotNullToCheckNotZeroMapping, methodsToProcess, lens, prunedItemsBuilder.build());
+        checkNotNullToCheckNotZeroMapping, methodsToProcess, prunedItemsBuilder.build());
   }
 
   private void cleanUpOldClass(DexProgramClass clazz) {
@@ -1098,17 +1102,14 @@
 
     private final BiMap<DexMethod, DexMethod> checkNotNullToCheckNotZeroMapping;
     private final ProgramMethodSet methodsToProcess;
-    private final EnumUnboxingLens lens;
     private final PrunedItems prunedItems;
 
     Result(
         BiMap<DexMethod, DexMethod> checkNotNullToCheckNotZeroMapping,
         ProgramMethodSet methodsToProcess,
-        EnumUnboxingLens lens,
         PrunedItems prunedItems) {
       this.checkNotNullToCheckNotZeroMapping = checkNotNullToCheckNotZeroMapping;
       this.methodsToProcess = methodsToProcess;
-      this.lens = lens;
       this.prunedItems = prunedItems;
     }
 
@@ -1120,10 +1121,6 @@
       return methodsToProcess;
     }
 
-    EnumUnboxingLens getLens() {
-      return lens;
-    }
-
     PrunedItems getPrunedItems() {
       return prunedItems;
     }