Framework for removing dead switch cases

Change-Id: I961257feeb2a9a3d752199665f2c308142263ffa
diff --git a/src/main/java/com/android/tools/r8/ir/code/Switch.java b/src/main/java/com/android/tools/r8/ir/code/Switch.java
index 213ff92..bb88f03 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Switch.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Switch.java
@@ -47,6 +47,7 @@
   }
 
   private boolean valid() {
+    assert keys.length >= 1;
     assert keys.length <= Constants.U16BIT_MAX;
     // Keys must be acceding, and cannot target the fallthrough.
     assert keys.length == targetBlockIndices.length;
@@ -210,6 +211,10 @@
     return keys[index];
   }
 
+  public int getTargetBlockIndex(int index) {
+    return targetBlockIndices[index];
+  }
+
   public int[] getKeys() {
     return keys;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 584fb50..8194593 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -784,6 +784,7 @@
   }
 
   public void rewriteSwitch(IRCode code) {
+    boolean needToRemoveUnreachableBlocks = false;
     ListIterator<BasicBlock> blocksIterator = code.listIterator();
     while (blocksIterator.hasNext()) {
       BasicBlock block = blocksIterator.next();
@@ -792,6 +793,24 @@
         Instruction instruction = iterator.next();
         if (instruction.isSwitch()) {
           Switch theSwitch = instruction.asSwitch();
+          if (options.testing.enableDeadSwitchCaseElimination) {
+            SwitchCaseEliminator eliminator =
+                removeUnnecessarySwitchCases(code, theSwitch, iterator);
+            if (eliminator != null) {
+              if (eliminator.mayHaveIntroducedUnreachableBlocks()) {
+                needToRemoveUnreachableBlocks = true;
+              }
+
+              iterator.previous();
+              instruction = iterator.next();
+              if (instruction.isGoto()) {
+                continue;
+              }
+
+              assert instruction.isSwitch();
+              theSwitch = instruction.asSwitch();
+            }
+          }
           if (theSwitch.numberOfKeys() == 1) {
             // Rewrite the switch to an if.
             int fallthroughBlockIndex = theSwitch.getFallthroughBlockIndex();
@@ -884,7 +903,10 @@
             // in newIntervals, potentially with a switch combining the remaining intervals.
             // Now we check to see if we can create any if's to reduce size.
             IntList outliers = new IntArrayList();
-            int outliersAsIfSize = findIfsForCandidates(newSwitches, theSwitch, outliers);
+            int outliersAsIfSize =
+                appView.options().testing.enableSwitchToIfRewriting
+                    ? findIfsForCandidates(newSwitches, theSwitch, outliers)
+                    : 0;
 
             long newSwitchesSize = 0;
             List<IntList> newSwitchSequences = new ArrayList<>(newSwitches.size());
@@ -902,6 +924,11 @@
         }
       }
     }
+
+    if (needToRemoveUnreachableBlocks) {
+      code.removeUnreachableBlocks();
+    }
+
     // Rewriting of switches introduces new branching structure. It relies on critical edges
     // being split on the way in but does not maintain this property. We therefore split
     // critical edges at exit.
@@ -909,6 +936,42 @@
     assert code.isConsistentSSA();
   }
 
