Improve the switch rewriting

For switch instructions with one or two outlier cases, but otherwise
consecutive cases, the switch is rewritten to a switch with only
consecutive cases preceeded by one or two if instructions checking the
outlier cases.

This change adds some simple support for estimating the size of the
final dex size of some instructions.

Change-Id: I3a1fea7bb249fa3cc8cd905acf4e4245d8c11b9b
diff --git a/src/main/java/com/android/tools/r8/code/Base1Format.java b/src/main/java/com/android/tools/r8/code/Base1Format.java
index b62853d..9d88f53 100644
--- a/src/main/java/com/android/tools/r8/code/Base1Format.java
+++ b/src/main/java/com/android/tools/r8/code/Base1Format.java
@@ -5,6 +5,8 @@
 
 public abstract class Base1Format extends Instruction {
 
+  public static final int SIZE = 1;
+
   public Base1Format(BytecodeStream stream) {
     super(stream);
   }
@@ -12,6 +14,6 @@
   protected Base1Format() {}
 
   public int getSize() {
-    return 1;
+    return SIZE;
   }
 }
diff --git a/src/main/java/com/android/tools/r8/code/Base2Format.java b/src/main/java/com/android/tools/r8/code/Base2Format.java
index f8a48ac..1241e6a 100644
--- a/src/main/java/com/android/tools/r8/code/Base2Format.java
+++ b/src/main/java/com/android/tools/r8/code/Base2Format.java
@@ -5,6 +5,8 @@
 
 public abstract class Base2Format extends Instruction {
 
+  public static final int SIZE = 2;
+
   protected Base2Format() {}
 
   public Base2Format(BytecodeStream stream) {
@@ -12,6 +14,6 @@
   }
 
   public int getSize() {
-    return 2;
+    return SIZE;
   }
 }
diff --git a/src/main/java/com/android/tools/r8/code/Base3Format.java b/src/main/java/com/android/tools/r8/code/Base3Format.java
index 34bda57..c1618f5 100644
--- a/src/main/java/com/android/tools/r8/code/Base3Format.java
+++ b/src/main/java/com/android/tools/r8/code/Base3Format.java
@@ -5,6 +5,8 @@
 
 public abstract class Base3Format extends Instruction {
 
+  public static final int SIZE = 3;
+
   protected Base3Format() {}
 
   public Base3Format(BytecodeStream stream) {
@@ -12,6 +14,6 @@
   }
 
   public int getSize() {
-    return 3;
+    return SIZE;
   }
 }
\ No newline at end of file
diff --git a/src/main/java/com/android/tools/r8/code/Base4Format.java b/src/main/java/com/android/tools/r8/code/Base4Format.java
index 7cdf1c5..f3448fa 100644
--- a/src/main/java/com/android/tools/r8/code/Base4Format.java
+++ b/src/main/java/com/android/tools/r8/code/Base4Format.java
@@ -5,6 +5,8 @@
 
 public abstract class Base4Format extends Instruction {
 
+  public static final int SIZE = 4;
+
   protected Base4Format() {}
 
   public Base4Format(BytecodeStream stream) {
@@ -12,6 +14,6 @@
   }
 
   public int getSize() {
-    return 4;
+    return SIZE;
   }
 }
\ No newline at end of file
diff --git a/src/main/java/com/android/tools/r8/code/Base5Format.java b/src/main/java/com/android/tools/r8/code/Base5Format.java
index cc67572..10ddc5e 100644
--- a/src/main/java/com/android/tools/r8/code/Base5Format.java
+++ b/src/main/java/com/android/tools/r8/code/Base5Format.java
@@ -5,6 +5,8 @@
 
 public abstract class Base5Format extends Instruction {
 
+  public static final int SIZE = 5;
+
   protected Base5Format() {}
 
   public Base5Format(BytecodeStream stream) {
@@ -12,6 +14,6 @@
   }
 
   public int getSize() {
-    return 5;
+    return SIZE;
   }
 }
\ No newline at end of file
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
index 5f1b33f..2863df6 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
@@ -543,6 +543,13 @@
     return unlinkSingleSuccessor();
   }
 
