Make StringOptimizer a CodeRewriterPass

Change-Id: I4974531b657470edd9b46e1310783f4bce5ef8ff
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 21c252b..98d46dd 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -111,7 +111,6 @@
   protected final CfInstructionDesugaringCollection instructionDesugaring;
   protected final FieldAccessAnalysis fieldAccessAnalysis;
   protected final LibraryMethodOverrideAnalysis libraryMethodOverrideAnalysis;
-  protected final StringOptimizer stringOptimizer;
   protected final IdempotentFunctionCallCanonicalizer idempotentFunctionCallCanonicalizer;
   private final ClassInliner classInliner;
   protected final InternalOptions options;
@@ -171,7 +170,6 @@
             : null;
     this.classInitializerDefaultsOptimization =
         new ClassInitializerDefaultsOptimization(appView, this);
-    this.stringOptimizer = new StringOptimizer(appView);
     this.deadCodeRemover = new DeadCodeRemover(appView);
     this.assertionsRewriter = new AssertionsRewriter(appView);
     this.idempotentFunctionCallCanonicalizer = new IdempotentFunctionCallCanonicalizer(appView);
@@ -693,18 +691,7 @@
     }
 
     if (!isDebugMode) {
-      // Reflection optimization 2. get*Name() with const-class -> const-string
-      if (options.enableNameReflectionOptimization
-          || options.testing.forceNameReflectionOptimization) {
-        timing.begin("Rewrite Class.getName");
-        stringOptimizer.rewriteClassGetName(appView, code);
-        timing.end();
-      }
-      // Reflection/string optimization 3. trivial conversion/computation on const-string
-      timing.begin("Optimize const strings");
-      stringOptimizer.computeTrivialOperationsOnConstString(code);
-      stringOptimizer.removeTrivialConversions(code);
-      timing.end();
+      new StringOptimizer(appView).run(code, timing);
       timing.begin("Optimize library methods");
       appView
           .libraryMethodOptimizer()
@@ -765,7 +752,6 @@
       assert options.inlinerOptions().enableInlining && inliner != null;
       classInliner.processMethodCode(
           appView.withLiveness(),
-          stringOptimizer,
           code.context(),
           code,
           feedback,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/ClassInliner.java b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/ClassInliner.java
index 112697a..a518875 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/ClassInliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/ClassInliner.java
@@ -124,7 +124,6 @@
   //
   public final void processMethodCode(
       AppView<AppInfoWithLiveness> appView,
-      StringOptimizer stringOptimizer,
       ProgramMethod method,
       IRCode code,
       OptimizationFeedback feedback,
@@ -246,13 +245,7 @@
       // If a method was inlined we may be able to prune additional branches.
       new BranchSimplifier(appView).run(code, Timing.empty());
       // If a method was inlined we may see more trivial computation/conversion of String.
-      boolean isDebugMode =
-          appView.options().debug || method.getOrComputeReachabilitySensitive(appView);
-      if (!isDebugMode) {
-        // Reflection/string optimization 3. trivial conversion/computation on const-string
-        stringOptimizer.computeTrivialOperationsOnConstString(code);
-        stringOptimizer.removeTrivialConversions(code);
-      }
+      new StringOptimizer(appView).run(code, Timing.empty());
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/string/StringOptimizer.java b/src/main/java/com/android/tools/r8/ir/optimize/string/StringOptimizer.java
index fb50e54..0b2f0d8 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/string/StringOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/string/StringOptimizer.java
@@ -11,9 +11,9 @@
 import static com.android.tools.r8.utils.DescriptorUtils.INNER_CLASS_SEPARATOR;
 
 import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
-import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
@@ -31,20 +31,42 @@
 import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.InvokeVirtual;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.CodeRewriterPass;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.ir.optimize.AffectedValues;
 import com.android.tools.r8.naming.dexitembasedstring.ClassNameComputationInfo;
 import java.io.UTFDataFormatException;
 import java.util.function.BiFunction;
 import java.util.function.Function;
 
-public class StringOptimizer {
-
-  private final AppView<?> appView;
-  private final DexItemFactory factory;
+public class StringOptimizer extends CodeRewriterPass<AppInfo> {
 
   public StringOptimizer(AppView<?> appView) {
-    this.appView = appView;
-    this.factory = appView.dexItemFactory();
+    super(appView);
+  }
+
+  @Override
+  protected String getTimingId() {
+    return "StringOptimizer";
+  }
+
+  @Override
+  protected boolean shouldRewriteCode(IRCode code) {
+    return !isDebugMode(code.context());
+  }
+
+  @Override
+  protected CodeRewriterResult rewriteCode(IRCode code) {
+    boolean hasChanged = false;
+    if (options.enableNameReflectionOptimization
+        || options.testing.forceNameReflectionOptimization) {
+      hasChanged |= rewriteClassGetName(code);
+    }
+    if (code.metadata().mayHaveConstString()) {
+      hasChanged |= computeTrivialOperationsOnConstString(code);
+    }
+    hasChanged |= removeTrivialConversions(code);
+    return CodeRewriterResult.hasChanged(hasChanged);
   }
 
   // boolean String#isEmpty()
@@ -65,10 +87,8 @@
   // String String#substring(int)
   // String String#substring(int, int)
   // String String#trim()
-  public void computeTrivialOperationsOnConstString(IRCode code) {
-    if (!code.metadata().mayHaveConstString()) {
-      return;
-    }
+  private boolean computeTrivialOperationsOnConstString(IRCode code) {
+    boolean hasChanged = false;
     AffectedValues affectedValues = new AffectedValues();
     InstructionListIterator it = code.instructionListIterator();
     while (it.hasNext()) {
@@ -81,10 +101,10 @@
         continue;
       }
       DexMethod invokedMethod = invoke.getInvokedMethod();
-      if (invokedMethod.getHolderType() != factory.stringType) {
+      if (invokedMethod.getHolderType() != dexItemFactory.stringType) {
         continue;
       }
-      if (invokedMethod.getName() == factory.substringName) {
+      if (invokedMethod.getName() == dexItemFactory.substringName) {
         assert invoke.inValues().size() == 2 || invoke.inValues().size() == 3;
         Value rcv = invoke.getReceiver().getAliasedValue();
         if (rcv.definition == null
@@ -123,18 +143,20 @@
         Value stringValue =
             code.createValue(
                 TypeElement.stringClassType(appView, definitelyNotNull()), invoke.getLocalInfo());
+        affectedValues.addAll(invoke.outValue().affectedValues());
         it.replaceCurrentInstruction(
-            new ConstString(stringValue, factory.createString(sub)), affectedValues);
+            new ConstString(stringValue, dexItemFactory.createString(sub)));
         continue;
       }
 
-      if (invokedMethod == factory.stringMembers.trim) {
+      if (invokedMethod == dexItemFactory.stringMembers.trim) {
         Value receiver = invoke.getReceiver().getAliasedValue();
         if (receiver.hasLocalInfo() || receiver.isPhi() || !receiver.definition.isConstString()) {
           continue;
         }
         DexString resultString =
-            factory.createString(receiver.definition.asConstString().getValue().toString().trim());
+            dexItemFactory.createString(
+                receiver.definition.asConstString().getValue().toString().trim());
         Value newOutValue =
             code.createValue(
                 TypeElement.stringClassType(appView, definitelyNotNull()), invoke.getLocalInfo());
@@ -145,7 +167,7 @@
       Function<DexString, Integer> operatorWithNoArg = null;
       BiFunction<DexString, DexString, Integer> operatorWithString = null;
       BiFunction<DexString, Integer, Integer> operatorWithInt = null;
-      if (invokedMethod == factory.stringMembers.hashCode) {
+      if (invokedMethod == dexItemFactory.stringMembers.hashCode) {
         operatorWithNoArg = rcv -> {
           try {
             return rcv.decodedHashCode();
@@ -154,33 +176,33 @@
             throw new Unreachable();
           }
         };
-      } else if (invokedMethod == factory.stringMembers.length) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.length) {
         operatorWithNoArg = rcv -> rcv.size;
-      } else if (invokedMethod == factory.stringMembers.isEmpty) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.isEmpty) {
         operatorWithNoArg = rcv -> rcv.size == 0 ? 1 : 0;
-      } else if (invokedMethod == factory.stringMembers.contains) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.contains) {
         operatorWithString = (rcv, arg) -> rcv.toString().contains(arg.toString()) ? 1 : 0;
-      } else if (invokedMethod == factory.stringMembers.startsWith) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.startsWith) {
         operatorWithString = (rcv, arg) -> rcv.startsWith(arg) ? 1 : 0;
-      } else if (invokedMethod == factory.stringMembers.endsWith) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.endsWith) {
         operatorWithString = (rcv, arg) -> rcv.endsWith(arg) ? 1 : 0;
-      } else if (invokedMethod == factory.stringMembers.equals) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.equals) {
         operatorWithString = (rcv, arg) -> rcv.equals(arg) ? 1 : 0;
-      } else if (invokedMethod == factory.stringMembers.equalsIgnoreCase) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.equalsIgnoreCase) {
         operatorWithString = (rcv, arg) -> rcv.toString().equalsIgnoreCase(arg.toString()) ? 1 : 0;
-      } else if (invokedMethod == factory.stringMembers.contentEqualsCharSequence) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.contentEqualsCharSequence) {
         operatorWithString = (rcv, arg) -> rcv.toString().contentEquals(arg.toString()) ? 1 : 0;
-      } else if (invokedMethod == factory.stringMembers.indexOfInt) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.indexOfInt) {
         operatorWithInt = (rcv, idx) -> rcv.toString().indexOf(idx);
-      } else if (invokedMethod == factory.stringMembers.indexOfString) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.indexOfString) {
         operatorWithString = (rcv, arg) -> rcv.toString().indexOf(arg.toString());
-      } else if (invokedMethod == factory.stringMembers.lastIndexOfInt) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.lastIndexOfInt) {
         operatorWithInt = (rcv, idx) -> rcv.toString().lastIndexOf(idx);
-      } else if (invokedMethod == factory.stringMembers.lastIndexOfString) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.lastIndexOfString) {
         operatorWithString = (rcv, arg) -> rcv.toString().lastIndexOf(arg.toString());
-      } else if (invokedMethod == factory.stringMembers.compareTo) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.compareTo) {
         operatorWithString = (rcv, arg) -> rcv.toString().compareTo(arg.toString());
-      } else if (invokedMethod == factory.stringMembers.compareToIgnoreCase) {
+      } else if (invokedMethod == dexItemFactory.stringMembers.compareToIgnoreCase) {
         operatorWithString = (rcv, arg) -> rcv.toString().compareToIgnoreCase(arg.toString());
       } else {
         continue;
@@ -222,15 +244,18 @@
         constNumber = code.createIntConstant(v);
       }
 