+  private SwitchCaseEliminator removeUnnecessarySwitchCases(
+      IRCode code, Switch theSwitch, InstructionListIterator iterator) {
+    BasicBlock defaultTarget = theSwitch.fallthroughBlock();
+    SwitchCaseEliminator eliminator = null;
+
+    // Compute the set of switch cases that can be removed.
+    for (int i = 0; i < theSwitch.numberOfKeys(); i++) {
+      BasicBlock targetBlock = theSwitch.targetBlock(i);
+
+      // This switch case can be removed if the behavior of the target block is equivalent to the
+      // behavior of the default block, or if the switch case is unreachable.
+      if (basicBlockCanBeReplacedBy(code, targetBlock, defaultTarget)
+          || switchCaseIsUnreachable(theSwitch, i)) {
+        if (eliminator == null) {
+          eliminator = new SwitchCaseEliminator(theSwitch, iterator);
+        }
+        eliminator.markSwitchCaseForRemoval(i);
+      }
+    }
+    if (eliminator != null) {
+      eliminator.optimize();
+    }
+    return eliminator;
+  }
+
+  private boolean basicBlockCanBeReplacedBy(IRCode code, BasicBlock block, BasicBlock replacement) {
+    // TODO(b/132420434): TBD.
+    return false;
+  }
+
+  private boolean switchCaseIsUnreachable(Switch theSwitch, int index) {
+    Value switchValue = theSwitch.value();
+    return switchValue.hasValueRange()
+        && !switchValue.getValueRange().containsValue(theSwitch.getKey(index));
+  }
+
   /**
    * Inline the indirection of switch maps into the switch statement.
    * <p>
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/SwitchCaseEliminator.java b/src/main/java/com/android/tools/r8/ir/optimize/SwitchCaseEliminator.java
new file mode 100644
index 0000000..dcefdf5
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/optimize/SwitchCaseEliminator.java
@@ -0,0 +1,147 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize;
+
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.Goto;
+import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.Switch;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.ints.IntList;
+import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
+import it.unimi.dsi.fastutil.ints.IntSet;
+import java.util.Comparator;
+import java.util.function.IntPredicate;
+
+/** Helper to remove dead switch cases from a switch instruction. */
+class SwitchCaseEliminator {
+
+  private final BasicBlock block;
+  private final BasicBlock defaultTarget;
+  private final InstructionListIterator iterator;
+  private final Switch theSwitch;
+
+  private boolean mayHaveIntroducedUnreachableBlocks = false;
+  private IntSet switchCasesToBeRemoved;
+
+  SwitchCaseEliminator(Switch theSwitch, InstructionListIterator iterator) {
+    this.block = theSwitch.getBlock();
+    this.defaultTarget = theSwitch.fallthroughBlock();
+    this.iterator = iterator;
+    this.theSwitch = theSwitch;
+  }
+
+  private boolean allSwitchCasesMarkedForRemoval() {
+    assert switchCasesToBeRemoved != null;
+    return switchCasesToBeRemoved.size() == theSwitch.numberOfKeys();
+  }
+
+  private boolean canBeOptimized() {
+    assert switchCasesToBeRemoved == null || !switchCasesToBeRemoved.isEmpty();
+    return switchCasesToBeRemoved != null;
+  }
+
+  boolean mayHaveIntroducedUnreachableBlocks() {
+    return mayHaveIntroducedUnreachableBlocks;
+  }
+
+  void markSwitchCaseForRemoval(int i) {
+    if (switchCasesToBeRemoved == null) {
+      switchCasesToBeRemoved = new IntOpenHashSet();
+    }
+    switchCasesToBeRemoved.add(i);
+  }
+
+  boolean optimize() {
+    if (canBeOptimized()) {
+      int originalNumberOfSuccessors = block.getSuccessors().size();
+      unlinkDeadSuccessors();
+      if (allSwitchCasesMarkedForRemoval()) {
+        // Replace switch with a simple goto since only the fall through is left.
+        replaceSwitchByGoto();
+      } else {
+        // Replace switch by a new switch where the dead switch cases have been removed.
+        replaceSwitchByOptimizedSwitch(originalNumberOfSuccessors);
+      }
+      return true;
+    }
+    return false;
+  }
+
+  private void unlinkDeadSuccessors() {
+    IntPredicate successorHasBecomeDeadPredicate = computeSuccessorHasBecomeDeadPredicate();
+    IntList successorIndicesToBeRemoved = new IntArrayList();
+    for (int i = 0; i < block.getSuccessors().size(); i++) {
+      if (successorHasBecomeDeadPredicate.test(i)) {
+        BasicBlock successor = block.getSuccessors().get(i);
+        successor.removePredecessor(block);
+        successorIndicesToBeRemoved.add(i);
+        if (successor.getPredecessors().isEmpty()) {
+          mayHaveIntroducedUnreachableBlocks = true;
+        }
+      }
+    }
+    successorIndicesToBeRemoved.sort(Comparator.naturalOrder());
+    block.removeSuccessorsByIndex(successorIndicesToBeRemoved);
+  }
+
+  private IntPredicate computeSuccessorHasBecomeDeadPredicate() {
+    int[] numberOfControlFlowEdgesToBlockWithIndex = new int[block.getSuccessors().size()];
+    for (int i = 0; i < theSwitch.numberOfKeys(); i++) {
+      if (!switchCasesToBeRemoved.contains(i)) {
+        int targetBlockIndex = theSwitch.getTargetBlockIndex(i);
+        numberOfControlFlowEdgesToBlockWithIndex[targetBlockIndex] += 1;
+      }
+    }
+    numberOfControlFlowEdgesToBlockWithIndex[theSwitch.getFallthroughBlockIndex()] += 1;
+    for (int i : block.getCatchHandlersWithSuccessorIndexes().getUniqueTargets()) {
+      numberOfControlFlowEdgesToBlockWithIndex[i] += 1;
+    }
+    return i -> numberOfControlFlowEdgesToBlockWithIndex[i] == 0;
+  }
+
+  private void replaceSwitchByGoto() {
+    iterator.replaceCurrentInstruction(new Goto(defaultTarget));
+  }
+
+  private void replaceSwitchByOptimizedSwitch(int originalNumberOfSuccessors) {
+    int[] targetBlockIndexOffset = new int[originalNumberOfSuccessors];
+    for (int i : switchCasesToBeRemoved) {
+      int targetBlockIndex = theSwitch.getTargetBlockIndex(i);
+      // Add 1 because we are interested in the number of targets removed before a given index.
+      if (targetBlockIndex + 1 < targetBlockIndexOffset.length) {
+        targetBlockIndexOffset[targetBlockIndex + 1] = 1;
+      }
+    }
+
+    for (int i = 1; i < targetBlockIndexOffset.length; i++) {
+      targetBlockIndexOffset[i] += targetBlockIndexOffset[i - 1];
+    }
+
+    int newNumberOfKeys = theSwitch.numberOfKeys() - switchCasesToBeRemoved.size();
+    int[] newKeys = new int[newNumberOfKeys];
+    int[] newTargetBlockIndices = new int[newNumberOfKeys];
+    for (int i = 0, j = 0; i < theSwitch.numberOfKeys(); i++) {
+      if (!switchCasesToBeRemoved.contains(i)) {
+        newKeys[j] = theSwitch.getKey(i);
+        newTargetBlockIndices[j] =
+            theSwitch.getTargetBlockIndex(i)
+                - targetBlockIndexOffset[theSwitch.getTargetBlockIndex(i)];
+        assert newTargetBlockIndices[j] < block.getSuccessors().size();
+        assert newTargetBlockIndices[j] != theSwitch.getFallthroughBlockIndex();
+        j++;
+      }
+    }
+
+    assert targetBlockIndexOffset[theSwitch.getFallthroughBlockIndex()] == 0;
+
+    iterator.replaceCurrentInstruction(
+        new Switch(
+            theSwitch.value(),
+            newKeys,
+            newTargetBlockIndices,
+            theSwitch.getFallthroughBlockIndex()));
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index e32efa4..0a8f820 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -804,6 +804,8 @@
     public boolean allowTypeErrors =
         !Version.isDev() || System.getProperty("com.android.tools.r8.allowTypeErrors") != null;
     public boolean alwaysUsePessimisticRegisterAllocation = false;
+    public boolean enableDeadSwitchCaseElimination = true;
+    public boolean enableSwitchToIfRewriting = true;
     public boolean invertConditionals = false;
     public boolean placeExceptionalBlocksLast = false;
     public boolean dontCreateMarkerInD8 = false;
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/SwitchCaseRemovalTest.java b/src/test/java/com/android/tools/r8/ir/optimize/SwitchCaseRemovalTest.java
new file mode 100644
index 0000000..1e430ad
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/SwitchCaseRemovalTest.java
@@ -0,0 +1,143 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class SwitchCaseRemovalTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimes().build();
+  }
+
+  public SwitchCaseRemovalTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClasses(TestClass.class)
+        .addKeepMainRule(TestClass.class)
+        .addKeepRules(
+            "-assumevalues class " + TestClass.class.getTypeName() + " {",
+            "  public static final int x return 0..42;",
+            "}")
+        .addOptionsModification(options -> options.testing.enableSwitchToIfRewriting = false)
+        .enableInliningAnnotations()
+        .enableMemberValuePropagationAnnotations()
+        .setMinApi(parameters.getRuntime())
+        .compile()
+        .inspect(this::verifyOutput)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutput(StringUtils.times(StringUtils.lines("Hello world!"), 3));
+  }
+
+  private void verifyOutput(CodeInspector inspector) {
+    ClassSubject classSubject = inspector.clazz(TestClass.class);
+    assertThat(classSubject, isPresent());
+    assertThat(classSubject.uniqueMethodWithName("dead"), not(isPresent()));
+
+    {
+      MethodSubject methodSubject = classSubject.uniqueMethodWithName("testSwitchCaseRemoval");
+      assertThat(methodSubject, isPresent());
+      assertEquals(
+          parameters.isCfRuntime() ? 2 : 1,
+          methodSubject.streamInstructions().filter(InstructionSubject::isConstNull).count());
+      assertEquals(
+          parameters.isCfRuntime() ? 3 : 2,
+          methodSubject.streamInstructions().filter(InstructionSubject::isReturnObject).count());
+    }
+
+    {
+      MethodSubject methodSubject =
+          classSubject.uniqueMethodWithName("testSwitchReplacementWithExplicitDefaultCase");
+      assertThat(methodSubject, isPresent());
+      assertTrue(methodSubject.streamInstructions().noneMatch(InstructionSubject::isSwitch));
+    }
+
+    {
+      MethodSubject methodSubject =
+          classSubject.uniqueMethodWithName("testSwitchReplacementWithoutExplicitDefaultCase");
+      assertThat(methodSubject, isPresent());
+      assertTrue(methodSubject.streamInstructions().noneMatch(InstructionSubject::isSwitch));
+    }
+  }
+
+  static class TestClass {
+
+    public static final int x = System.currentTimeMillis() >= 0 ? 0 : 42;
+
+    public static void main(String[] args) {
+      System.out.println(testSwitchCaseRemoval());
+      System.out.println(testSwitchReplacementWithExplicitDefaultCase());
+      System.out.println(testSwitchReplacementWithoutExplicitDefaultCase());
+    }
+
+    @NeverInline
+    public static String testSwitchCaseRemoval() {
+      switch (x) {
+        case 0:
+          return "Hello world!";
+        case 1: // TODO(b/132420434): Verify that this is removed.
+          return null;
+        case 43:
+          return dead();
+        case 2: // TODO(b/132420434): Verify that this is removed.
+        default:
+          return null;
+      }
+    }
+
+    @NeverInline
+    @NeverPropagateValue
+    public static String testSwitchReplacementWithExplicitDefaultCase() {
+      switch (x) {
+        case 43:
+          return dead();
+        default:
+          return "Hello world!";
+      }
+    }
+
+    @NeverInline
+    @NeverPropagateValue
+    public static String testSwitchReplacementWithoutExplicitDefaultCase() {
+      switch (x) {
+        case 43:
+          return dead();
+      }
+      return "Hello world!";
+    }
+
+    @NeverInline
+    @NeverPropagateValue
+    public static String dead() {
+      return "WTF";
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
index b9f72c4..ff48000 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
@@ -238,6 +238,11 @@
   }
 
   @Override
+  public boolean isSwitch() {
+    return isPackedSwitch() || isSparseSwitch();
+  }
+
+  @Override
   public boolean isPackedSwitch() {
     return instruction instanceof CfSwitch
         && ((CfSwitch) instruction).getKind() == CfSwitch.Kind.TABLE;
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
index 466f68f..32e3f00 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
@@ -372,6 +372,11 @@
   }
 
   @Override
+  public boolean isSwitch() {
+    return isPackedSwitch() || isSparseSwitch();
+  }
+
+  @Override
   public boolean isPackedSwitch() {
     return instruction instanceof PackedSwitch;
   }
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
index 0c5a4a6..722aa5f 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
@@ -84,6 +84,8 @@
 
   boolean isIf(); // Also include CF/if_cmp* instructions.
 
+  boolean isSwitch();
+
   boolean isPackedSwitch();
 
   boolean isSparseSwitch();