Make string-switch conversion robust towards canonicalization
Bug: 135721688
Change-Id: Ia5a15673e1e9615d809d9566265a05fed2eacbb8
diff --git a/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java b/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
index 06aef0a..fcb0c1b 100644
--- a/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
+++ b/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
@@ -43,6 +43,10 @@
this.value = value;
}
+ public static ConstNumber asConstNumberOrNull(Instruction instruction) {
+ return (ConstNumber) instruction;
+ }
+
@Override
public int opcode() {
return Opcodes.CONST_NUMBER;
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCode.java b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
index c9d0d84..064842a 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCode.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
@@ -33,6 +33,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
+import com.google.common.collect.Streams;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
@@ -986,6 +987,10 @@
return this::instructionIterator;
}
+ public Stream<Instruction> streamInstructions() {
+ return Streams.stream(instructions());
+ }
+
public <T extends Instruction> Iterable<T> instructions(Predicate<Instruction> predicate) {
return () -> IteratorUtils.filter(instructionIterator(), predicate);
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
index 0365d43..57fd8e0 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
@@ -4,6 +4,8 @@
package com.android.tools.r8.ir.conversion;
+import static com.android.tools.r8.ir.code.ConstNumber.asConstNumberOrNull;
+
import com.android.tools.r8.errors.Unreachable;
import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexString;
@@ -494,18 +496,36 @@
Set<BasicBlock> visited) {
InstructionIterator instructionIterator = block.iterator();
- // Verify that the first instruction is a non-throwing const-string instruction.
- // If the string throws, it can't be decoded, and then the string does not have a hash.
- ConstString theString = instructionIterator.next().asConstString();
- if (theString == null || theString.instructionInstanceCanThrow()) {
+ // The first instruction is expected to be a non-throwing const-string instruction, but this
+ // may change due to canonicalization. If the string throws, it can't be decoded, and then
+ // the string does not have a hash.
+ Instruction first = instructionIterator.next();
+ ConstString optionalString = first.asConstString();
+ if (optionalString != null && optionalString.instructionInstanceCanThrow()) {
return false;
}
- InvokeVirtual theInvoke = instructionIterator.next().asInvokeVirtual();
+ // The next instruction must be an invoke-virtual that calls stringValue.equals() with a
+ // constant string argument.
+ InvokeVirtual theInvoke =
+ first.isConstString()
+ ? instructionIterator.next().asInvokeVirtual()
+ : first.asInvokeVirtual();
if (theInvoke == null
|| theInvoke.getInvokedMethod() != dexItemFactory.stringMembers.equals
- || theInvoke.getReceiver() != stringValue
- || theInvoke.inValues().get(1) != theString.outValue()) {
+ || theInvoke.getReceiver() != stringValue) {
+ return false;
+ }
+
+ // If this block starts with a const-string instruction, then it should be passed as the
+ // second argument to equals().
+ if (optionalString != null && theInvoke.getArgument(1) != optionalString.outValue()) {
+ assert false; // This should generally not happen.
+ return false;
+ }
+
+ Value theString = theInvoke.getArgument(1).getAliasedValue();
+ if (!theString.isDefinedByInstructionSatisfying(Instruction::isConstString)) {
return false;
}
@@ -518,9 +538,10 @@
}
try {
- if (theString.getValue().decodedHashCode() == hash) {
- BasicBlock trueTarget = theIf.targetFromCondition(1).endOfGotoChain();
- if (!addMappingForString(trueTarget, theString.getValue(), extension)) {
+ DexString theStringValue = theString.definition.asConstString().getValue();
+ if (theStringValue.decodedHashCode() == hash) {
+ BasicBlock trueTarget = theIf.targetFromCondition(1);
+ if (!addMappingForString(trueTarget, theStringValue, extension)) {
return false;
}
}
@@ -545,7 +566,17 @@
private boolean addMappingForString(
BasicBlock block, DexString string, Reference2IntMap<DexString> extension) {
InstructionIterator instructionIterator = block.iterator();
- ConstNumber constNumberInstruction = instructionIterator.next().asConstNumber();
+ ConstNumber constNumberInstruction;
+ if (block.isTrivialGoto()) {
+ if (block.getUniqueNormalSuccessor() != idValue.getBlock()) {
+ return false;
+ }
+ int predecessorIndex = idValue.getBlock().getPredecessors().indexOf(block);
+ constNumberInstruction =
+ asConstNumberOrNull(idValue.getOperand(predecessorIndex).definition);
+ } else {
+ constNumberInstruction = instructionIterator.next().asConstNumber();
+ }
if (constNumberInstruction == null
|| !idValue.getOperands().contains(constNumberInstruction.outValue())) {
return false;
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/switches/ConvertRemovedStringSwitchTest.java b/src/test/java/com/android/tools/r8/ir/optimize/switches/ConvertRemovedStringSwitchTest.java
new file mode 100644
index 0000000..183a3b5
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/switches/ConvertRemovedStringSwitchTest.java
@@ -0,0 +1,128 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.switches;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.Reporter;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject.JumboStringMode;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import it.unimi.dsi.fastutil.objects.Reference2IntMap;
+import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class ConvertRemovedStringSwitchTest extends TestBase {
+
+ private final TestParameters parameters;
+
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ public ConvertRemovedStringSwitchTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
+ // TODO(b/135721688): We only introduce string-switches when there is a comparison to the hash
+ // code of a string. Thus, we won't be able to recognize the string-switch in the output until we
+ // have a hash-based string-switch elimination.
+ @Ignore("b/135721688")
+ @Test
+ public void test() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(ConvertRemovedStringSwitchTest.class)
+ .addKeepMainRule(TestClass.class)
+ .addOptionsModification(
+ options -> {
+ assert !options.enableStringSwitchConversion;
+ options.enableStringSwitchConversion = true;
+ })
+ .setMinApi(parameters.getApiLevel())
+ .compile()
+ .inspect(this::inspect)
+ .run(parameters.getRuntime(), TestClass.class, "A", "B", "C", "D", "E")
+ .assertSuccessWithOutputLines("A", "B", "C", "D", "E!");
+ }
+
+ private void inspect(CodeInspector inspector) {
+ ClassSubject classSubject = inspector.clazz(TestClass.class);
+ assertThat(classSubject, isPresent());
+
+ MethodSubject mainMethodSubject = classSubject.mainMethod();
+ assertThat(mainMethodSubject, isPresent());
+
+ DexItemFactory dexItemFactory = new DexItemFactory();
+ InternalOptions options = new InternalOptions(dexItemFactory, new Reporter());
+ assert !options.enableStringSwitchConversion;
+ options.enableStringSwitchConversion = true;
+
+ // Verify that the keys were canonicalized.
+ Reference2IntMap<String> stringCounts = countStrings(mainMethodSubject);
+ assertEquals(1, stringCounts.getInt("A"));
+ assertEquals(1, stringCounts.getInt("B"));
+ assertEquals(1, stringCounts.getInt("B"));
+ assertEquals(1, stringCounts.getInt("D"));
+ assertEquals(1, stringCounts.getInt("E"));
+ assertEquals(1, stringCounts.getInt("E!"));
+
+ // Verify that we can rebuild the StringSwitch instruction.
+ IRCode code = mainMethodSubject.buildIR(options);
+ assertTrue(code.streamInstructions().anyMatch(Instruction::isStringSwitch));
+ }
+
+ private static Reference2IntOpenHashMap<String> countStrings(MethodSubject methodSubject) {
+ Reference2IntMap<String> result = new Reference2IntOpenHashMap<>();
+ methodSubject
+ .streamInstructions()
+ .filter(instruction -> instruction.isConstString(JumboStringMode.ALLOW))
+ .map(InstructionSubject::getConstString)
+ .forEach(string -> result.put(string, result.getInt(string) + 1));
+ return result;
+ }
+
+ static class TestClass {
+
+ public static void main(String[] args) {
+ for (String arg : args) {
+ switch (arg) {
+ case "A":
+ System.out.println("A");
+ break;
+ case "B":
+ System.out.println("B");
+ break;
+ case "C":
+ System.out.println("C");
+ break;
+ case "D":
+ System.out.println("D");
+ break;
+ case "E":
+ // Intentionally "E!" to prevent canonicalization of this key.
+ System.out.println("E!");
+ break;
+ }
+ }
+ }
+ }
+}