+  public void detachAllSuccessors() {
+    for (BasicBlock successor : successors) {
+      successor.predecessors.remove(this);
+    }
+    successors.clear();
+  }
+
   public List<BasicBlock> unlink(BasicBlock successor, DominatorTree dominator) {
     assert successors.contains(successor);
     assert successor.predecessors.contains(this);
@@ -935,7 +942,8 @@
    *
    * <p>The constructed basic block has no predecessors and no successors.
    *
-   * @param blockNumber the block number of the goto block
+   * @param blockNumber the block number of the block
+   * @param theIf the if instruction
    */
   public static BasicBlock createIfBlock(int blockNumber, If theIf) {
     BasicBlock block = new BasicBlock();
@@ -945,6 +953,32 @@
     return block;
   }
 
+  /**
+   * Create a new basic block with an instruction followed by an if instruction.
+   *
+   * <p>The constructed basic block has no predecessors and no successors.
+   *
+   * @param blockNumber the block number of the block
+   * @param theIf the if instruction
+   * @param instruction the instruction to place before the if instruction
+   */
+  public static BasicBlock createIfBlock(int blockNumber, If theIf, Instruction instruction) {
+    BasicBlock block = new BasicBlock();
+    block.add(instruction);
+    block.add(theIf);
+    block.close(null);
+    block.setNumber(blockNumber);
+    return block;
+  }
+
+  public static BasicBlock createSwitchBlock(int blockNumber, Switch theSwitch) {
+    BasicBlock block = new BasicBlock();
+    block.add(theSwitch);
+    block.close(null);
+    block.setNumber(blockNumber);
+    return block;
+  }
+
   public boolean isTrivialGoto() {
     return instructions.size() == 1 && exit().isGoto();
   }
diff --git a/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java b/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
index 1880668..e2d41b7 100644
--- a/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
+++ b/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
@@ -121,6 +121,33 @@
     }
   }
 
+  // Estimated size of the resulting dex instruction in code units.
+  public static int estimatedDexSize(ConstType type, long value) {
+    if (MoveType.fromConstType(type) == MoveType.SINGLE) {
+      assert NumberUtils.is32Bit(value);
+      if (NumberUtils.is4Bit(value)) {
+        return Const4.SIZE;
+      } else if (NumberUtils.is16Bit(value)) {
+        return Const16.SIZE;
+      } else if ((value & 0x0000ffffL) == 0) {
+        return ConstHigh16.SIZE;
+      } else {
+        return Const.SIZE;
+      }
+    } else {
+      assert MoveType.fromConstType(type) == MoveType.WIDE;
+      if (NumberUtils.is16Bit(value)) {
+        return ConstWide16.SIZE;
+      } else if ((value & 0x0000ffffffffffffL) == 0) {
+        return ConstWideHigh16.SIZE;
+      } else if (NumberUtils.is32Bit(value)) {
+        return ConstWide32.SIZE;
+      } else {
+        return ConstWide.SIZE;
+      }
+    }
+  }
+
   @Override
   public int maxInValueRegister() {
     assert false : "Const has no register arguments.";
diff --git a/src/main/java/com/android/tools/r8/ir/code/If.java b/src/main/java/com/android/tools/r8/ir/code/If.java
index b533452..3a068f6 100644
--- a/src/main/java/com/android/tools/r8/ir/code/If.java
+++ b/src/main/java/com/android/tools/r8/ir/code/If.java
@@ -115,6 +115,11 @@
     builder.addIf(this);
   }
 
