Do not reuse move-exception for different exception types.

Art on Android N through P will use the type of the first
guard targeting an exception handler as the type of the
exception. That will lead to incorrect optimizations if
we have the same move-exception instruction targeted with
different guard types.

R=ricow@google.com, zerny@google.com

Bug: 120164595
Change-Id: I53114ac093535454ddaf0d0ed06c81e42f338761
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
index 99833f4..1dcd217 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
@@ -17,6 +17,7 @@
 import com.android.tools.r8.ir.conversion.DexBuilder;
 import com.android.tools.r8.ir.conversion.IRBuilder;
 import com.android.tools.r8.utils.CfgPrinter;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.StringUtils.BraceType;
@@ -1297,10 +1298,13 @@
   }
 
   public static BasicBlock createRethrowBlock(
-      IRCode code, Position position, TypeLatticeElement guardTypeLattice) {
+      IRCode code, Position position, DexType guard, AppInfo appInfo, InternalOptions options) {
+    TypeLatticeElement guardTypeLattice = TypeLatticeElement.fromDexType(guard, false, appInfo);
     BasicBlock block = new BasicBlock();
     MoveException moveException = new MoveException(
-        new Value(code.valueNumberGenerator.next(), guardTypeLattice, null));
+        new Value(code.valueNumberGenerator.next(), guardTypeLattice, null),
+        guard,
+        options);
     moveException.setPosition(position);
     Throw throwInstruction = new Throw(moveException.outValue);
     throwInstruction.setPosition(position);
@@ -1518,7 +1522,10 @@
    * Clone catch successors from `fromBlock` into this block.
    */
   public void copyCatchHandlers(
-      IRCode code, ListIterator<BasicBlock> blockIterator, BasicBlock fromBlock) {
+      IRCode code,
+      ListIterator<BasicBlock> blockIterator,
+      BasicBlock fromBlock,
+      InternalOptions options) {
     if (catchHandlers != null && catchHandlers.hasCatchAll()) {
       return;
     }
@@ -1538,7 +1545,8 @@
       catchSuccessor.splitCriticalExceptionEdges(
           code.getHighestBlockNumber() + 1,
           code.valueNumberGenerator,
-          blockIterator::add);
+          blockIterator::add,
+          options);
     }
   }
 
@@ -1558,16 +1566,19 @@
   public int splitCriticalExceptionEdges(
       int nextBlockNumber,
       ValueNumberGenerator valueNumberGenerator,
-      Consumer<BasicBlock> onNewBlock) {
+      Consumer<BasicBlock> onNewBlock,
+      InternalOptions options) {
     List<BasicBlock> predecessors = getMutablePredecessors();
     boolean hasMoveException = entry().isMoveException();
     TypeLatticeElement exceptionTypeLattice = null;
+    DexType exceptionType = null;
     MoveException move = null;
     Position position = entry().getPosition();
     if (hasMoveException) {
       // Remove the move-exception instruction.
       move = entry().asMoveException();
       exceptionTypeLattice = move.outValue().getTypeLattice();
+      exceptionType = move.getExceptionType();
       assert move.getDebugValues().isEmpty();
       getInstructions().remove(0);
     }
@@ -1588,7 +1599,7 @@
             exceptionTypeLattice,
             move.getLocalInfo());
         values.add(value);
-        MoveException newMove = new MoveException(value);
+        MoveException newMove = new MoveException(value, exceptionType, options);
         newBlock.add(newMove);
         newMove.setPosition(position);
       }
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionIterator.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionIterator.java
index af2dae6..9102ec2 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionIterator.java
@@ -278,7 +278,7 @@
         } else {
           nextBlock = null;
         }
