Framework for removing dead switch cases
Change-Id: I961257feeb2a9a3d752199665f2c308142263ffa
diff --git a/src/main/java/com/android/tools/r8/ir/code/Switch.java b/src/main/java/com/android/tools/r8/ir/code/Switch.java
index 213ff92..bb88f03 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Switch.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Switch.java
@@ -47,6 +47,7 @@
}
private boolean valid() {
+ assert keys.length >= 1;
assert keys.length <= Constants.U16BIT_MAX;
// Keys must be acceding, and cannot target the fallthrough.
assert keys.length == targetBlockIndices.length;
@@ -210,6 +211,10 @@
return keys[index];
}
+ public int getTargetBlockIndex(int index) {
+ return targetBlockIndices[index];
+ }
+
public int[] getKeys() {
return keys;
}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 584fb50..8194593 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -784,6 +784,7 @@
}
public void rewriteSwitch(IRCode code) {
+ boolean needToRemoveUnreachableBlocks = false;
ListIterator<BasicBlock> blocksIterator = code.listIterator();
while (blocksIterator.hasNext()) {
BasicBlock block = blocksIterator.next();
@@ -792,6 +793,24 @@
Instruction instruction = iterator.next();
if (instruction.isSwitch()) {
Switch theSwitch = instruction.asSwitch();
+ if (options.testing.enableDeadSwitchCaseElimination) {
+ SwitchCaseEliminator eliminator =
+ removeUnnecessarySwitchCases(code, theSwitch, iterator);
+ if (eliminator != null) {
+ if (eliminator.mayHaveIntroducedUnreachableBlocks()) {
+ needToRemoveUnreachableBlocks = true;
+ }
+
+ iterator.previous();
+ instruction = iterator.next();
+ if (instruction.isGoto()) {
+ continue;
+ }
+
+ assert instruction.isSwitch();
+ theSwitch = instruction.asSwitch();
+ }
+ }
if (theSwitch.numberOfKeys() == 1) {
// Rewrite the switch to an if.
int fallthroughBlockIndex = theSwitch.getFallthroughBlockIndex();
@@ -884,7 +903,10 @@
// in newIntervals, potentially with a switch combining the remaining intervals.
// Now we check to see if we can create any if's to reduce size.
IntList outliers = new IntArrayList();
- int outliersAsIfSize = findIfsForCandidates(newSwitches, theSwitch, outliers);
+ int outliersAsIfSize =
+ appView.options().testing.enableSwitchToIfRewriting
+ ? findIfsForCandidates(newSwitches, theSwitch, outliers)
+ : 0;
long newSwitchesSize = 0;
List<IntList> newSwitchSequences = new ArrayList<>(newSwitches.size());
@@ -902,6 +924,11 @@
}
}
}
+
+ if (needToRemoveUnreachableBlocks) {
+ code.removeUnreachableBlocks();
+ }
+
// Rewriting of switches introduces new branching structure. It relies on critical edges
// being split on the way in but does not maintain this property. We therefore split
// critical edges at exit.
@@ -909,6 +936,42 @@
assert code.isConsistentSSA();
}
+ private SwitchCaseEliminator removeUnnecessarySwitchCases(
+ IRCode code, Switch theSwitch, InstructionListIterator iterator) {
+ BasicBlock defaultTarget = theSwitch.fallthroughBlock();
+ SwitchCaseEliminator eliminator = null;
+
+ // Compute the set of switch cases that can be removed.
+ for (int i = 0; i < theSwitch.numberOfKeys(); i++) {
+ BasicBlock targetBlock = theSwitch.targetBlock(i);
+
+ // This switch case can be removed if the behavior of the target block is equivalent to the
+ // behavior of the default block, or if the switch case is unreachable.
+ if (basicBlockCanBeReplacedBy(code, targetBlock, defaultTarget)
+ || switchCaseIsUnreachable(theSwitch, i)) {
+ if (eliminator == null) {
+ eliminator = new SwitchCaseEliminator(theSwitch, iterator);
+ }
+ eliminator.markSwitchCaseForRemoval(i);
+ }
+ }
+ if (eliminator != null) {
+ eliminator.optimize();
+ }
+ return eliminator;
+ }
+
+ private boolean basicBlockCanBeReplacedBy(IRCode code, BasicBlock block, BasicBlock replacement) {
+ // TODO(b/132420434): TBD.
+ return false;
+ }
+
+ private boolean switchCaseIsUnreachable(Switch theSwitch, int index) {
+ Value switchValue = theSwitch.value();
+ return switchValue.hasValueRange()
+ && !switchValue.getValueRange().containsValue(theSwitch.getKey(index));
+ }
+
/**
* Inline the indirection of switch maps into the switch statement.
* <p>
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/SwitchCaseEliminator.java b/src/main/java/com/android/tools/r8/ir/optimize/SwitchCaseEliminator.java
new file mode 100644
index 0000000..dcefdf5
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/optimize/SwitchCaseEliminator.java
@@ -0,0 +1,147 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize;
+
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.Goto;
+import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.Switch;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.ints.IntList;
+import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
+import it.unimi.dsi.fastutil.ints.IntSet;
+import java.util.Comparator;
+import java.util.function.IntPredicate;
+
+/** Helper to remove dead switch cases from a switch instruction. */
+class SwitchCaseEliminator {
+
+ private final BasicBlock block;
+ private final BasicBlock defaultTarget;
+ private final InstructionListIterator iterator;
+ private final Switch theSwitch;
+
+ private boolean mayHaveIntroducedUnreachableBlocks = false;
+ private IntSet switchCasesToBeRemoved;
+
+ SwitchCaseEliminator(Switch theSwitch, InstructionListIterator iterator) {
+ this.block = theSwitch.getBlock();
+ this.defaultTarget = theSwitch.fallthroughBlock();
+ this.iterator = iterator;
+ this.theSwitch = theSwitch;
+ }
+
+ private boolean allSwitchCasesMarkedForRemoval() {
+ assert switchCasesToBeRemoved != null;
+ return switchCasesToBeRemoved.size() == theSwitch.numberOfKeys();
+ }
+
+ private boolean canBeOptimized() {
+ assert switchCasesToBeRemoved == null || !switchCasesToBeRemoved.isEmpty();
+ return switchCasesToBeRemoved != null;
+ }
+
+ boolean mayHaveIntroducedUnreachableBlocks() {
+ return mayHaveIntroducedUnreachableBlocks;
+ }
+
+ void markSwitchCaseForRemoval(int i) {
+ if (switchCasesToBeRemoved == null) {
+ switchCasesToBeRemoved = new IntOpenHashSet();
+ }
+ switchCasesToBeRemoved.add(i);
+ }
+
+ boolean optimize() {
+ if (canBeOptimized()) {
+ int originalNumberOfSuccessors = block.getSuccessors().size();
+ unlinkDeadSuccessors();
+ if (allSwitchCasesMarkedForRemoval()) {
+ // Replace switch with a simple goto since only the fall through is left.
+ replaceSwitchByGoto();
+ } else {
+ // Replace switch by a new switch where the dead switch cases have been removed.
+ replaceSwitchByOptimizedSwitch(originalNumberOfSuccessors);
+ }
+ return true;
+ }
+ return false;
+ }
+
+ private void unlinkDeadSuccessors() {
+ IntPredicate successorHasBecomeDeadPredicate = computeSuccessorHasBecomeDeadPredicate();
+ IntList successorIndicesToBeRemoved = new IntArrayList();
+ for (int i = 0; i < block.getSuccessors().size(); i++) {
+ if (successorHasBecomeDeadPredicate.test(i)) {
+ BasicBlock successor = block.getSuccessors().get(i);
+ successor.removePredecessor(block);
+ successorIndicesToBeRemoved.add(i);
+ if (successor.getPredecessors().isEmpty()) {
+ mayHaveIntroducedUnreachableBlocks = true;
+ }
+ }
+ }
+ successorIndicesToBeRemoved.sort(Comparator.naturalOrder());
+ block.removeSuccessorsByIndex(successorIndicesToBeRemoved);
+ }
+
+ private IntPredicate computeSuccessorHasBecomeDeadPredicate() {
+ int[] numberOfControlFlowEdgesToBlockWithIndex = new int[block.getSuccessors().size()];
+ for (int i = 0; i < theSwitch.numberOfKeys(); i++) {
+ if (!switchCasesToBeRemoved.contains(i)) {
+ int targetBlockIndex = theSwitch.getTargetBlockIndex(i);
+ numberOfControlFlowEdgesToBlockWithIndex[targetBlockIndex] += 1;
+ }
+ }
+ numberOfControlFlowEdgesToBlockWithIndex[theSwitch.getFallthroughBlockIndex()] += 1;
+ for (int i : block.getCatchHandlersWithSuccessorIndexes().getUniqueTargets()) {
+ numberOfControlFlowEdgesToBlockWithIndex[i] += 1;
+ }
+ return i -> numberOfControlFlowEdgesToBlockWithIndex[i] == 0;
+ }
+
+ private void replaceSwitchByGoto() {
+ iterator.replaceCurrentInstruction(new Goto(defaultTarget));
+ }
+
+ private void replaceSwitchByOptimizedSwitch(int originalNumberOfSuccessors) {
+ int[] targetBlockIndexOffset = new int[originalNumberOfSuccessors];
+ for (int i : switchCasesToBeRemoved) {
+ int targetBlockIndex = theSwitch.getTargetBlockIndex(i);
+ // Add 1 because we are interested in the number of targets removed before a given index.
+ if (targetBlockIndex + 1 < targetBlockIndexOffset.length) {
+ targetBlockIndexOffset[targetBlockIndex + 1] = 1;
+ }
+ }
+
+ for (int i = 1; i < targetBlockIndexOffset.length; i++) {
+ targetBlockIndexOffset[i] += targetBlockIndexOffset[i - 1];
+ }
+
+ int newNumberOfKeys = theSwitch.numberOfKeys() - switchCasesToBeRemoved.size();
+ int[] newKeys = new int[newNumberOfKeys];
+ int[] newTargetBlockIndices = new int[newNumberOfKeys];
+ for (int i = 0, j = 0; i < theSwitch.numberOfKeys(); i++) {
+ if (!switchCasesToBeRemoved.contains(i)) {
+ newKeys[j] = theSwitch.getKey(i);
+ newTargetBlockIndices[j] =
+ theSwitch.getTargetBlockIndex(i)
+ - targetBlockIndexOffset[theSwitch.getTargetBlockIndex(i)];
+ assert newTargetBlockIndices[j] < block.getSuccessors().size();
+ assert newTargetBlockIndices[j] != theSwitch.getFallthroughBlockIndex();
+ j++;
+ }
+ }
+
+ assert targetBlockIndexOffset[theSwitch.getFallthroughBlockIndex()] == 0;
+
+ iterator.replaceCurrentInstruction(
+ new Switch(
+ theSwitch.value(),
+ newKeys,
+ newTargetBlockIndices,
+ theSwitch.getFallthroughBlockIndex()));
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index e32efa4..0a8f820 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -804,6 +804,8 @@
public boolean allowTypeErrors =
!Version.isDev() || System.getProperty("com.android.tools.r8.allowTypeErrors") != null;
public boolean alwaysUsePessimisticRegisterAllocation = false;
+ public boolean enableDeadSwitchCaseElimination = true;
+ public boolean enableSwitchToIfRewriting = true;
public boolean invertConditionals = false;
public boolean placeExceptionalBlocksLast = false;
public boolean dontCreateMarkerInD8 = false;
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/SwitchCaseRemovalTest.java b/src/test/java/com/android/tools/r8/ir/optimize/SwitchCaseRemovalTest.java
new file mode 100644
index 0000000..1e430ad
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/SwitchCaseRemovalTest.java
@@ -0,0 +1,143 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class SwitchCaseRemovalTest extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimes().build();
+ }
+
+ public SwitchCaseRemovalTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void test() throws Exception {
+ testForR8(parameters.getBackend())
+ .addProgramClasses(TestClass.class)
+ .addKeepMainRule(TestClass.class)
+ .addKeepRules(
+ "-assumevalues class " + TestClass.class.getTypeName() + " {",
+ " public static final int x return 0..42;",
+ "}")
+ .addOptionsModification(options -> options.testing.enableSwitchToIfRewriting = false)
+ .enableInliningAnnotations()
+ .enableMemberValuePropagationAnnotations()
+ .setMinApi(parameters.getRuntime())
+ .compile()
+ .inspect(this::verifyOutput)
+ .run(parameters.getRuntime(), TestClass.class)
+ .assertSuccessWithOutput(StringUtils.times(StringUtils.lines("Hello world!"), 3));
+ }
+
+ private void verifyOutput(CodeInspector inspector) {
+ ClassSubject classSubject = inspector.clazz(TestClass.class);
+ assertThat(classSubject, isPresent());
+ assertThat(classSubject.uniqueMethodWithName("dead"), not(isPresent()));
+
+ {
+ MethodSubject methodSubject = classSubject.uniqueMethodWithName("testSwitchCaseRemoval");
+ assertThat(methodSubject, isPresent());
+ assertEquals(
+ parameters.isCfRuntime() ? 2 : 1,
+ methodSubject.streamInstructions().filter(InstructionSubject::isConstNull).count());
+ assertEquals(
+ parameters.isCfRuntime() ? 3 : 2,
+ methodSubject.streamInstructions().filter(InstructionSubject::isReturnObject).count());
+ }
+
+ {
+ MethodSubject methodSubject =
+ classSubject.uniqueMethodWithName("testSwitchReplacementWithExplicitDefaultCase");
+ assertThat(methodSubject, isPresent());
+ assertTrue(methodSubject.streamInstructions().noneMatch(InstructionSubject::isSwitch));
+ }
+
+ {
+ MethodSubject methodSubject =
+ classSubject.uniqueMethodWithName("testSwitchReplacementWithoutExplicitDefaultCase");
+ assertThat(methodSubject, isPresent());
+ assertTrue(methodSubject.streamInstructions().noneMatch(InstructionSubject::isSwitch));
+ }
+ }
+
+ static class TestClass {
+
+ public static final int x = System.currentTimeMillis() >= 0 ? 0 : 42;
+
+ public static void main(String[] args) {
+ System.out.println(testSwitchCaseRemoval());
+ System.out.println(testSwitchReplacementWithExplicitDefaultCase());
+ System.out.println(testSwitchReplacementWithoutExplicitDefaultCase());
+ }
+
+ @NeverInline
+ public static String testSwitchCaseRemoval() {
+ switch (x) {
+ case 0:
+ return "Hello world!";
+ case 1: // TODO(b/132420434): Verify that this is removed.
+ return null;
+ case 43:
+ return dead();
+ case 2: // TODO(b/132420434): Verify that this is removed.
+ default:
+ return null;
+ }
+ }
+
+ @NeverInline
+ @NeverPropagateValue
+ public static String testSwitchReplacementWithExplicitDefaultCase() {
+ switch (x) {
+ case 43:
+ return dead();
+ default:
+ return "Hello world!";
+ }
+ }
+
+ @NeverInline
+ @NeverPropagateValue
+ public static String testSwitchReplacementWithoutExplicitDefaultCase() {
+ switch (x) {
+ case 43:
+ return dead();
+ }
+ return "Hello world!";
+ }
+
+ @NeverInline
+ @NeverPropagateValue
+ public static String dead() {
+ return "WTF";
+ }
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
index b9f72c4..ff48000 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/CfInstructionSubject.java
@@ -238,6 +238,11 @@
}
@Override
+ public boolean isSwitch() {
+ return isPackedSwitch() || isSparseSwitch();
+ }
+
+ @Override
public boolean isPackedSwitch() {
return instruction instanceof CfSwitch
&& ((CfSwitch) instruction).getKind() == CfSwitch.Kind.TABLE;
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
index 466f68f..32e3f00 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/DexInstructionSubject.java
@@ -372,6 +372,11 @@
}
@Override
+ public boolean isSwitch() {
+ return isPackedSwitch() || isSparseSwitch();
+ }
+
+ @Override
public boolean isPackedSwitch() {
return instruction instanceof PackedSwitch;
}
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
index 0c5a4a6..722aa5f 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/InstructionSubject.java
@@ -84,6 +84,8 @@
boolean isIf(); // Also include CF/if_cmp* instructions.
+ boolean isSwitch();
+
boolean isPackedSwitch();
boolean isSparseSwitch();