Coalesce equivalent catch handlers prior to register allocation

Bug: b/385088500
Change-Id: I5ac40efc9c5ffd84f3c003d0b32c9d7289c06716
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/PeepholeOptimizer.java b/src/main/java/com/android/tools/r8/ir/optimize/PeepholeOptimizer.java
index 932bf0e..b562fbb 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/PeepholeOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/PeepholeOptimizer.java
@@ -3,6 +3,8 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.ir.optimize;
 
+import static com.android.tools.r8.ir.regalloc.LiveIntervals.NO_REGISTER;
+
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DebugLocalInfo;
 import com.android.tools.r8.ir.code.BasicBlock;
@@ -231,13 +233,45 @@
   public static void shareIdenticalBlockSuffix(
       IRCode code, RegisterAllocator allocator, int overhead) {
     Collection<BasicBlock> blocks = code.blocks;
-    BasicBlock normalExit = null;
     List<BasicBlock> normalExits = code.computeNormalExitBlocks();
+    Set<BasicBlock> syntheticNormalExits = Sets.newIdentityHashSet();
     if (normalExits.size() > 1) {
-      normalExit = new BasicBlock(code.metadata());
-      normalExit.getMutablePredecessors().addAll(normalExits);
-      blocks = new ArrayList<>(code.blocks);
-      blocks.add(normalExit);
+      if (code.context().getReturnType().isVoidType()
+          || code.getConversionOptions().isGeneratingClassFiles()) {
+        BasicBlock syntheticNormalExit = new BasicBlock(code.metadata());
+        syntheticNormalExit.getMutablePredecessors().addAll(normalExits);
+        syntheticNormalExits.add(syntheticNormalExit);
+      } else {
+        Int2ReferenceMap<List<BasicBlock>> normalExitPartitioning =
+            new Int2ReferenceOpenHashMap<>();
+        for (BasicBlock block : normalExits) {
+          int returnRegister =
+              block
+                  .exit()
+                  .asReturn()
+                  .returnValue()
+                  .getLiveIntervals()
+                  .getSplitCovering(block.exit().getNumber())
+                  .getRegister();
+          assert returnRegister != NO_REGISTER;
+          List<BasicBlock> blocksWithReturnRegister;
+          if (normalExitPartitioning.containsKey(returnRegister)) {
+            blocksWithReturnRegister = normalExitPartitioning.get(returnRegister);
+          } else {
+            blocksWithReturnRegister = new ArrayList<>();
+            normalExitPartitioning.put(returnRegister, blocksWithReturnRegister);
+          }
+          blocksWithReturnRegister.add(block);
+        }
+        for (List<BasicBlock> blocksWithSameReturnRegister : normalExitPartitioning.values()) {
+          BasicBlock syntheticNormalExit = new BasicBlock(code.metadata());
+          syntheticNormalExit.getMutablePredecessors().addAll(blocksWithSameReturnRegister);
+          syntheticNormalExits.add(syntheticNormalExit);
+        }
+      }
+      blocks = new ArrayList<>(code.getBlocks().size() + syntheticNormalExits.size());
+      blocks.addAll(code.getBlocks());
+      blocks.addAll(syntheticNormalExits);
     }
     do {
       Map<BasicBlock, BasicBlock> newBlocks = new IdentityHashMap<>();
@@ -295,7 +329,7 @@
                   code,
                   commonSuffixSize,
                   predsWithSameLastInstruction,
-                  block == normalExit ? null : block,
+                  syntheticNormalExits.contains(block) ? null : block,
                   allocator);
           newBlocks.put(predsWithSameLastInstruction.get(0), newBlock);
         }
diff --git a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
index bc8b11b..ac612a7 100644
--- a/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
+++ b/src/main/java/com/android/tools/r8/ir/regalloc/LinearScanRegisterAllocator.java
@@ -367,6 +367,7 @@
     constrainArgumentIntervals();
     insertRangeInvokeMoves();
     ImmutableList<BasicBlock> blocks = computeLivenessInformation();
+    dedupCatchHandlerBlocks();
     timing.end();
     timing.begin("Allocate");
     performAllocation();
