Refactor lazy map rewrite in LirCode

Change-Id: Ib96ddab728ff0b9ac1059724b017222f2e69e8d2
diff --git a/src/main/java/com/android/tools/r8/lightir/LirCode.java b/src/main/java/com/android/tools/r8/lightir/LirCode.java
index 1819495..29c7ca6 100644
--- a/src/main/java/com/android/tools/r8/lightir/LirCode.java
+++ b/src/main/java/com/android/tools/r8/lightir/LirCode.java
@@ -36,9 +36,9 @@
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.utils.ArrayUtils;
 import com.android.tools.r8.utils.ComparatorUtils;
+import com.android.tools.r8.utils.FastMapUtils;
 import com.android.tools.r8.utils.IntBox;
 import com.android.tools.r8.utils.InternalOptions;
-import com.android.tools.r8.utils.ObjectUtils;
 import com.android.tools.r8.utils.RetracerForCodePrinting;
 import com.android.tools.r8.utils.structural.CompareToVisitor;
 import com.android.tools.r8.utils.structural.HashingVisitor;
@@ -202,28 +202,11 @@
     }
 
     public TryCatchTable rewriteWithLens(GraphLens graphLens, GraphLens codeLens) {
-      Int2ReferenceMap<CatchHandlers<Integer>> newTryCatchHandlers = null;
-      for (Int2ReferenceMap.Entry<CatchHandlers<Integer>> entry :
-          tryCatchHandlers.int2ReferenceEntrySet()) {
-        int block = entry.getIntKey();
-        CatchHandlers<Integer> blockHandlers = entry.getValue();
-        CatchHandlers<Integer> newBlockHandlers =
-            blockHandlers.rewriteWithLens(graphLens, codeLens);
-        if (newTryCatchHandlers == null) {
-          if (ObjectUtils.identical(newBlockHandlers, blockHandlers)) {
-            continue;
-          }
-          newTryCatchHandlers = new Int2ReferenceOpenHashMap<>(tryCatchHandlers.size());
-          for (Int2ReferenceMap.Entry<CatchHandlers<Integer>> previousEntry :
-              tryCatchHandlers.int2ReferenceEntrySet()) {
-            if (previousEntry == entry) {
-              break;
-            }
-            newTryCatchHandlers.put(previousEntry.getIntKey(), previousEntry.getValue());
-          }
-        }
-        newTryCatchHandlers.put(block, newBlockHandlers);
-      }
+      Int2ReferenceMap<CatchHandlers<Integer>> newTryCatchHandlers =
+          FastMapUtils.mapInt2ReferenceOpenHashMapOrElse(
+              tryCatchHandlers,
+              (block, blockHandlers) -> blockHandlers.rewriteWithLens(graphLens, codeLens),
+              null);
       return newTryCatchHandlers != null ? new TryCatchTable(newTryCatchHandlers) : this;
     }
 
diff --git a/src/main/java/com/android/tools/r8/utils/FastMapUtils.java b/src/main/java/com/android/tools/r8/utils/FastMapUtils.java
new file mode 100644
index 0000000..e64f202
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/FastMapUtils.java
@@ -0,0 +1,60 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.utils;
+
+import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
+import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
+import java.util.Iterator;
+import java.util.function.Function;
+
+public class FastMapUtils {
+
+  public static <V> Int2ReferenceMap<V> destructiveMapValues(
+      Int2ReferenceMap<V> map, Function<V, V> valueMapper) {
+    Iterator<Int2ReferenceMap.Entry<V>> iterator = map.int2ReferenceEntrySet().iterator();
+    while (iterator.hasNext()) {
+      Int2ReferenceMap.Entry<V> entry = iterator.next();
+      V newValue = valueMapper.apply(entry.getValue());
+      if (newValue != null) {
+        entry.setValue(newValue);
+      } else {
+        iterator.remove();
+      }
+    }
+    return map;
+  }
+
+  public static <V> Int2ReferenceMap<V> mapInt2ReferenceOpenHashMapOrElse(
+      Int2ReferenceMap<V> map,
+      IntObjToObjFunction<V, V> valueMapper,
+      Int2ReferenceMap<V> defaultValue) {
+    Int2ReferenceMap<V> newMap = null;
+    Iterator<Int2ReferenceMap.Entry<V>> iterator = map.int2ReferenceEntrySet().iterator();
+    while (iterator.hasNext()) {
+      Int2ReferenceMap.Entry<V> entry = iterator.next();
+      int key = entry.getIntKey();
+      V value = entry.getValue();
+      V newValue = valueMapper.apply(key, value);
+      if (newMap == null) {
+        if (newValue == value) {
+          continue;
+        }
+        // This is the first entry where the value has been changed. Create the new map and copy
+        // over previous entries that did not change.
+        Int2ReferenceMap<V> newFinalMap = new Int2ReferenceOpenHashMap<>(map.size());
+        CollectionUtils.forEachUntilExclusive(
+            map.int2ReferenceEntrySet(),
+            previousEntry -> newFinalMap.put(previousEntry.getIntKey(), previousEntry.getValue()),
+            entry);
+        newMap = newFinalMap;
+      }
+      if (newValue != null) {
+        newMap.put(key, newValue);
+      } else {
+        iterator.remove();
+      }
+    }
+    return newMap != null ? newMap : defaultValue;
+  }
+}