Make string-switch conversion robust towards canonicalization

Bug: 135721688
Change-Id: Ia5a15673e1e9615d809d9566265a05fed2eacbb8
diff --git a/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java b/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
index 06aef0a..fcb0c1b 100644
--- a/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
+++ b/src/main/java/com/android/tools/r8/ir/code/ConstNumber.java
@@ -43,6 +43,10 @@
     this.value = value;
   }
 
+  public static ConstNumber asConstNumberOrNull(Instruction instruction) {
+    return (ConstNumber) instruction;
+  }
+
   @Override
   public int opcode() {
     return Opcodes.CONST_NUMBER;
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCode.java b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
index c9d0d84..064842a 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCode.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
@@ -33,6 +33,7 @@
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
+import com.google.common.collect.Streams;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -986,6 +987,10 @@
     return this::instructionIterator;
   }
 
+  public Stream<Instruction> streamInstructions() {
+    return Streams.stream(instructions());
+  }
+
   public <T extends Instruction> Iterable<T> instructions(Predicate<Instruction> predicate) {
     return () -> IteratorUtils.filter(instructionIterator(), predicate);
   }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
index 0365d43..57fd8e0 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
@@ -4,6 +4,8 @@
 
 package com.android.tools.r8.ir.conversion;
 
+import static com.android.tools.r8.ir.code.ConstNumber.asConstNumberOrNull;
+
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexString;
@@ -494,18 +496,36 @@
           Set<BasicBlock> visited) {
         InstructionIterator instructionIterator = block.iterator();
 
-        // Verify that the first instruction is a non-throwing const-string instruction.
-        // If the string throws, it can't be decoded, and then the string does not have a hash.
-        ConstString theString = instructionIterator.next().asConstString();
-        if (theString == null || theString.instructionInstanceCanThrow()) {
+        // The first instruction is expected to be a non-throwing const-string instruction, but this
+        // may change due to canonicalization. If the string throws, it can't be decoded, and then
+        // the string does not have a hash.
+        Instruction first = instructionIterator.next();
+        ConstString optionalString = first.asConstString();
+        if (optionalString != null && optionalString.instructionInstanceCanThrow()) {
           return false;
         }
 
-        InvokeVirtual theInvoke = instructionIterator.next().asInvokeVirtual();
+        // The next instruction must be an invoke-virtual that calls stringValue.equals() with a
+        // constant string argument.
+        InvokeVirtual theInvoke =
+            first.isConstString()
+                ? instructionIterator.next().asInvokeVirtual()
+                : first.asInvokeVirtual();
         if (theInvoke == null
             || theInvoke.getInvokedMethod() != dexItemFactory.stringMembers.equals
-            || theInvoke.getReceiver() != stringValue
-            || theInvoke.inValues().get(1) != theString.outValue()) {
+            || theInvoke.getReceiver() != stringValue) {
+          return false;
+        }
+
+        // If this block starts with a const-string instruction, then it should be passed as the
+        // second argument to equals().
+        if (optionalString != null && theInvoke.getArgument(1) != optionalString.outValue()) {
+          assert false; // This should generally not happen.
+          return false;
+        }
+
+        Value theString = theInvoke.getArgument(1).getAliasedValue();
+        if (!theString.isDefinedByInstructionSatisfying(Instruction::isConstString)) {
           return false;
         }
 
@@ -518,9 +538,10 @@
         }
 
         try {
-          if (theString.getValue().decodedHashCode() == hash) {
-            BasicBlock trueTarget = theIf.targetFromCondition(1).endOfGotoChain();
-            if (!addMappingForString(trueTarget, theString.getValue(), extension)) {
+          DexString theStringValue = theString.definition.asConstString().getValue();
+          if (theStringValue.decodedHashCode() == hash) {
+            BasicBlock trueTarget = theIf.targetFromCondition(1);
+            if (!addMappingForString(trueTarget, theStringValue, extension)) {
               return false;
             }
           }
@@ -545,7 +566,17 @@
       private boolean addMappingForString(
           BasicBlock block, DexString string, Reference2IntMap<DexString> extension) {
         InstructionIterator instructionIterator = block.iterator();
-        ConstNumber constNumberInstruction = instructionIterator.next().asConstNumber();
+        ConstNumber constNumberInstruction;
+        if (block.isTrivialGoto()) {
+          if (block.getUniqueNormalSuccessor() != idValue.getBlock()) {
+            return false;
+          }
+          int predecessorIndex = idValue.getBlock().getPredecessors().indexOf(block);
+          constNumberInstruction =
+              asConstNumberOrNull(idValue.getOperand(predecessorIndex).definition);
+        } else {
+          constNumberInstruction = instructionIterator.next().asConstNumber();
+        }
         if (constNumberInstruction == null
             || !idValue.getOperands().contains(constNumberInstruction.outValue())) {
           return false;
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/switches/ConvertRemovedStringSwitchTest.java b/src/test/java/com/android/tools/r8/ir/optimize/switches/ConvertRemovedStringSwitchTest.java
new file mode 100644
index 0000000..183a3b5
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/switches/ConvertRemovedStringSwitchTest.java
@@ -0,0 +1,128 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.switches;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.Reporter;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject.JumboStringMode;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import it.unimi.dsi.fastutil.objects.Reference2IntMap;
+import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class ConvertRemovedStringSwitchTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public ConvertRemovedStringSwitchTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  // TODO(b/135721688): We only introduce string-switches when there is a comparison to the hash
+  //  code of a string. Thus, we won't be able to recognize the string-switch in the output until we
+  //  have a hash-based string-switch elimination.
+  @Ignore("b/135721688")
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(ConvertRemovedStringSwitchTest.class)
+        .addKeepMainRule(TestClass.class)
+        .addOptionsModification(
+            options -> {
+              assert !options.enableStringSwitchConversion;
+              options.enableStringSwitchConversion = true;
+            })
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), TestClass.class, "A", "B", "C", "D", "E")
+        .assertSuccessWithOutputLines("A", "B", "C", "D", "E!");
+  }
+
+  private void inspect(CodeInspector inspector) {
+    ClassSubject classSubject = inspector.clazz(TestClass.class);
+    assertThat(classSubject, isPresent());
+
+    MethodSubject mainMethodSubject = classSubject.mainMethod();
+    assertThat(mainMethodSubject, isPresent());
+
+    DexItemFactory dexItemFactory = new DexItemFactory();
+    InternalOptions options = new InternalOptions(dexItemFactory, new Reporter());
+    assert !options.enableStringSwitchConversion;
+    options.enableStringSwitchConversion = true;
+
+    // Verify that the keys were canonicalized.
+    Reference2IntMap<String> stringCounts = countStrings(mainMethodSubject);
+    assertEquals(1, stringCounts.getInt("A"));
+    assertEquals(1, stringCounts.getInt("B"));
+    assertEquals(1, stringCounts.getInt("B"));
+    assertEquals(1, stringCounts.getInt("D"));
+    assertEquals(1, stringCounts.getInt("E"));
+    assertEquals(1, stringCounts.getInt("E!"));
+
+    // Verify that we can rebuild the StringSwitch instruction.
+    IRCode code = mainMethodSubject.buildIR(options);
+    assertTrue(code.streamInstructions().anyMatch(Instruction::isStringSwitch));
+  }
+
+  private static Reference2IntOpenHashMap<String> countStrings(MethodSubject methodSubject) {
+    Reference2IntMap<String> result = new Reference2IntOpenHashMap<>();
+    methodSubject
+        .streamInstructions()
+        .filter(instruction -> instruction.isConstString(JumboStringMode.ALLOW))
+        .map(InstructionSubject::getConstString)
+        .forEach(string -> result.put(string, result.getInt(string) + 1));
+    return result;
+  }
+
+  static class TestClass {
+
+    public static void main(String[] args) {
+      for (String arg : args) {
+        switch (arg) {
+          case "A":
+            System.out.println("A");
+            break;
+          case "B":
+            System.out.println("B");
+            break;
+          case "C":
+            System.out.println("C");
+            break;
+          case "D":
+            System.out.println("D");
+            break;
+          case "E":
+            // Intentionally "E!" to prevent canonicalization of this key.
+            System.out.println("E!");
+            break;
+        }
+      }
+    }
+  }
+}