+      hasChanged = true;
       it.replaceCurrentInstruction(constNumber);
     }
     // Computed substring is not null, and thus propagate that information.
     affectedValues.narrowingWithAssumeRemoval(appView, code);
     assert code.isConsistentSSA(appView);
+    return hasChanged;
   }
 
   // Find Class#get*Name() with a constant-class and replace it with a const-string if possible.
-  public void rewriteClassGetName(AppView<?> appView, IRCode code) {
+  private boolean rewriteClassGetName(IRCode code) {
+    boolean hasChanged = false;
     AffectedValues affectedValues = new AffectedValues();
     InstructionListIterator it = code.instructionListIterator();
     while (it.hasNext()) {
@@ -240,7 +265,7 @@
       }
       InvokeVirtual invoke = instr.asInvokeVirtual();
       DexMethod invokedMethod = invoke.getInvokedMethod();
-      if (!factory.classMethods.isReflectiveNameLookup(invokedMethod)) {
+      if (!dexItemFactory.classMethods.isReflectiveNameLookup(invokedMethod)) {
         continue;
       }
 
@@ -272,7 +297,7 @@
       ConstClass constClass = in.definition.asConstClass();
       DexType type = constClass.getValue();
       int arrayDepth = type.getNumberOfLeadingSquareBrackets();
-      DexType baseType = type.toBaseType(factory);
+      DexType baseType = type.toBaseType(dexItemFactory);
       // Make sure base type is a class type.
       if (!baseType.isClassType()) {
         continue;
@@ -294,7 +319,7 @@
         if (mayBeRenamed) {
           continue;
         }
-        if (invokedMethod != factory.classMethods.getSimpleName) {
+        if (invokedMethod != dexItemFactory.classMethods.getSimpleName) {
           EscapeAnalysis escapeAnalysis =
               new EscapeAnalysis(appView, StringOptimizerEscapeAnalysisConfiguration.getInstance());
           if (escapeAnalysis.isEscaping(code, out)) {
@@ -307,7 +332,7 @@
       boolean assumeTopLevel = descriptor.indexOf(INNER_CLASS_SEPARATOR) < 0;
       DexItemBasedConstString deferred = null;
       DexString name = null;
-      if (invokedMethod == factory.classMethods.getName) {
+      if (invokedMethod == dexItemFactory.classMethods.getName) {
         if (mayBeRenamed) {
           Value stringValue =
               code.createValue(
@@ -316,12 +341,12 @@
               new DexItemBasedConstString(
                   stringValue, baseType, ClassNameComputationInfo.create(NAME, arrayDepth));
         } else {
-          name = NAME.map(descriptor, holder, factory, arrayDepth);
+          name = NAME.map(descriptor, holder, dexItemFactory, arrayDepth);
         }
-      } else if (invokedMethod == factory.classMethods.getTypeName) {
+      } else if (invokedMethod == dexItemFactory.classMethods.getTypeName) {
         // TODO(b/119426668): desugar Type#getTypeName
         continue;
-      } else if (invokedMethod == factory.classMethods.getCanonicalName) {
+      } else if (invokedMethod == dexItemFactory.classMethods.getCanonicalName) {
         // Always returns null if the target type is local or anonymous class.
         if (holder.isLocalClass() || holder.isAnonymousClass()) {
           ConstNumber constNull = code.createConstNull();
@@ -343,13 +368,13 @@
                     baseType,
                     ClassNameComputationInfo.create(CANONICAL_NAME, arrayDepth));
           } else {
-            name = CANONICAL_NAME.map(descriptor, holder, factory, arrayDepth);
+            name = CANONICAL_NAME.map(descriptor, holder, dexItemFactory, arrayDepth);
           }
         }
-      } else if (invokedMethod == factory.classMethods.getSimpleName) {
+      } else if (invokedMethod == dexItemFactory.classMethods.getSimpleName) {
         // Always returns an empty string if the target type is an anonymous class.
         if (holder.isAnonymousClass()) {
-          name = factory.createString("");
+          name = dexItemFactory.createString("");
         } else {
           // b/120130435: If an outer class is shrunk, we may compute a wrong simple name.
           // Leave it as-is so that the class's simple name is consistent across the app.
@@ -367,7 +392,7 @@
                     baseType,
                     ClassNameComputationInfo.create(SIMPLE_NAME, arrayDepth));
           } else {
-            name = SIMPLE_NAME.map(descriptor, holder, factory, arrayDepth);
+            name = SIMPLE_NAME.map(descriptor, holder, dexItemFactory, arrayDepth);
           }
         }
       }
@@ -377,20 +402,24 @@
                 TypeElement.stringClassType(appView, definitelyNotNull()), invoke.getLocalInfo());
         ConstString constString = new ConstString(stringValue, name);
         it.replaceCurrentInstruction(constString, affectedValues);
+        hasChanged = true;
       } else if (deferred != null) {
         it.replaceCurrentInstruction(deferred, affectedValues);
+        hasChanged = true;
       }
     }
     // Computed name is not null or literally null (for canonical name of local/anonymous class).
     // In either way, that is narrower information, and thus propagate that.
     affectedValues.narrowingWithAssumeRemoval(appView, code);
     assert code.isConsistentSSA(appView);
+    return hasChanged;
   }
 
   // String#valueOf(null) -> "null"
   // String#valueOf(String s) -> s
   // str.toString() -> str
-  public void removeTrivialConversions(IRCode code) {
+  private boolean removeTrivialConversions(IRCode code) {
+    boolean hasChanged = false;
     AffectedValues affectedValues = new AffectedValues();
     InstructionListIterator it = code.instructionListIterator();
     while (it.hasNext()) {
@@ -398,7 +427,7 @@
       if (instr.isInvokeStatic()) {
         InvokeStatic invoke = instr.asInvokeStatic();
         DexMethod invokedMethod = invoke.getInvokedMethod();
-        if (invokedMethod != factory.stringMembers.valueOf) {
+        if (invokedMethod != dexItemFactory.stringMembers.valueOf) {
           continue;
         }
         assert invoke.inValues().size() == 1;
@@ -410,10 +439,11 @@
         TypeElement inType = in.getType();
         if (out != null && in.isAlwaysNull(appView)) {
           affectedValues.addAll(out.affectedValues());
-          it.replaceCurrentInstructionWithConstString(appView, code, factory.createString("null"));
+          it.replaceCurrentInstructionWithConstString(
+              appView, code, dexItemFactory.createString("null"));
         } else if (inType.nullability().isDefinitelyNotNull()
             && inType.isClassType()
-            && inType.asClassType().getClassType().equals(factory.stringType)) {
+            && inType.asClassType().getClassType().equals(dexItemFactory.stringType)) {
           if (out != null) {
             affectedValues.addAll(out.affectedValues());
             removeOrReplaceByDebugLocalWrite(invoke, it, in, out);
@@ -421,10 +451,11 @@
             it.removeOrReplaceByDebugLocalRead();
           }
         }
+        hasChanged = true;
       } else if (instr.isInvokeVirtual()) {
         InvokeVirtual invoke = instr.asInvokeVirtual();
         DexMethod invokedMethod = invoke.getInvokedMethod();
-        if (invokedMethod != factory.stringMembers.toString) {
+        if (invokedMethod != dexItemFactory.stringMembers.toString) {
           continue;
         }
         assert invoke.inValues().size() == 1;
@@ -432,7 +463,7 @@
         TypeElement inType = in.getType();
         if (inType.nullability().isDefinitelyNotNull()
             && inType.isClassType()
-            && inType.asClassType().getClassType().equals(factory.stringType)) {
+            && inType.asClassType().getClassType().equals(dexItemFactory.stringType)) {
           Value out = invoke.outValue();
           if (out != null) {
             affectedValues.addAll(out.affectedValues());
@@ -441,12 +472,14 @@
             it.removeOrReplaceByDebugLocalRead();
           }
         }
+        hasChanged = true;
       }
     }
     // Newly added "null" string is not null, and thus propagate that information.
     affectedValues.narrowingWithAssumeRemoval(appView, code);
     code.removeRedundantBlocks();
     assert code.isConsistentSSA(appView);
+    return hasChanged;
   }
 
   static class StringOptimizerEscapeAnalysisConfiguration