-        currentBlock.copyCatchHandlers(code, blocksIterator, invokeBlock);
+        currentBlock.copyCatchHandlers(code, blocksIterator, invokeBlock, code.options);
         if (nextBlock != null) {
           BasicBlock b = blocksIterator.next();
           assert b == nextBlock;
@@ -311,7 +311,7 @@
       if (inlinedBlock.hasCatchHandlers()) {
         // The block already has catch handlers, so it has only one throwing instruction, and no
         // splitting is required.
-        inlinedBlock.copyCatchHandlers(code, blocksIterator, invokeBlock);
+        inlinedBlock.copyCatchHandlers(code, blocksIterator, invokeBlock, code.options);
       } else {
         // The block does not have catch handlers, so it can have several throwing instructions.
         // Therefore the block must be split after each throwing instruction, and the catch
diff --git a/src/main/java/com/android/tools/r8/ir/code/MoveException.java b/src/main/java/com/android/tools/r8/ir/code/MoveException.java
index 7cdcd76..fae3557 100644
--- a/src/main/java/com/android/tools/r8/ir/code/MoveException.java
+++ b/src/main/java/com/android/tools/r8/ir/code/MoveException.java
@@ -7,21 +7,22 @@
 import com.android.tools.r8.cf.TypeVerificationHelper;
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.graph.AppInfo;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.conversion.CfBuilder;
 import com.android.tools.r8.ir.conversion.DexBuilder;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
 import com.android.tools.r8.ir.optimize.InliningConstraints;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
+import com.android.tools.r8.utils.InternalOptions;
 
 public class MoveException extends Instruction {
+  private final DexType exceptionType;
+  private final InternalOptions options;
 
-  public MoveException(Value dest) {
+  public MoveException(Value dest, DexType exceptionType, InternalOptions options) {
     super(dest);
+    this.exceptionType = exceptionType;
+    this.options = options;
     dest.markNeverNull();
   }
 
@@ -48,7 +49,13 @@
 
   @Override
   public boolean identicalNonValueNonPositionParts(Instruction other) {
-    return other.isMoveException();
+    if (!other.isMoveException()) {
+      return false;
+    }
+    if (options.canHaveExceptionTypeBug()) {
+      return other.asMoveException().exceptionType == exceptionType;
+    }
+    return true;
   }
 
   @Override
@@ -88,34 +95,17 @@
     return true;
   }
 
-  public static Set<DexType> collectExceptionTypes(
-      BasicBlock currentBlock, DexItemFactory dexItemFactory) {
-    Set<DexType> exceptionTypes = new HashSet<>(currentBlock.getPredecessors().size());
-    for (BasicBlock block : currentBlock.getPredecessors()) {
-      int size = block.getCatchHandlers().size();
-      List<BasicBlock> targets = block.getCatchHandlers().getAllTargets();
-      List<DexType> guards = block.getCatchHandlers().getGuards();
-      for (int i = 0; i < size; i++) {
-        if (targets.get(i) == currentBlock) {
-          DexType guard = guards.get(i);
-          exceptionTypes.add(
-              guard == DexItemFactory.catchAllType
-                  ? dexItemFactory.throwableType
-                  : guard);
-        }
-      }
-    }
-    return exceptionTypes;
-  }
-
   @Override
   public DexType computeVerificationType(TypeVerificationHelper helper) {
-    return helper.join(collectExceptionTypes(getBlock(), helper.getFactory()));
+    return exceptionType;
   }
 
   @Override
   public TypeLatticeElement evaluate(AppInfo appInfo) {
-    Set<DexType> exceptionTypes = collectExceptionTypes(getBlock(), appInfo.dexItemFactory);
-    return TypeLatticeElement.joinTypes(exceptionTypes, false, appInfo);
+    return TypeLatticeElement.fromDexType(exceptionType, false, appInfo);
+  }
+
+  public DexType getExceptionType() {
+    return exceptionType;
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
index 007cac1..835399c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
@@ -101,6 +101,7 @@
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.IteratorUtils;
 import com.android.tools.r8.utils.Pair;
+import com.google.common.collect.Sets;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceAVLTreeMap;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
@@ -116,7 +117,6 @@
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.IdentityHashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.ListIterator;
@@ -193,11 +193,14 @@
   }
 
   private static class MoveExceptionWorklistItem extends WorklistItem {
+    private final DexType guard;
     private final int sourceOffset;
     private final int targetOffset;
 
-    private MoveExceptionWorklistItem(BasicBlock block, int sourceOffset, int targetOffset) {
+    private MoveExceptionWorklistItem(
+        BasicBlock block, DexType guard, int sourceOffset, int targetOffset) {
       super(block, -1);
+      this.guard = guard;
       this.sourceOffset = sourceOffset;
       this.targetOffset = targetOffset;
     }
@@ -707,12 +710,9 @@
     // construction and prior to building the IR.
     for (BlockInfo info : targets.values()) {
       if (info != null && info.block == block) {
-        assert info.predecessorCount() == block.getPredecessors().size();
+        assert info.predecessorCount() == nonSplitPredecessorCount(block);
         assert info.normalSuccessors.size() == block.getNormalSuccessors().size();
-        if (block.hasCatchHandlers()) {
-          assert info.exceptionalSuccessors.size()
-              == block.getCatchHandlers().getUniqueTargets().size();
-        } else {
+        if (!block.hasCatchHandlers()) {
           assert !block.canThrow()
               || info.exceptionalSuccessors.isEmpty()
               || (info.exceptionalSuccessors.size() == 1
@@ -726,6 +726,38 @@
     return true;
   }
 
+  private int nonSplitPredecessorCount(BasicBlock block) {
+    Set<BasicBlock> set = Sets.newIdentityHashSet();
+    for (BasicBlock predecessor : block.getPredecessors()) {
+      if (offsets.containsKey(predecessor)) {
+        set.add(predecessor);
+      } else {
+        assert predecessor.getSuccessors().size() == 1;
+        assert predecessor.getPredecessors().size() == 1;
+        assert trivialGotoBlockPotentiallyWithMoveException(predecessor);
+        // Combine the exceptional edges to just one, for normal edges that have been split
+        // record them separately. That means that we are checking that there are the expected
+        // number of normal edges and some number of exceptional edges (which we count as one edge).
+        if (predecessor.getPredecessors().get(0).hasCatchSuccessor(predecessor)) {
+          set.add(predecessor.getPredecessors().get(0));
+        } else {
+          set.add(predecessor);
+        }
+      }
+    }
+    return set.size();
+  }
+
+  // Check that all instructions are either move-exception, goto or debug instructions.
+  private boolean trivialGotoBlockPotentiallyWithMoveException(BasicBlock block) {
+    for (Instruction instruction : block.getInstructions()) {
+      assert instruction.isMoveException()
+          || instruction.isGoto()
+          || instruction.isDebugInstruction();
+    }
+    return true;
+  }
+
   private void processWorklist() {
     for (WorklistItem item = ssaWorklist.poll(); item != null; item = ssaWorklist.poll()) {
       if (item.block.isFilled()) {
@@ -782,10 +814,10 @@
     int moveExceptionDest = source.getMoveExceptionRegister(targetIndex);
     Position position = source.getCanonicalDebugPositionAtOffset(moveExceptionItem.targetOffset);
     if (moveExceptionDest >= 0) {
-      Set<DexType> exceptionTypes = MoveException.collectExceptionTypes(currentBlock, getFactory());
-      TypeLatticeElement typeLattice = TypeLatticeElement.joinTypes(exceptionTypes, false, appInfo);
+      TypeLatticeElement typeLattice =
+          TypeLatticeElement.fromDexType(moveExceptionItem.guard, false, appInfo);
       Value out = writeRegister(moveExceptionDest, typeLattice, ThrowingInfo.NO_THROW, null);
-      MoveException moveException = new MoveException(out);
+      MoveException moveException = new MoveException(out, moveExceptionItem.guard, options);
       moveException.setPosition(position);
       currentBlock.add(moveException);
     }
@@ -2196,21 +2228,22 @@
         assert !throwingInstructionInCurrentBlock;
         throwingInstructionInCurrentBlock = true;
         List<BasicBlock> targets = new ArrayList<>(catchHandlers.getAllTargets().size());
-        // Construct unique move-exception header blocks for each unique target.
-        Map<BasicBlock, BasicBlock> moveExceptionHeaders =
-            new IdentityHashMap<>(catchHandlers.getUniqueTargets().size());
-        for (int targetOffset : catchHandlers.getAllTargets()) {
-          BasicBlock target = getTarget(targetOffset);
-          BasicBlock header = moveExceptionHeaders.get(target);
-          if (header == null) {
-            header = new BasicBlock();
-            header.incrementUnfilledPredecessorCount();
-            moveExceptionHeaders.put(target, header);
-            ssaWorklist.add(
-                new MoveExceptionWorklistItem(header, currentInstructionOffset, targetOffset));
-          }
+        Set<BasicBlock> moveExceptionTargets = Sets.newIdentityHashSet();
+        catchHandlers.forEach((type, targetOffset) -> {
+          DexType exceptionType = type == options.itemFactory.catchAllType
+              ? options.itemFactory.throwableType
+              : type;
+          BasicBlock header = new BasicBlock();
+          header.incrementUnfilledPredecessorCount();
+          ssaWorklist.add(
+              new MoveExceptionWorklistItem(
+                  header, exceptionType, currentInstructionOffset, targetOffset));
           targets.add(header);
-        }
+          BasicBlock target = getTarget(targetOffset);
+          if (!moveExceptionTargets.add(target)) {
+            target.incrementUnfilledPredecessorCount();
+          }
+        });
         currentBlock.linkCatchSuccessors(catchHandlers.getGuards(), targets);
       }
     }
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
index b92d772..5814486 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
@@ -332,6 +332,6 @@
     BasicBlock currentBlock = newInstance.getBlock();
     BasicBlock nextBlock = instructions.split(code, blocks);
     assert !instructions.hasNext();
-    nextBlock.copyCatchHandlers(code, blocks, currentBlock);
+    nextBlock.copyCatchHandlers(code, blocks, currentBlock, code.options);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/StringConcatRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/StringConcatRewriter.java
index 293065e..2be7320 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/StringConcatRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/StringConcatRewriter.java
@@ -394,7 +394,7 @@
       }
       // Copy catch handlers after all blocks are split.
       for (BasicBlock newBlock : newBlocks) {
-        newBlock.copyCatchHandlers(code, blocks, currentBlock);
+        newBlock.copyCatchHandlers(code, blocks, currentBlock, code.options);
       }
     }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 0349327..5c1fae1 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -321,7 +321,9 @@
       BasicBlock rethrowBlock = BasicBlock.createRethrowBlock(
           code,
           lastSelfRecursiveCall.getPosition(),
-          TypeLatticeElement.fromDexType(guard, true, appInfo));
+          guard,
+          appInfo,
+          options);
       code.blocks.add(rethrowBlock);
       // Add catch handler to the block containing the last recursive call.
       newBlock.addCatchHandler(rethrowBlock, guard);
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index c0c6e59..2de450f 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -875,4 +875,17 @@
     // Marshmallow and Nougat arm64 devices and they do not have the bug.
     return minApiLevel < AndroidApiLevel.M.getLevel();
   }
+
+  // The Art VM for Android N through P has a bug in the JIT that means that if the same
+  // exception block with a move-exception instruction is targeted with more than one type
+  // of exception the JIT will incorrectly assume that the exception object has one of these
+  // types and will optimize based on that one type instead of taking all the types into account.
+  //
+  // In order to workaround that, we always generate distinct move-exception instructions for
+  // distinct dex types.
+  //
+  // See b/120164595.
+  public boolean canHaveExceptionTypeBug() {
+    return minApiLevel < AndroidApiLevel.Q.getLevel();
+  }
 }
diff --git a/src/test/java/com/android/tools/r8/ir/SplitBlockTest.java b/src/test/java/com/android/tools/r8/ir/SplitBlockTest.java
index 90aad41..c416c6e 100644
--- a/src/test/java/com/android/tools/r8/ir/SplitBlockTest.java
+++ b/src/test/java/com/android/tools/r8/ir/SplitBlockTest.java
@@ -201,7 +201,7 @@
 
   public void runCatchHandlerTest(boolean codeThrows, boolean twoGuards) throws Exception {
     final int secondBlockInstructions = 4;
-    final int initialBlockCount = 6;
+    final int initialBlockCount = twoGuards ? 7 : 6;
     // Try split between all instructions in second block.
     for (int i = 1; i < secondBlockInstructions; i++) {
       TestApplication test = codeWithCatchHandlers(codeThrows, twoGuards);
@@ -240,7 +240,7 @@
   public void runCatchHandlerSplitThreeTest(boolean codeThrows, boolean twoGuards)
       throws Exception {
     final int secondBlockInstructions = 4;
-    final int initialBlockCount = 6;
+    final int initialBlockCount = twoGuards ? 7 : 6;
     // Try split out all instructions in second block.
     for (int i = 1; i < secondBlockInstructions - 1; i++) {
       TestApplication test = codeWithCatchHandlers(codeThrows, twoGuards);
diff --git a/src/test/java/com/android/tools/r8/regress/b120164595/B120164595.java b/src/test/java/com/android/tools/r8/regress/b120164595/B120164595.java
index 08e86e6..9d633f2 100644
--- a/src/test/java/com/android/tools/r8/regress/b120164595/B120164595.java
+++ b/src/test/java/com/android/tools/r8/regress/b120164595/B120164595.java
@@ -4,16 +4,14 @@
 
 package com.android.tools.r8.regress.b120164595;
 
-import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertTrue;
-import static org.junit.matchers.JUnitMatchers.containsString;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 
 import com.android.tools.r8.CompilationFailedException;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestCompileResult;
 import com.android.tools.r8.ToolHelper.DexVm;
 import com.android.tools.r8.ToolHelper.ProcessResult;
-import com.google.common.collect.ImmutableList;
 import java.io.IOException;
 import org.junit.Test;
 
@@ -72,8 +70,7 @@
         },
         DexVm.ART_9_0_0_HOST
     );
-    // TODO(120164595): Remove when workaround lands.
-    assertNotEquals(artResult.exitCode, 0);
-    assertTrue(artResult.stderr.contains("Expected NullPointerException"));
+    assertEquals(0, artResult.exitCode);
+    assertFalse(artResult.stderr.contains("Expected NullPointerException"));
   }
 }