Recognize string-switches with intermediate phi id value
Fixes: 188788708
Change-Id: I8bf579d9a38fca9603f7db56fb8558d0889fd5dd
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
index 50bec23..a377238 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
@@ -311,7 +311,7 @@
private final BasicBlock continuationBlock;
private final DexItemFactory dexItemFactory;
- private final Phi idValue;
+ private final Phi intermediateIdValue;
private final Value stringValue;
Builder(
@@ -321,10 +321,38 @@
Value stringValue) {
this.continuationBlock = continuationBlock;
this.dexItemFactory = dexItemFactory;
- this.idValue = idValue;
+ this.intermediateIdValue = getIntermediateIdValueOrElse(idValue, idValue);
this.stringValue = stringValue;
}
+ // Finds the intermediate id value from the given (non-intermediate) id value. If the non-
+ // intermediate id value v_id has the following structure, then v_id_intermediate is returned.
+ //
+ // v_id_intermediate = phi(v1(0), v2(1), v_n(n-1))
+ // v_id = phi(v0(-1), v_id_intermediate)
+ //
+ // Normally, this intermediate value is not present, and the code will have the following
+ // structure:
+ //
+ // v_id = phi(v0(-1), v1(0), v2(1), v_n(n-1))
+ private Phi getIntermediateIdValueOrElse(Phi idValue, Phi defaultValue) {
+ if (idValue.getOperands().size() != 2) {
+ return defaultValue;
+ }
+ Phi intermediateIdValue = null;
+ for (Value operand : idValue.getOperands()) {
+ if (operand.isPhi()) {
+ if (intermediateIdValue != null) {
+ return defaultValue;
+ }
+ intermediateIdValue = operand.asPhi();
+ }
+ }
+ assert intermediateIdValue == null
+ || intermediateIdValue.getOperands().stream().noneMatch(Value::isPhi);
+ return intermediateIdValue != null ? intermediateIdValue : defaultValue;
+ }
+
// Attempts to build a mapping from strings to their ids starting from the given block. The
// mapping is built by traversing the control flow graph upwards, so the given block is
// expected to be the last block in the sequence of blocks that compare the hash code of the
@@ -569,17 +597,17 @@
InstructionIterator instructionIterator = block.iterator();
ConstNumber constNumberInstruction;
if (block.isTrivialGoto()) {
- if (block.getUniqueNormalSuccessor() != idValue.getBlock()) {
+ if (block.getUniqueNormalSuccessor() != intermediateIdValue.getBlock()) {
return false;
}
- int predecessorIndex = idValue.getBlock().getPredecessors().indexOf(block);
+ int predecessorIndex = intermediateIdValue.getBlock().getPredecessors().indexOf(block);
constNumberInstruction =
- asConstNumberOrNull(idValue.getOperand(predecessorIndex).definition);
+ asConstNumberOrNull(intermediateIdValue.getOperand(predecessorIndex).definition);
} else {
constNumberInstruction = instructionIterator.next().asConstNumber();
}
if (constNumberInstruction == null
- || !idValue.getOperands().contains(constNumberInstruction.outValue())) {
+ || !intermediateIdValue.getOperands().contains(constNumberInstruction.outValue())) {
return false;
}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/switches/StringSwitchWitNonIntermediateIdValueTest.java b/src/test/java/com/android/tools/r8/ir/optimize/switches/StringSwitchWitNonIntermediateIdValueTest.java
new file mode 100644
index 0000000..213f02b
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/switches/StringSwitchWitNonIntermediateIdValueTest.java
@@ -0,0 +1,130 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.switches;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class StringSwitchWitNonIntermediateIdValueTest extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ public StringSwitchWitNonIntermediateIdValueTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ @Test
+ public void testD8() throws Exception {
+ assumeTrue(parameters.isDexRuntime());
+ testForD8()
+ .addProgramClasses(Main.class)
+ .addOptionsModification(options -> options.minimumStringSwitchSize = 4)
+ .release()
+ .setMinApi(parameters.getApiLevel())
+ .compile()
+ .inspect(this::verifyRewrittenToIfs)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("Foo", "Bar", "Baz", "Qux");
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ testForR8(parameters.getBackend())
+ .addProgramClasses(Main.class)
+ .addKeepMainRule(Main.class)
+ .addOptionsModification(options -> options.minimumStringSwitchSize = 4)
+ .enableInliningAnnotations()
+ .setMinApi(parameters.getApiLevel())
+ .compile()
+ .inspect(this::verifyRewrittenToIfs)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("Foo", "Bar", "Baz", "Qux");
+ }
+
+ private void verifyRewrittenToIfs(CodeInspector inspector) {
+ MethodSubject testMethodSubject = inspector.clazz(Main.class).uniqueMethodWithName("test");
+ assertThat(testMethodSubject, isPresent());
+ assertTrue(testMethodSubject.streamInstructions().noneMatch(InstructionSubject::isSwitch));
+ assertEquals(
+ 3, testMethodSubject.streamInstructions().filter(InstructionSubject::isIf).count());
+ }
+
+ static class Main {
+
+ public static void main(String[] args) {
+ test("Foo");
+ test("Bar");
+ test("Baz");
+ test("Qux");
+ }
+
+ @NeverInline
+ static void test(String str) {
+ int hashCode = str.hashCode();
+ int id = 0;
+ outer:
+ {
+ int nonZeroId;
+ switch (hashCode) {
+ case 70822: // "Foo".hashCode()
+ if (str.equals("Foo")) {
+ nonZeroId = 1;
+ break;
+ }
+ break outer;
+ case 66547: // "Bar".hashCode()
+ if (str.equals("Bar")) {
+ nonZeroId = 2;
+ break;
+ }
+ break outer;
+ case 66555: // "Baz".hashCode()
+ if (str.equals("Baz")) {
+ nonZeroId = 3;
+ break;
+ }
+ break outer;
+ default:
+ break outer;
+ }
+ id = nonZeroId;
+ }
+ switch (id) {
+ case 1:
+ System.out.println("Foo");
+ break;
+ case 2:
+ System.out.println("Bar");
+ break;
+ case 3:
+ System.out.println("Baz");
+ break;
+ default:
+ System.out.println("Qux");
+ break;
+ }
+ }
+ }
+}