Transform ConstString into DexItemBasedConstString when compared to a class name

Change-Id: Ie485b47c00429e54fbff7cca5034b2af5921a409
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index ab5a2dd..2bd69e5 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -270,6 +270,7 @@
   public final DexType enumType = createType(enumDescriptor);
   public final DexType annotationType = createType(annotationDescriptor);
   public final DexType iterableType = createType(iterableDescriptor);
+  public final DexType referenceFieldUpdaterType = createType(referenceFieldUpdaterDescriptor);
 
   public final DexType classType = createType(classDescriptor);
   public final DexType classLoaderType = createType(classLoaderDescriptor);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 6f55b72..9964c7f 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -228,7 +228,7 @@
       this.outliner = new Outliner(appViewWithLiveness, this);
       this.memberValuePropagation =
           options.enableValuePropagation ? new MemberValuePropagation(appViewWithLiveness) : null;
-      if (!appInfoWithLiveness.identifierNameStrings.isEmpty() && options.isMinifying()) {
+      if (options.isMinifying()) {
         this.identifierNameStringMarker = new IdentifierNameStringMarker(appViewWithLiveness);
       } else {
         this.identifierNameStringMarker = null;
diff --git a/src/main/java/com/android/tools/r8/naming/IdentifierNameStringMarker.java b/src/main/java/com/android/tools/r8/naming/IdentifierNameStringMarker.java
index b2c5675..4e54043 100644
--- a/src/main/java/com/android/tools/r8/naming/IdentifierNameStringMarker.java
+++ b/src/main/java/com/android/tools/r8/naming/IdentifierNameStringMarker.java
@@ -3,8 +3,10 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.naming;
 
+import static com.android.tools.r8.naming.IdentifierNameStringUtils.getPositionOfFirstConstString;
 import static com.android.tools.r8.naming.IdentifierNameStringUtils.identifyIdentifier;
 import static com.android.tools.r8.naming.IdentifierNameStringUtils.inferMemberOrTypeFromNameString;
+import static com.android.tools.r8.naming.IdentifierNameStringUtils.isClassNameComparison;
 import static com.android.tools.r8.naming.IdentifierNameStringUtils.isReflectionMethod;
 
 import com.android.tools.r8.graph.AppView;
@@ -22,6 +24,7 @@
 import com.android.tools.r8.graph.DexValue.DexValueString;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.BasicBlock.ThrowingInfo;
+import com.android.tools.r8.ir.code.ConstString;
 import com.android.tools.r8.ir.code.DexItemBasedConstString;
 import com.android.tools.r8.ir.code.FieldInstruction;
 import com.android.tools.r8.ir.code.IRCode;
@@ -187,53 +190,68 @@
       InstructionListIterator iterator,
       InvokeMethod invoke) {
     DexMethod invokedMethod = invoke.getInvokedMethod();
-    if (!identifierNameStrings.containsKey(invokedMethod)) {
+    boolean isClassNameComparison = isClassNameComparison(invoke, appView.dexItemFactory());
+    if (!identifierNameStrings.containsKey(invokedMethod) && !isClassNameComparison) {
       return iterator;
     }
     List<Value> ins = invoke.arguments();
     Value[] changes = new Value[ins.size()];
-    if (isReflectionMethod(appView.dexItemFactory(), invokedMethod)) {
+    if (isReflectionMethod(appView.dexItemFactory(), invokedMethod) || isClassNameComparison) {
       DexReference itemBasedString = identifyIdentifier(invoke, appView);
       if (itemBasedString == null) {
         DexType context = method.method.holder;
         warnUndeterminedIdentifierIfNecessary(invokedMethod, context, invoke, null);
         return iterator;
       }
-      DexType returnType = invoke.getReturnType();
-      boolean isClassForName = returnType.descriptor == appView.dexItemFactory().classDescriptor;
-      boolean isReferenceFieldUpdater =
-          returnType.descriptor == appView.dexItemFactory().referenceFieldUpdaterDescriptor;
-      int positionOfIdentifier = isClassForName ? 0 : (isReferenceFieldUpdater ? 2 : 1);
-      Value in = invoke.arguments().get(positionOfIdentifier);
-      // Move the cursor back to $invoke
-      assert iterator.peekPrevious() == invoke;
-      iterator.previous();
+      int identifierPosition = getIdentifierPositionInArguments(invoke);
+      Value in = invoke.arguments().get(identifierPosition);
       // Prepare $decoupled just before $invoke
       Value newIn = code.createValue(in.getTypeLattice(), in.getLocalInfo());
       DexItemBasedConstString decoupled =
           new DexItemBasedConstString(newIn, itemBasedString, throwingInfo);
-      decoupled.setPosition(invoke.getPosition());
-      changes[positionOfIdentifier] = newIn;
-      // If the current block has catch handler, split into two blocks.
-      // Because const-string we're about to add is also a throwing instr, we need to split
-      // before adding it.
-      BasicBlock block = invoke.getBlock();
-      BasicBlock blockWithInvoke = block.hasCatchHandlers() ? iterator.split(code, blocks) : block;
-      if (blockWithInvoke != block) {
-        // If we split, add const-string at the end of the currently visiting block.
-        iterator = block.listIterator(block.getInstructions().size());
-        iterator.previous();
-        iterator.add(decoupled);
-        // Restore the cursor and block.
-        iterator = blockWithInvoke.listIterator();
-        assert iterator.peekNext() == invoke;
-        iterator.next();
+      changes[identifierPosition] = newIn;
+
+      if (in.numberOfAllUsers() == 1) {
+        // Simply replace the existing ConstString by a DexItemBasedConstString. No need to check
+        // for catch handlers, as this is replacing one throwing instruction with another.
+        ConstString constString = in.definition.asConstString();
+        if (constString.getBlock() == invoke.getBlock()) {
+          iterator.previousUntil(instruction -> instruction == constString);
+          Instruction current = iterator.next();
+          assert current == constString;
+          iterator.replaceCurrentInstruction(decoupled);
+          iterator.nextUntil(instruction -> instruction == invoke);
+        } else {
+          in.definition.replace(decoupled);
+        }
       } else {
-        // Otherwise, just add it to the current block at the position of the iterator.
-        iterator.add(decoupled);
-        // Restore the cursor.
-        assert iterator.peekNext() == invoke;
-        iterator.next();
+        decoupled.setPosition(invoke.getPosition());
+
+        // Move the cursor back to $invoke
+        assert iterator.peekPrevious() == invoke;
+        iterator.previous();
+        // If the current block has catch handler, split into two blocks.
+        // Because const-string we're about to add is also a throwing instr, we need to split
+        // before adding it.
+        BasicBlock block = invoke.getBlock();
+        BasicBlock blockWithInvoke =
+            block.hasCatchHandlers() ? iterator.split(code, blocks) : block;
+        if (blockWithInvoke != block) {
+          // If we split, add const-string at the end of the currently visiting block.
+          iterator = block.listIterator(block.getInstructions().size());
+          iterator.previous();
+          iterator.add(decoupled);
+          // Restore the cursor and block.
+          iterator = blockWithInvoke.listIterator();
+          assert iterator.peekNext() == invoke;
+          iterator.next();
+        } else {
+          // Otherwise, just add it to the current block at the position of the iterator.
+          iterator.add(decoupled);
+          // Restore the cursor.
+          assert iterator.peekNext() == invoke;
+          iterator.next();
+        }
       }
     } else {
       // For general invoke. Multiple arguments can be string literals to be renamed.
@@ -297,6 +315,29 @@
     return iterator;
   }
 
+  private int getIdentifierPositionInArguments(InvokeMethod invoke) {
+    DexType returnType = invoke.getReturnType();
+    if (isClassNameComparison(invoke, appView.dexItemFactory())) {
+      return getPositionOfFirstConstString(invoke);
+    }
+
+    boolean isClassForName = returnType == appView.dexItemFactory().classType;
+    if (isClassForName) {
+      assert invoke.getInvokedMethod() == appView.dexItemFactory().classMethods.forName;
+      return 0;
+    }
+
+    boolean isReferenceFieldUpdater =
+        returnType == appView.dexItemFactory().referenceFieldUpdaterType;
+    if (isReferenceFieldUpdater) {
+      assert invoke.getInvokedMethod()
+          == appView.dexItemFactory().atomicFieldUpdaterMethods.referenceUpdater;
+      return 2;
+    }
+
+    return 1;
+  }
+
   private void warnUndeterminedIdentifierIfNecessary(
       DexReference member, DexType originHolder, Instruction instruction, DexString original) {
     assert member.isDexField() || member.isDexMethod();
diff --git a/src/main/java/com/android/tools/r8/naming/IdentifierNameStringUtils.java b/src/main/java/com/android/tools/r8/naming/IdentifierNameStringUtils.java
index 3a759c4..37ac36b 100644
--- a/src/main/java/com/android/tools/r8/naming/IdentifierNameStringUtils.java
+++ b/src/main/java/com/android/tools/r8/naming/IdentifierNameStringUtils.java
@@ -24,6 +24,7 @@
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.InvokeStatic;
+import com.android.tools.r8.ir.code.InvokeVirtual;
 import com.android.tools.r8.ir.code.NewArrayEmpty;
 import com.android.tools.r8.ir.code.Value;
 import com.google.common.collect.Sets;
@@ -138,6 +139,30 @@
   }
 
   /**
+   * Returns true if the given invoke instruction is calling `boolean java.lang.String.equals(
+   * java.lang.String)`, and one of the arguments is defined by an invoke-instruction that calls
+   * `java.lang.String java.lang.Class.getName()`.
+   */
+  public static boolean isClassNameComparison(InvokeMethod invoke, DexItemFactory dexItemFactory) {
+    return invoke.isInvokeVirtual()
+        && isClassNameComparison(invoke.asInvokeVirtual(), dexItemFactory);
+  }
+
+  public static boolean isClassNameComparison(InvokeVirtual invoke, DexItemFactory dexItemFactory) {
+    return invoke.getInvokedMethod() == dexItemFactory.stringMethods.equals
+        && (isClassNameValue(invoke.getReceiver(), dexItemFactory)
+            || isClassNameValue(invoke.inValues().get(1), dexItemFactory));
+  }
+
+  private static boolean isClassNameValue(Value value, DexItemFactory dexItemFactory) {
+    Value root = value.getAliasedValue();
+    return !root.isPhi()
+        && root.definition.isInvokeVirtual()
+        && root.definition.asInvokeVirtual().getInvokedMethod()
+            == dexItemFactory.classMethods.getName;
+  }
+
+  /**
    * Returns a {@link DexReference} if one of the arguments to the invoke instruction is a constant
    * string that corresponds to either a class or member name (i.e., an identifier).
    *
@@ -157,6 +182,17 @@
       }
     }
 
+    if (invoke.isInvokeVirtual()) {
+      InvokeVirtual invokeVirtual = invoke.asInvokeVirtual();
+      if (isClassNameComparison(invokeVirtual, definitions.dexItemFactory())) {
+        int argumentIndex = getPositionOfFirstConstString(invokeVirtual);
+        if (argumentIndex >= 0) {
+          return inferTypeFromConstStringValue(
+              definitions, invokeVirtual.inValues().get(argumentIndex));
+        }
+      }
+    }
+
     // All the other cases receive either (Class, String) or (Class, String, Class[]) as ins.
     if (ins.size() == 1) {
       return null;
@@ -210,6 +246,17 @@
     return null;
   }
 
+  static int getPositionOfFirstConstString(Instruction instruction) {
+    List<Value> inValues = instruction.inValues();
+    for (int i = 0; i < inValues.size(); i++) {
+      Value value = inValues.get(i);
+      if (value.getAliasedValue().isConstString()) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
   static DexReference inferMemberOrTypeFromNameString(
       DexDefinitionSupplier definitions, DexString dexString) {
     // "fully.qualified.ClassName.fieldOrMethodName"
@@ -231,6 +278,14 @@
     return null;
   }
 
+  public static DexType inferTypeFromConstStringValue(
+      DexDefinitionSupplier definitions, Value value) {
+    assert value.definition != null;
+    assert value.getAliasedValue().isConstString();
+    return inferTypeFromNameString(
+        definitions, value.getAliasedValue().definition.asConstString().getValue());
+  }
+
   private static DexReference inferMemberFromNameString(
       DexDefinitionSupplier definitions, DexString dexString) {
     String identifier = dexString.toString();
diff --git a/src/test/java/com/android/tools/r8/internal/proto/Proto2ShrinkingTest.java b/src/test/java/com/android/tools/r8/internal/proto/Proto2ShrinkingTest.java
index 485d90a..1f27829 100644
--- a/src/test/java/com/android/tools/r8/internal/proto/Proto2ShrinkingTest.java
+++ b/src/test/java/com/android/tools/r8/internal/proto/Proto2ShrinkingTest.java
@@ -33,8 +33,6 @@
         .addProgramFiles(PROTO2_EXAMPLES_JAR, PROTO2_PROTO_JAR, PROTOBUF_LITE_JAR)
         .addKeepMainRule("proto2.TestClass")
         .addKeepRules(
-            // TODO(b/112437944): Fix -identifiernamestring support.
-            "-keepnames class * extends com.google.protobuf.GeneratedMessageLite",
             // TODO(b/112437944): Use dex item based const strings for proto schema definitions.
             "-keepclassmembernames class * extends com.google.protobuf.GeneratedMessageLite {",
             "  <fields>;",
@@ -46,6 +44,7 @@
             "}",
             allowAccessModification ? "-allowaccessmodification" : "")
         .addKeepRuleFiles(PROTOBUF_LITE_PROGUARD_RULES)
+        .addOptionsModification(options -> options.enableStringSwitchConversion = true)
         .setMinApi(parameters.getRuntime())
         .compile()
         .run(parameters.getRuntime(), "proto2.TestClass")
diff --git a/src/test/java/com/android/tools/r8/naming/identifiernamestring/ClassNameComparisonTest.java b/src/test/java/com/android/tools/r8/naming/identifiernamestring/ClassNameComparisonTest.java
new file mode 100644
index 0000000..5aaf063
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/naming/identifiernamestring/ClassNameComparisonTest.java
@@ -0,0 +1,70 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.naming.identifiernamestring;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class ClassNameComparisonTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimes().build();
+  }
+
+  public ClassNameComparisonTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(ClassNameComparisonTest.class)
+        .addKeepMainRule(TestClass.class)
+        .setMinApi(parameters.getRuntime())
+        .compile()
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("Hello!", "Hello " + B.class.getName() + "!");
+  }
+
+  @Test
+  public void testCorrectnessOfNames() {
+    assertEquals(A.class.getName(), TestClass.NAME_A);
+    assertEquals(B.class.getName(), TestClass.NAME_B);
+  }
+
+  static class TestClass {
+
+    private static final String NAME_A =
+        "com.android.tools.r8.naming.identifiernamestring.ClassNameComparisonTest$A";
+
+    private static final String NAME_B =
+        "com.android.tools.r8.naming.identifiernamestring.ClassNameComparisonTest$B";
+
+    public static void main(String[] args) {
+      if (A.class.getName().equals(NAME_A)) {
+        System.out.println("Hello!");
+      }
+      String name = NAME_B;
+      if (B.class.getName().equals(name)) {
+        System.out.println("Hello " + name + "!");
+      }
+    }
+  }
+
+  static class A {}
+
+  static class B {}
+}