@@ -3428,6 +3429,77 @@
     code.blocks.forEach(BasicBlock::clearUserInfo);
   }
 
+  private void dedupCatchHandlerBlocks() {
+    List<BasicBlock> candidateBlocks = new ArrayList<>();
+    for (BasicBlock block : code.getBlocks()) {
+      if (block.hasUniquePredecessor()
+          && block.getUniquePredecessor().hasCatchSuccessor(block)
+          && liveAtEntrySets.get(block).isEmpty()
+          && block.size() <= 2) {
+        candidateBlocks.add(block);
+      }
+    }
+    if (candidateBlocks.isEmpty()) {
+      return;
+    }
+    Set<BasicBlock> removedBlocks = Sets.newIdentityHashSet();
+    for (BasicBlock candidateBlock : candidateBlocks) {
+      assert !removedBlocks.contains(candidateBlock);
+      BasicBlock equivalentBlock = null;
+      for (BasicBlock block : candidateBlocks) {
+        if (block == candidateBlock || removedBlocks.contains(block)) {
+          continue;
+        }
+        if (isEquivalentCatchHandlers(candidateBlock, block)) {
+          equivalentBlock = block;
+          break;
+        }
+      }
+      if (equivalentBlock == null) {
+        continue;
+      }
+      assert !candidateBlock.hasCatchHandlers();
+      removedBlocks.add(candidateBlock);
+      for (BasicBlock tryBlock : candidateBlock.getPredecessors()) {
+        tryBlock.replaceSuccessor(candidateBlock, equivalentBlock);
+        if (!equivalentBlock.getPredecessors().contains(tryBlock)) {
+          equivalentBlock.getMutablePredecessors().add(tryBlock);
+        }
+      }
+      for (BasicBlock successor : candidateBlock.getSuccessors()) {
+        int index = successor.getPredecessors().indexOf(candidateBlock);
+        successor.getMutablePredecessors().remove(index);
+        for (Phi phi : successor.getPhis()) {
+          phi.removeOperand(index);
+        }
+      }
+    }
+    code.removeBlocks(removedBlocks);
+  }
+
+  // TODO(b/153139043): Generalize this. Maybe use BasicBlock subsumption.
+  private boolean isEquivalentCatchHandlers(BasicBlock block, BasicBlock other) {
+    assert liveAtEntrySets.get(block).isEmpty();
+    assert liveAtEntrySets.get(other).isEmpty();
+    if (block.size() != other.size() || block.size() > 2) {
+      return false;
+    }
+    if (block.size() == 2) {
+      if (!block.entry().isMoveException() || !other.entry().isMoveException()) {
+        return false;
+      }
+    }
+    if (block.exit().isGoto()
+        && other.exit().isGoto()
+        && block.getUniqueNormalSuccessor() == other.getUniqueNormalSuccessor()) {
+      return true;
+    }
+    if (block.exit().isReturn() && other.exit().isReturn()) {
+      return true;
+    }
+    return false;
+  }
+
   // Rewrites casts on the form "lhs = (T) rhs" into "(T) rhs" and replaces the uses of lhs by rhs.
   // This transformation helps to ensure that we do not insert unnecessary moves in bridge methods
   // with an invoke-range instruction, since all the arguments to the invoke-range instruction will
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/CatchHandlerCoalescingAfterSplitReturnRewriterTest.java b/src/test/java/com/android/tools/r8/ir/optimize/CatchHandlerCoalescingAfterSplitReturnRewriterTest.java
index 6ea9c73..fa241ae 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/CatchHandlerCoalescingAfterSplitReturnRewriterTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/CatchHandlerCoalescingAfterSplitReturnRewriterTest.java
@@ -51,8 +51,7 @@
               assertThat(testMethodSubject, isPresent());
 
               DexCode code = testMethodSubject.getMethod().getCode().asDexCode();
-              // TODO(b/384848525): Should be 1.
-              assertEquals(forceSplitReturnRewriter ? 2 : 1, code.getTries().length);
+              assertEquals(1, code.getTries().length);
             });
   }
 
@@ -63,6 +62,7 @@
         doStuff();
         doStuff();
       } catch (Exception e) {
+        System.out.println(e);
         return e;
       }
       return null;