Extend lir rewriting to catch handlers and method handles

Bug: b/315445393
Change-Id: If0df72aeab54ec45a1fe7719faad938545e005d1
diff --git a/src/main/java/com/android/tools/r8/ir/code/CatchHandlers.java b/src/main/java/com/android/tools/r8/ir/code/CatchHandlers.java
index 1717cb3..7da900e 100644
--- a/src/main/java/com/android/tools/r8/ir/code/CatchHandlers.java
+++ b/src/main/java/com/android/tools/r8/ir/code/CatchHandlers.java
@@ -5,7 +5,9 @@
 
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.ir.code.CatchHandlers.CatchHandler;
+import com.android.tools.r8.utils.ListUtils;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import java.util.ArrayList;
@@ -110,6 +112,12 @@
     return new CatchHandlers<>(newGuards, newTargets);
   }
 
+  public CatchHandlers<T> rewriteWithLens(GraphLens graphLens, GraphLens codeLens) {
+    List<DexType> newGuards =
+        ListUtils.mapOrElse(guards, guard -> graphLens.lookupType(guard, codeLens), null);
+    return newGuards != null ? new CatchHandlers<>(newGuards, targets) : this;
+  }
+
   public void forEach(BiConsumer<DexType, T> consumer) {
     for (int i = 0; i < size(); ++i) {
       consumer.accept(guards.get(i), targets.get(i));
diff --git a/src/main/java/com/android/tools/r8/lightir/LirCode.java b/src/main/java/com/android/tools/r8/lightir/LirCode.java
index 056b7fa..1819495 100644
--- a/src/main/java/com/android/tools/r8/lightir/LirCode.java
+++ b/src/main/java/com/android/tools/r8/lightir/LirCode.java
@@ -38,6 +38,7 @@
 import com.android.tools.r8.utils.ComparatorUtils;
 import com.android.tools.r8.utils.IntBox;
 import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.ObjectUtils;
 import com.android.tools.r8.utils.RetracerForCodePrinting;
 import com.android.tools.r8.utils.structural.CompareToVisitor;
 import com.android.tools.r8.utils.structural.HashingVisitor;
@@ -200,6 +201,32 @@
       return TryCatchTable::specify;
     }
 
+    public TryCatchTable rewriteWithLens(GraphLens graphLens, GraphLens codeLens) {
+      Int2ReferenceMap<CatchHandlers<Integer>> newTryCatchHandlers = null;
+      for (Int2ReferenceMap.Entry<CatchHandlers<Integer>> entry :
+          tryCatchHandlers.int2ReferenceEntrySet()) {
+        int block = entry.getIntKey();
+        CatchHandlers<Integer> blockHandlers = entry.getValue();
+        CatchHandlers<Integer> newBlockHandlers =
+            blockHandlers.rewriteWithLens(graphLens, codeLens);
+        if (newTryCatchHandlers == null) {
+          if (ObjectUtils.identical(newBlockHandlers, blockHandlers)) {
+            continue;
+          }
+          newTryCatchHandlers = new Int2ReferenceOpenHashMap<>(tryCatchHandlers.size());
+          for (Int2ReferenceMap.Entry<CatchHandlers<Integer>> previousEntry :
+              tryCatchHandlers.int2ReferenceEntrySet()) {
+            if (previousEntry == entry) {
+              break;
+            }
+            newTryCatchHandlers.put(previousEntry.getIntKey(), previousEntry.getValue());
+          }
+        }
+        newTryCatchHandlers.put(block, newBlockHandlers);
+      }
+      return newTryCatchHandlers != null ? new TryCatchTable(newTryCatchHandlers) : this;
+    }
+
     private static void specify(StructuralSpecification<TryCatchTable, ?> spec) {
       spec.withInt2CustomItemMap(
           s -> s.tryCatchHandlers,
@@ -774,6 +801,24 @@
         metadataMap);
   }
 
+  public LirCode<EV> newCodeWithRewrittenTryCatchTable(TryCatchTable rewrittenTryCatchTable) {
+    if (rewrittenTryCatchTable == tryCatchTable) {
+      return this;
+    }
+    return new LirCode<>(
+        irMetadata,
+        constants,
+        positionTable,
+        argumentCount,
+        instructions,
+        instructionCount,
+        rewrittenTryCatchTable,
+        debugLocalInfoTable,
+        strategyInfo,
+        useDexEstimationStrategy,
+        metadataMap);
+  }
+
   public LirCode<EV> rewriteWithSimpleLens(
       ProgramMethod context, AppView<?> appView, LensCodeRewriterUtils rewriterUtils) {
     GraphLens graphLens = appView.graphLens();
diff --git a/src/main/java/com/android/tools/r8/lightir/SimpleLensLirRewriter.java b/src/main/java/com/android/tools/r8/lightir/SimpleLensLirRewriter.java
index d093823..1a4d065 100644
--- a/src/main/java/com/android/tools/r8/lightir/SimpleLensLirRewriter.java
+++ b/src/main/java/com/android/tools/r8/lightir/SimpleLensLirRewriter.java
@@ -4,10 +4,13 @@
 
 package com.android.tools.r8.lightir;
 
+import static com.android.tools.r8.graph.UseRegistry.MethodHandleUse.NOT_ARGUMENT_TO_LAMBDA_METAFACTORY;
+
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.DexCallSite;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexMethodHandle;
 import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
@@ -17,6 +20,8 @@
 import com.android.tools.r8.ir.code.InvokeType;
 import com.android.tools.r8.ir.code.Opcodes;
 import com.android.tools.r8.ir.conversion.LensCodeRewriterUtils;
+import com.android.tools.r8.lightir.LirBuilder.RecordFieldValuesPayload;
+import com.android.tools.r8.lightir.LirCode.TryCatchTable;
 import com.android.tools.r8.utils.ArrayUtils;
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
 import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
@@ -68,6 +73,12 @@
     addRewrittenMapping(callSite, helper.rewriteCallSite(callSite, context));
   }
 
+  public void onMethodHandleReference(DexMethodHandle methodHandle) {
+    addRewrittenMapping(
+        methodHandle,
+        helper.rewriteDexMethodHandle(methodHandle, NOT_ARGUMENT_TO_LAMBDA_METAFACTORY, context));
+  }
+
   public void onProtoReference(DexProto proto) {
     addRewrittenMapping(proto, helper.rewriteProto(proto));
   }
@@ -143,7 +154,8 @@
 
   public LirCode<EV> rewrite() {
     LirCode<EV> rewritten = rewriteConstantPoolAndScanForTypeChanges(getCode());
-    return rewriteInstructionsWithInvokeTypeChanges(rewritten);
+    rewritten = rewriteInstructionsWithInvokeTypeChanges(rewritten);
+    return rewriteTryCatchTable(rewritten);
   }
 
   private LirCode<EV> rewriteConstantPoolAndScanForTypeChanges(LirCode<EV> code) {
@@ -152,12 +164,16 @@
     // fields/methods that need to be examined.
     boolean hasPotentialRewrittenMethod = false;
     for (LirConstant constant : code.getConstantPool()) {
+      // RecordFieldValuesPayload is lowered to NewArrayEmpty before lens code rewriting any LIR.
+      assert !(constant instanceof RecordFieldValuesPayload);
       if (constant instanceof DexType) {
         onTypeReference((DexType) constant);
       } else if (constant instanceof DexField) {
         onFieldReference((DexField) constant);
       } else if (constant instanceof DexCallSite) {
         onCallSiteReference((DexCallSite) constant);
+      } else if (constant instanceof DexMethodHandle) {
+        onMethodHandleReference((DexMethodHandle) constant);
       } else if (constant instanceof DexProto) {
         onProtoReference((DexProto) constant);
       } else if (!hasPotentialRewrittenMethod && constant instanceof DexMethod) {
@@ -265,4 +281,13 @@
             byteWriter.toByteArray());
     return newCode;
   }
+
+  private LirCode<EV> rewriteTryCatchTable(LirCode<EV> code) {
+    TryCatchTable tryCatchTable = code.getTryCatchTable();
+    if (tryCatchTable == null) {
+      return code;
+    }
+    TryCatchTable newTryCatchTable = tryCatchTable.rewriteWithLens(graphLens, codeLens);
+    return code.newCodeWithRewrittenTryCatchTable(newTryCatchTable);
+  }
 }