+  // Estimated size of the resulting dex instruction in code units.
+  public static int estimatedDexSize() {
+    return 2;
+  }
+
   @Override
   public String toString() {
     return super.toString() + " " + type + " block " + getTrueTarget().getNumber()
diff --git a/src/main/java/com/android/tools/r8/ir/code/Switch.java b/src/main/java/com/android/tools/r8/ir/code/Switch.java
index 6b47ba2..6def63a 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Switch.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Switch.java
@@ -12,6 +12,8 @@
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.ir.conversion.DexBuilder;
 import com.android.tools.r8.utils.CfgPrinter;
+import com.google.common.primitives.Ints;
+import java.util.List;
 
 public class Switch extends JumpInstruction {
 
@@ -48,29 +50,73 @@
   }
 
   // Number of targets if this switch is emitted as a packed switch.
-  private long numberOfTargetsIfPacked() {
+  private static long numberOfTargetsIfPacked(int keys[]) {
     return ((long) keys[keys.length - 1]) - ((long) keys[0]) + 1;
   }
 
-  private boolean canBePacked() {
+  public static boolean canBePacked(int keys[]) {
     // The size of a switch payload is stored in an ushort in the Dex file.
-    return numberOfTargetsIfPacked() <= Constants.U16BIT_MAX;
+    return numberOfTargetsIfPacked(keys) <= Constants.U16BIT_MAX;
+  }
+
+  // Number of targets if this switch is emitted as a packed switch.
+  public static int numberOfTargetsForPacked(int keys[]) {
+    assert canBePacked(keys);
+    return (int) numberOfTargetsIfPacked(keys);
+  }
+
+  // Size of the switch payload if emitted as packed (in code units).
+  static public long packedPayloadSize(int keys[]) {
+    return (numberOfTargetsForPacked(keys) * 2) + 4;
+  }
+
+  // Size of the switch payload if emitted as sparse (in code units).
+  public static long sparsePayloadSize(int keys[]) {
+    return (keys.length * 4) + 2;
+  }
+
+  /**
+   * Size of the switch payload instruction for the given keys. This will be the payload
+   * size for the smallest encoding of the provided keys.
+   *
+   * @param keys the switch keys
+   * @return Size of the switch payload instruction in code units
+   */
+  public static long payloadSize(List<Integer> keys) {
+    return payloadSize(Ints.toArray(keys));
+  }
+
+  /**
+   * Size of the switch payload instruction for the given keys.
+   *
+   * @see #payloadSize(List)
+   */
+  public static long payloadSize(int keys[]) {
+    long sparse = sparsePayloadSize(keys);
+    if (canBePacked(keys)) {
+      return Math.min(sparse, packedPayloadSize(keys));
+    } else {
+      return sparse;
+    }
+  }
+
+  private boolean canBePacked() {
+    return canBePacked(keys);
   }
 
   // Number of targets if this switch is emitted as a packed switch.
   private int numberOfTargetsForPacked() {
-    assert canBePacked();
-    return (int) numberOfTargetsIfPacked();
+    return numberOfTargetsForPacked(keys);
   }
 
   // Size of the switch payload if emitted as packed (in code units).
   private long packedPayloadSize() {
-    return (numberOfTargetsForPacked() * 2) + 4;
+    return packedPayloadSize(keys);
   }
 
   // Size of the switch payload if emitted as sparse (in code units).
   private long sparsePayloadSize() {
-    return (keys.length * 4) + 2;
+    return sparsePayloadSize(keys);
   }
 
   private boolean emitPacked() {
@@ -121,6 +167,10 @@
     return keys[index];
   }
 
+  public int[] getKeys() {
+    return keys;
+  }
+
   public int[] targetBlockIndices() {
     return targetBlockIndices;
   }
@@ -192,6 +242,7 @@
     StringBuilder builder = new StringBuilder(super.toString()+ "\n");
     for (int i = 0; i < numberOfKeys(); i++) {
       builder.append("          ");
+      builder.append(getKey(i));
       builder.append(" -> ");
       builder.append(targetBlock(i).getNumber());
       builder.append("\n");
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 815efe4..e30e1d0 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -37,6 +37,7 @@
 import com.android.tools.r8.ir.code.ConstInstruction;
 import com.android.tools.r8.ir.code.ConstNumber;
 import com.android.tools.r8.ir.code.ConstString;
+import com.android.tools.r8.ir.code.ConstType;
 import com.android.tools.r8.ir.code.DominatorTree;
 import com.android.tools.r8.ir.code.Goto;
 import com.android.tools.r8.ir.code.IRCode;
@@ -63,6 +64,7 @@
 import com.android.tools.r8.ir.conversion.OptimizationFeedback;
 import com.android.tools.r8.ir.optimize.SwitchUtils.EnumSwitchInfo;
 import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.IteratorUtils;
 import com.android.tools.r8.utils.LongInterval;
 import com.google.common.base.Equivalence;
 import com.google.common.base.Equivalence.Wrapper;
@@ -72,8 +74,13 @@
 import com.google.common.collect.Maps;
 import it.unimi.dsi.fastutil.ints.Int2IntArrayMap;
 import it.unimi.dsi.fastutil.ints.Int2IntMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
 import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.ints.IntIterator;
 import it.unimi.dsi.fastutil.ints.IntList;
+import it.unimi.dsi.fastutil.objects.Object2IntLinkedOpenHashMap;
+import it.unimi.dsi.fastutil.objects.Object2IntMap;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -336,8 +343,130 @@
     }
   }
 
+  // TODO(sgjesse); Move this somewhere else, and reuse it for some of the other switch rewritings.
+  public static class SwitchBuilder {
+    private Value value;
+    private Int2ObjectSortedMap<BasicBlock> keyToTarget = new Int2ObjectAVLTreeMap<>();
+    private BasicBlock fallthrough;
+    private int blockNumber;
+
+    public SwitchBuilder setValue(Value value) {
+      this.value = value;
+      return  this;
+    }
+
+    public SwitchBuilder addKeyAndTarget(int key, BasicBlock target) {
+      keyToTarget.put(key, target);
+      return this;
+    }
+
+    public SwitchBuilder setFallthrough(BasicBlock fallthrough) {
+      this.fallthrough = fallthrough;
+      return this;
+    }
+
+    public SwitchBuilder setBlockNumber(int blockNumber) {
+      this.blockNumber = blockNumber;
+      return  this;
+    }
+
+    public BasicBlock build() {
+      final int NOT_FOUND = -1;
+      Object2IntMap<BasicBlock> targetToSuccessorIndex = new Object2IntLinkedOpenHashMap<>();
+      targetToSuccessorIndex.defaultReturnValue(NOT_FOUND);
+
+      int[] keys = new int[keyToTarget.size()];
+      int[] targetBlockIndices = new int[keyToTarget.size()];
+      // Sort keys descending.
+      int count = 0;
+      IntIterator iter = keyToTarget.keySet().iterator();
+      while (iter.hasNext()) {
+        int key = iter.nextInt();
+        BasicBlock target = keyToTarget.get(key);
+        Integer targetIndex =
+            targetToSuccessorIndex.computeIfAbsent(target, b -> targetToSuccessorIndex.size());
+        keys[count] = key;
+        targetBlockIndices[count] = targetIndex;
+        count++;
+      }
+      Integer fallthroughIndex =
+          targetToSuccessorIndex.computeIfAbsent(fallthrough, b -> targetToSuccessorIndex.size());
+      Switch newSwitch = new Switch(value, keys, targetBlockIndices, fallthroughIndex);
+      BasicBlock newSwitchBlock = BasicBlock.createSwitchBlock(blockNumber, newSwitch);
+      for (BasicBlock successor : targetToSuccessorIndex.keySet()) {
+        newSwitchBlock.link(successor);
+      }
+      return newSwitchBlock;
+    }
+  }
+
+  /**
+   * Covert the switch instruction to a sequence of if instructions checking for a specified
+   * set of keys, followed by a new switch with the remaining keys.
+   */
+  private void convertSwitchToSwitchAndIfs(
+      IRCode code, ListIterator<BasicBlock> blocksIterator, BasicBlock originalBlock,
+      InstructionListIterator iterator, Switch theSwitch, List<Integer> keysToRemove) {
+    // Split the switch instruction into its own block and remove it.
+    iterator.previous();
+    BasicBlock originalSwitchBlock = iterator.split(code, blocksIterator);
+    assert !originalSwitchBlock.hasCatchHandlers();
+    assert originalSwitchBlock.getInstructions().size() == 1;
+    BasicBlock block = blocksIterator.previous();
+    assert block == originalSwitchBlock;
+    blocksIterator.remove();
+
+    int nextBlockNumber = code.getHighestBlockNumber() + 1;
+
+    // Collect targets for the keys to peel off, and create a new switch instruction without
+    // these keys.
+    SwitchBuilder switchBuilder = new SwitchBuilder();
+    List<BasicBlock> peeledOffTargets = new ArrayList<>();
+    for (int i = 0; i < theSwitch.numberOfKeys(); i++) {
+      BasicBlock target = theSwitch.targetBlock(i);
+      if (!keysToRemove.contains(theSwitch.getKey(i))) {
+        switchBuilder.addKeyAndTarget(theSwitch.getKey(i), theSwitch.targetBlock(i));
+      } else {
+        peeledOffTargets.add(target);
+      }
+    }
+    assert peeledOffTargets.size() == keysToRemove.size();
+    switchBuilder.setValue(theSwitch.value());
+    switchBuilder.setFallthrough(theSwitch.fallthroughBlock());
+    switchBuilder.setBlockNumber(nextBlockNumber++);
+    theSwitch.getBlock().detachAllSuccessors();
+    block = theSwitch.getBlock().unlinkSinglePredecessor();
+    assert theSwitch.getBlock().getPredecessors().size() == 0;
+    assert theSwitch.getBlock().getSuccessors().size() == 0;
+    assert block == originalBlock;
+
+    BasicBlock newSwitchBlock = switchBuilder.build();
+
+    // Create if blocks for each of the peeled off keys, and link them into the graph.
+    BasicBlock predecessor = originalBlock;
+    for (int i = 0; i < keysToRemove.size(); i++) {
+      int key = keysToRemove.get(i);
+      BasicBlock peeledOffTarget = peeledOffTargets.get(i);
+      ConstNumber keyConst = code.createIntConstant(key);
+      If theIf = new If(Type.EQ, ImmutableList.of(keyConst.dest(), theSwitch.value()));
+      BasicBlock ifBlock = BasicBlock.createIfBlock(nextBlockNumber++, theIf, keyConst);
+      predecessor.link(ifBlock);
+      ifBlock.link(peeledOffTarget);
+      predecessor = ifBlock;
+      blocksIterator.add(ifBlock);
+      assert !peeledOffTarget.getPredecessors().contains(theSwitch.getBlock());
+    }
+    predecessor.link(newSwitchBlock);
+    blocksIterator.add(newSwitchBlock);
+
+    // The switch fallthrough block is still the same, and it is right after the new switch block.
+    IteratorUtils.peekNext(blocksIterator, newSwitchBlock.exit().fallthroughBlock());
+  }
+
   public void rewriteSwitch(IRCode code) {
-    for (BasicBlock block : code.blocks) {
+    ListIterator<BasicBlock> blocksIterator = code.listIterator();
+    while (blocksIterator.hasNext()) {
+      BasicBlock block = blocksIterator.next();
       InstructionListIterator iterator = block.listIterator();
       while (iterator.hasNext()) {
         Instruction instruction = iterator.next();
@@ -361,6 +490,54 @@
               If theIf = new If(Type.EQ, ImmutableList.of(theSwitch.value(), labelConst.dest()));
               iterator.replaceCurrentInstruction(theIf);
             }
+          } else {
+            // Split keys into outliers and sequences.
+            List<List<Integer>> sequences = new ArrayList<>();
+            List<Integer> outliers = new ArrayList<>();
+
+            List<Integer> current = new ArrayList<>();
+            int[] keys = theSwitch.getKeys();
+            int previousKey = keys[0];
+            current.add(previousKey);
+            for (int i = 1; i < keys.length; i++) {
+              assert current.size() > 0;
+              assert current.get(current.size() - 1) == previousKey;
+              int key = keys[i];
+              if (((long) key - (long) previousKey) > 1) {
+                if (current.size() == 1) {
+                  outliers.add(previousKey);
+                } else {
+                  sequences.add(current);;
+                }
+                current = new ArrayList<>();
+              }
+              current.add(key);
+              previousKey = key;
+            }
+            if (current.size() == 1) {
+              outliers.add(previousKey);
+            } else {
+              sequences.add(current);
+            }
+
+            // Only check for rewrite if there is one sequence and one or two outliers.
+            if (sequences.size() == 1 && outliers.size() <= 2) {
+              // Get the existing dex size for the payload (excluding the switch itself).
+              long currentSize = Switch.payloadSize(keys);
+              // Estimate the dex size of the rewritten payload and the additional if instructions.
+              long rewrittenSize = Switch.payloadSize(sequences.get(0));
+              for (Integer outlier : outliers) {
+                rewrittenSize += ConstNumber.estimatedDexSize(
+                    ConstType.fromMoveType(theSwitch.value().outType()), outlier);
+                rewrittenSize += If.estimatedDexSize();
+              }
+              // Rewrite if smaller.
+              if (rewrittenSize < currentSize) {
+                convertSwitchToSwitchAndIfs(
+                    code, blocksIterator, block, iterator, theSwitch, outliers);
+                assert code.isConsistentSSA();
+              }
+            }
           }
         }
       }
diff --git a/src/main/java/com/android/tools/r8/utils/IteratorUtils.java b/src/main/java/com/android/tools/r8/utils/IteratorUtils.java
new file mode 100644
index 0000000..2bc52a3
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/IteratorUtils.java
@@ -0,0 +1,23 @@
+// Copyright (c) 2017, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils;
+
+import java.util.ListIterator;
+
+public class IteratorUtils {
+  public static <T> T peekPrevious(ListIterator<T> iterator, T element) {
+    T previous = iterator.previous();
+    T next = iterator.next();
+    assert previous == next;
+    return previous;
+  }
+
+  public static <T> T peekNext(ListIterator<T> iterator, T element) {
+    T next = iterator.next();
+    T previous = iterator.previous();
+    assert previous == next;
+    return next;
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/TestBase.java b/src/test/java/com/android/tools/r8/TestBase.java
index 3ca9b1b..4d707a8 100644
--- a/src/test/java/com/android/tools/r8/TestBase.java
+++ b/src/test/java/com/android/tools/r8/TestBase.java
@@ -19,6 +19,7 @@
 import java.io.FileOutputStream;
 import java.io.IOException;
 import java.nio.file.Path;
+import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.ExecutionException;
 import java.util.function.Consumer;
@@ -30,7 +31,7 @@
 public class TestBase {
 
   @Rule
-  public TemporaryFolder temp = new TemporaryFolder();
+  public TemporaryFolder temp = ToolHelper.getTemporaryFolderForTest();
 
   /**
    * Write lines of text to a temporary file.
@@ -114,6 +115,15 @@
   /**
    * Compile an application with R8.
    */
+  protected AndroidApp compileWithR8(AndroidApp app)
+      throws CompilationException, ProguardRuleParserException, ExecutionException, IOException {
+    R8Command command = ToolHelper.prepareR8CommandBuilder(app).build();
+    return ToolHelper.runR8(command);
+  }
+
+  /**
+   * Compile an application with R8.
+   */
   protected AndroidApp compileWithR8(AndroidApp app, Consumer<InternalOptions> optionsConsumer)
       throws CompilationException, ProguardRuleParserException, ExecutionException, IOException {
     R8Command command = ToolHelper.prepareR8CommandBuilder(app).build();
@@ -215,6 +225,13 @@
    * Run application on Art with the specified main class and provided arguments.
    */
   protected String runOnArt(AndroidApp app, Class mainClass, String... args) throws IOException {
+    return runOnArt(app, mainClass, Arrays.asList(args));
+  }
+
+  /**
+   * Run application on Art with the specified main class and provided arguments.
+   */
+  protected String runOnArt(AndroidApp app, Class mainClass, List<String> args) throws IOException {
     Path out = File.createTempFile("junit", ".zip", temp.getRoot()).toPath();
     app.writeToZip(out, OutputMode.Indexed);
     return ToolHelper.runArtNoVerificationErrors(
diff --git a/src/test/java/com/android/tools/r8/jasmin/JasminBuilder.java b/src/test/java/com/android/tools/r8/jasmin/JasminBuilder.java
index 7f65c60..ba5e370 100644
--- a/src/test/java/com/android/tools/r8/jasmin/JasminBuilder.java
+++ b/src/test/java/com/android/tools/r8/jasmin/JasminBuilder.java
@@ -5,7 +5,6 @@
 
 import com.android.tools.r8.dex.ApplicationReader;
 import com.android.tools.r8.graph.DexApplication;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.naming.MemberNaming.FieldSignature;
 import com.android.tools.r8.naming.MemberNaming.MethodSignature;
 import com.android.tools.r8.utils.AndroidApp;
@@ -141,11 +140,17 @@
     return out.toByteArray();
   }
 
+  public List<byte[]> buildClasses() throws Exception {
+    List<byte[]> result = new ArrayList<>();
+    for (ClassBuilder clazz : classes) {
+      result.add(compile(clazz));
+    }
+    return result;
+  }
+
   public AndroidApp build() throws Exception {
     AndroidApp.Builder builder = AndroidApp.builder();
-    for (ClassBuilder clazz : classes) {
-      builder.addClassProgramData(compile(clazz));
-    }
+    builder.addClassProgramData(buildClasses());
     return builder.build();
   }
 
@@ -154,7 +159,6 @@
   }
 
   public DexApplication read(InternalOptions options) throws Exception {
-    DexItemFactory factory = new DexItemFactory();
     Timing timing = new Timing("JasminTest");
     return new ApplicationReader(build(), options, timing).read();
   }
diff --git a/src/test/java/com/android/tools/r8/smali/CheckSwitchInTestClass.java b/src/test/java/com/android/tools/r8/smali/CheckSwitchInTestClass.java
new file mode 100644
index 0000000..53ee6dd
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/smali/CheckSwitchInTestClass.java
@@ -0,0 +1,44 @@
+// Copyright (c) 2017, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.smali;
+
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class CheckSwitchInTestClass {
+  public static void main(String[] args) throws Exception {
+    // Load the generated Jasmin class, and get the test method.
+    Class<?> test = CheckSwitchInTestClass.class.getClassLoader().loadClass("Test");
+    Method method = test.getMethod("test", int.class);
+
+    // Get keys and default value from arguments.
+    List<Integer> keys = Arrays.stream(Arrays.copyOfRange(args, 0, args.length - 1))
+        .map(Integer::parseInt)
+        .sorted()
+        .collect(Collectors.toList());
+    int defaultValue = Integer.parseInt(args[args.length - 1]);
+
+    // Run over all keys and test a small interval around each.
+    long delta = 2;
+    for (Integer key : keys) {
+      for (long potential = key - delta; potential < key + delta; potential++) {
+        if (Integer.MIN_VALUE <= potential && potential <= Integer.MAX_VALUE) {
+          int testKey = (int) potential;
+          int result = ((Integer) method.invoke(null, testKey));
+          int expect = defaultValue;
+          if (keys.contains(testKey)) {
+            expect = testKey;
+          }
+          if (result != expect) {
+            System.out.println("Expected " + expect + " but got " + result);
+            System.exit(1);
+          }
+        }
+      }
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/smali/SmaliTestBase.java b/src/test/java/com/android/tools/r8/smali/SmaliTestBase.java
index 4253395..0322407 100644
--- a/src/test/java/com/android/tools/r8/smali/SmaliTestBase.java
+++ b/src/test/java/com/android/tools/r8/smali/SmaliTestBase.java
@@ -7,6 +7,7 @@
 import static org.junit.Assert.assertTrue;
 
 import com.android.tools.r8.R8;
+import com.android.tools.r8.TestBase;
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.dex.ApplicationReader;
 import com.android.tools.r8.graph.AppInfo;
@@ -45,18 +46,13 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executors;
 import org.antlr.runtime.RecognitionException;
-import org.junit.Rule;
-import org.junit.rules.TemporaryFolder;
 
-public class SmaliTestBase {
+public class SmaliTestBase extends TestBase {
 
   public static final String DEFAULT_CLASS_NAME = "Test";
   public static final String DEFAULT_MAIN_CLASS_NAME = DEFAULT_CLASS_NAME;
   public static final String DEFAULT_METHOD_NAME = "method";
 
-  @Rule
-  public TemporaryFolder temp = ToolHelper.getTemporaryFolderForTest();
-
   public static class MethodSignature {
 
     public final String clazz;
diff --git a/src/test/java/com/android/tools/r8/smali/SwitchRewritingTest.java b/src/test/java/com/android/tools/r8/smali/SwitchRewritingTest.java
index 7f89da7..e4ae46c 100644
--- a/src/test/java/com/android/tools/r8/smali/SwitchRewritingTest.java
+++ b/src/test/java/com/android/tools/r8/smali/SwitchRewritingTest.java
@@ -21,9 +21,14 @@
 import com.android.tools.r8.graph.DexCode;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.jasmin.JasminBuilder;
+import com.android.tools.r8.utils.AndroidApp;
+import com.android.tools.r8.utils.DexInspector;
+import com.android.tools.r8.utils.DexInspector.MethodSubject;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.StringUtils;
 import com.google.common.collect.ImmutableList;
+import java.util.List;
+import java.util.stream.Collectors;
 import org.junit.Test;
 
 public class SwitchRewritingTest extends SmaliTestBase {
@@ -456,4 +461,77 @@
     // class file max.
     runLargerSwitchJarTest(0, 1, 5503, null);
   }
+
+  private void runConvertCasesToIf(List<Integer> keys, int defaultValue, int expectedIfs)
+      throws Exception {
+    JasminBuilder builder = new JasminBuilder();
+    JasminBuilder.ClassBuilder clazz = builder.addClass("Test");
+
+    StringBuilder x = new StringBuilder();
+    StringBuilder y = new StringBuilder();
+    for (Integer key : keys) {
+      x.append(key).append(" : case_").append(key).append("\n");
+      y.append("case_").append(key).append(":\n");
+      y.append("    ldc ").append(key).append("\n");
+      y.append("    goto return_\n");
+    }
+
+    clazz.addStaticMethod("test", ImmutableList.of("I"), "I",
+        "    .limit stack 1",
+        "    .limit locals 1",
+        "    iload_0",
+        "    lookupswitch",
+        x.toString(),
+        "      default : case_default",
+        y.toString(),
+        "  case_default:",
+        "    ldc " + defaultValue,
+        "  return_:",
+        "    ireturn");
+
+    // Add the Jasmin class and a class from Java source with the main method.
+    AndroidApp.Builder appBuilder = AndroidApp.builder();
+    appBuilder.addClassProgramData(builder.buildClasses());
+    appBuilder.addProgramFiles(ToolHelper.getClassFileForTestClass(CheckSwitchInTestClass.class));
+    AndroidApp app = compileWithR8(appBuilder.build());
+
+    DexInspector inspector = new DexInspector(app);
+    MethodSubject method = inspector.clazz("Test").method("int", "test", ImmutableList.of("int"));
+    DexCode code = method.getMethod().getCode().asDexCode();
+
+    int packedSwitches = 0;
+    int sparseSwitches = 0;
+    int ifs = 0;
+    for (Instruction instruction : code.instructions) {
+      if (instruction instanceof PackedSwitch) {
+        packedSwitches++;
+      }
+      if (instruction instanceof SparseSwitch) {
+        sparseSwitches++;
+      }
+      if (instruction instanceof IfEq || instruction instanceof IfEqz) {
+        ifs++;
+      }
+    }
+
+    assertEquals(1, packedSwitches);
+    assertEquals(0, sparseSwitches);
+    assertEquals(expectedIfs, ifs);
+
+    // Run the code
+    List<String> args = keys.stream().map(Object::toString).collect(Collectors.toList());
+    args.add(Integer.toString(defaultValue));
+    runOnArt(app, CheckSwitchInTestClass.class, args);
+  }
+
+  @Test
+  public void convertCasesToIf() throws Exception {
+    runConvertCasesToIf(ImmutableList.of(0, 1000, 1001, 1002, 1003, 1004), -100, 1);
+    runConvertCasesToIf(ImmutableList.of(1000, 1001, 1002, 1003, 1004, 2000), -100, 1);
+    runConvertCasesToIf(ImmutableList.of(Integer.MIN_VALUE, 1000, 1001, 1002, 1003, 1004), -100, 1);
+    runConvertCasesToIf(ImmutableList.of(1000, 1001, 1002, 1003, 1004, Integer.MAX_VALUE), -100, 1);
+    runConvertCasesToIf(ImmutableList.of(0, 1000, 1001, 1002, 1003, 1004, 2000), -100, 2);
+    runConvertCasesToIf(ImmutableList.of(
+        Integer.MIN_VALUE, 1000, 1001, 1002, 1003, 1004, Integer.MAX_VALUE), -100, 2);
+  }
 }