Recognize string-switches with intermediate phi id value

Fixes: 188788708
Change-Id: I8bf579d9a38fca9603f7db56fb8558d0889fd5dd
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
index 50bec23..a377238 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/StringSwitchConverter.java
@@ -311,7 +311,7 @@
 
       private final BasicBlock continuationBlock;
       private final DexItemFactory dexItemFactory;
-      private final Phi idValue;
+      private final Phi intermediateIdValue;
       private final Value stringValue;
 
       Builder(
@@ -321,10 +321,38 @@
           Value stringValue) {
         this.continuationBlock = continuationBlock;
         this.dexItemFactory = dexItemFactory;
-        this.idValue = idValue;
+        this.intermediateIdValue = getIntermediateIdValueOrElse(idValue, idValue);
         this.stringValue = stringValue;
       }
 
+      // Finds the intermediate id value from the given (non-intermediate) id value. If the non-
+      // intermediate id value v_id has the following structure, then v_id_intermediate is returned.
+      //
+      //   v_id_intermediate = phi(v1(0), v2(1), v_n(n-1))
+      //   v_id = phi(v0(-1), v_id_intermediate)
+      //
+      // Normally, this intermediate value is not present, and the code will have the following
+      // structure:
+      //
+      //   v_id = phi(v0(-1), v1(0), v2(1), v_n(n-1))
+      private Phi getIntermediateIdValueOrElse(Phi idValue, Phi defaultValue) {
+        if (idValue.getOperands().size() != 2) {
+          return defaultValue;
+        }
+        Phi intermediateIdValue = null;
+        for (Value operand : idValue.getOperands()) {
+          if (operand.isPhi()) {
+            if (intermediateIdValue != null) {
+              return defaultValue;
+            }
+            intermediateIdValue = operand.asPhi();
+          }
+        }
+        assert intermediateIdValue == null
+            || intermediateIdValue.getOperands().stream().noneMatch(Value::isPhi);
+        return intermediateIdValue != null ? intermediateIdValue : defaultValue;
+      }
+
       // Attempts to build a mapping from strings to their ids starting from the given block. The
       // mapping is built by traversing the control flow graph upwards, so the given block is
       // expected to be the last block in the sequence of blocks that compare the hash code of the
@@ -569,17 +597,17 @@
         InstructionIterator instructionIterator = block.iterator();
         ConstNumber constNumberInstruction;
         if (block.isTrivialGoto()) {
-          if (block.getUniqueNormalSuccessor() != idValue.getBlock()) {
+          if (block.getUniqueNormalSuccessor() != intermediateIdValue.getBlock()) {
             return false;
           }
-          int predecessorIndex = idValue.getBlock().getPredecessors().indexOf(block);
+          int predecessorIndex = intermediateIdValue.getBlock().getPredecessors().indexOf(block);
           constNumberInstruction =
-              asConstNumberOrNull(idValue.getOperand(predecessorIndex).definition);
+              asConstNumberOrNull(intermediateIdValue.getOperand(predecessorIndex).definition);
         } else {
           constNumberInstruction = instructionIterator.next().asConstNumber();
         }
         if (constNumberInstruction == null
-            || !idValue.getOperands().contains(constNumberInstruction.outValue())) {
+            || !intermediateIdValue.getOperands().contains(constNumberInstruction.outValue())) {
           return false;
         }
 
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/switches/StringSwitchWitNonIntermediateIdValueTest.java b/src/test/java/com/android/tools/r8/ir/optimize/switches/StringSwitchWitNonIntermediateIdValueTest.java
new file mode 100644
index 0000000..213f02b
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/switches/StringSwitchWitNonIntermediateIdValueTest.java
@@ -0,0 +1,130 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.switches;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class StringSwitchWitNonIntermediateIdValueTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public StringSwitchWitNonIntermediateIdValueTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    assumeTrue(parameters.isDexRuntime());
+    testForD8()
+        .addProgramClasses(Main.class)
+        .addOptionsModification(options -> options.minimumStringSwitchSize = 4)
+        .release()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(this::verifyRewrittenToIfs)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("Foo", "Bar", "Baz", "Qux");
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClasses(Main.class)
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(options -> options.minimumStringSwitchSize = 4)
+        .enableInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(this::verifyRewrittenToIfs)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("Foo", "Bar", "Baz", "Qux");
+  }
+
+  private void verifyRewrittenToIfs(CodeInspector inspector) {
+    MethodSubject testMethodSubject = inspector.clazz(Main.class).uniqueMethodWithName("test");
+    assertThat(testMethodSubject, isPresent());
+    assertTrue(testMethodSubject.streamInstructions().noneMatch(InstructionSubject::isSwitch));
+    assertEquals(
+        3, testMethodSubject.streamInstructions().filter(InstructionSubject::isIf).count());
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      test("Foo");
+      test("Bar");
+      test("Baz");
+      test("Qux");
+    }
+
+    @NeverInline
+    static void test(String str) {
+      int hashCode = str.hashCode();
+      int id = 0;
+      outer:
+      {
+        int nonZeroId;
+        switch (hashCode) {
+          case 70822: // "Foo".hashCode()
+            if (str.equals("Foo")) {
+              nonZeroId = 1;
+              break;
+            }
+            break outer;
+          case 66547: // "Bar".hashCode()
+            if (str.equals("Bar")) {
+              nonZeroId = 2;
+              break;
+            }
+            break outer;
+          case 66555: // "Baz".hashCode()
+            if (str.equals("Baz")) {
+              nonZeroId = 3;
+              break;
+            }
+            break outer;
+          default:
+            break outer;
+        }
+        id = nonZeroId;
+      }
+      switch (id) {
+        case 1:
+          System.out.println("Foo");
+          break;
+        case 2:
+          System.out.println("Bar");
+          break;
+        case 3:
+          System.out.println("Baz");
+          break;
+        default:
+          System.out.println("Qux");
+          break;
+      }
+    }
+  }
+}