Split MoveResultRewriter off CodeRewriter

Bug: b/284304606
Change-Id: I2157b612e60dd52233c36305263d4156457bf269
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 2a2813f..6007ea9 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -33,6 +33,7 @@
 import com.android.tools.r8.ir.conversion.passes.CommonSubexpressionElimination;
 import com.android.tools.r8.ir.conversion.passes.DexConstantOptimizer;
 import com.android.tools.r8.ir.conversion.passes.KnownArrayLengthRewriter;
+import com.android.tools.r8.ir.conversion.passes.MoveResultRewriter;
 import com.android.tools.r8.ir.conversion.passes.NaturalIntLoopRemover;
 import com.android.tools.r8.ir.conversion.passes.ParentConstructorHoistingCodeRewriter;
 import com.android.tools.r8.ir.conversion.passes.SplitBranch;
@@ -182,7 +183,7 @@
     this.classInitializerDefaultsOptimization =
         new ClassInitializerDefaultsOptimization(appView, this);
     this.stringOptimizer = new StringOptimizer(appView);
-    this.deadCodeRemover = new DeadCodeRemover(appView, codeRewriter);
+    this.deadCodeRemover = new DeadCodeRemover(appView);
     this.assertionsRewriter = new AssertionsRewriter(appView);
     this.idempotentFunctionCallCanonicalizer = new IdempotentFunctionCallCanonicalizer(appView);
     this.neverMerge =
@@ -383,7 +384,7 @@
   private void processAndFinalizeSimpleSynthesizedMethod(ProgramMethod method) {
     IRCode code = method.buildIR(appView);
     assert code != null;
-    codeRewriter.rewriteMoveResult(code);
+    new MoveResultRewriter(appView).run(code, Timing.empty());
     removeDeadCodeAndFinalizeIR(code, OptimizationFeedbackIgnore.getInstance(), Timing.empty());
   }
 
@@ -750,9 +751,7 @@
     }
     commonSubexpressionElimination.run(code, timing);
     new ArrayConstructionSimplifier(appView).run(code, timing);
-    timing.begin("Rewrite move result");
-    codeRewriter.rewriteMoveResult(code);
-    timing.end();
+    new MoveResultRewriter(appView).run(code, timing);
     if (options.enableStringConcatenationOptimization && !isDebugMode) {
       timing.begin("Rewrite string concat");
       StringBuilderAppendOptimizer.run(appView, code);
@@ -930,7 +929,7 @@
       assert code.isConsistentSSA(appView);
 
       // TODO(b/214496607): Remove when dynamic types are safe w.r.t. interface assignment rules.
-      codeRewriter.rewriteMoveResult(code);
+      new MoveResultRewriter(appView).run(code, timing);
     }
 
     // Assert that we do not have unremoved non-sense code in the output, e.g., v <- non-null NULL.
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/ArrayConstructionSimplifier.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/ArrayConstructionSimplifier.java
index 63e6b36..61cc2c1 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/ArrayConstructionSimplifier.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/ArrayConstructionSimplifier.java
@@ -20,6 +20,7 @@
 import com.android.tools.r8.ir.code.NewArrayEmpty;
 import com.android.tools.r8.ir.code.NewArrayFilledData;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.InternalOptions.RewriteArrayOptions;
 import com.android.tools.r8.utils.SetUtils;
@@ -92,12 +93,13 @@
   }
 
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     WorkList<BasicBlock> worklist = WorkList.newIdentityWorkList(code.blocks);
     while (worklist.hasNext()) {
       BasicBlock block = worklist.next();
       simplifyArrayConstructionBlock(block, worklist, code, appView.options());
     }
+    return CodeRewriterResult.NONE;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
index 24fad2e..f684ff5 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/BinopRewriter.java
@@ -27,6 +27,7 @@
 import com.android.tools.r8.ir.code.Ushr;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.code.Xor;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.utils.WorkList;
 import com.google.common.collect.ImmutableMap;
 import java.util.Map;
@@ -248,7 +249,7 @@
   }
 
   @Override
-  public void rewriteCode(IRCode code) {
+  public CodeRewriterResult rewriteCode(IRCode code) {
     InstructionListIterator iterator = code.instructionListIterator();
     while (iterator.hasNext()) {
       Instruction next = iterator.next();
@@ -268,6 +269,7 @@
     code.removeAllDeadAndTrivialPhis();
     code.removeRedundantBlocks();
     assert code.isConsistentSSA(appView);
+    return CodeRewriterResult.NONE;
   }
 
   private void successiveSimplification(
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/CodeRewriterPass.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/CodeRewriterPass.java
index 94e30d3..aba955c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/CodeRewriterPass.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/CodeRewriterPass.java
@@ -11,6 +11,7 @@
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.Timing;
 
@@ -31,38 +32,43 @@
     return (AppView<? extends T>) appView;
   }
 
-  public final void run(
+  public final CodeRewriterResult run(
       IRCode code,
       MethodProcessor methodProcessor,
       MethodProcessingContext methodProcessingContext,
       Timing timing) {
-    timing.time(getTimingId(), () -> run(code, methodProcessor, methodProcessingContext));
+    return timing.time(getTimingId(), () -> run(code, methodProcessor, methodProcessingContext));
   }
 
-  public final void run(IRCode code, Timing timing) {
-    timing.time(getTimingId(), () -> run(code, null, null));
+  public final CodeRewriterResult run(IRCode code, Timing timing) {
+    return timing.time(getTimingId(), () -> run(code, null, null));
   }
 
-  private void run(
+  private CodeRewriterResult run(
       IRCode code,
       MethodProcessor methodProcessor,
       MethodProcessingContext methodProcessingContext) {
     if (shouldRewriteCode(code)) {
-      rewriteCode(code, methodProcessor, methodProcessingContext);
+      return rewriteCode(code, methodProcessor, methodProcessingContext);
     }
+    return noChange();
+  }
+
+  protected CodeRewriterResult noChange() {
+    return CodeRewriterResult.NO_CHANGE;
   }
 
   protected abstract String getTimingId();
 
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     throw new Unreachable("Should Override or use overload");
   }
 
-  protected void rewriteCode(
+  protected CodeRewriterResult rewriteCode(
       IRCode code,
       MethodProcessor methodProcessor,
       MethodProcessingContext methodProcessingContext) {
-    rewriteCode(code);
+    return rewriteCode(code);
   }
 
   protected abstract boolean shouldRewriteCode(IRCode code);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/CommonSubexpressionElimination.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/CommonSubexpressionElimination.java
index 5d88ded..7057abc 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/CommonSubexpressionElimination.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/CommonSubexpressionElimination.java
@@ -15,6 +15,7 @@
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.utils.InternalOptions;
 import com.google.common.base.Equivalence;
 import com.google.common.base.Equivalence.Wrapper;
@@ -39,7 +40,7 @@
   }
 
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     int noCandidate = code.reserveMarkingColor();
     if (hasCSECandidate(code, noCandidate)) {
       final ListMultimap<Wrapper<Instruction>, Value> instructionToValue =
@@ -79,6 +80,7 @@
     code.returnMarkingColor(noCandidate);
     code.removeRedundantBlocks();
     assert code.isConsistentSSA(appView);
+    return CodeRewriterResult.NONE;
   }
 
   private static class CSEExpressionEquivalence extends Equivalence<Instruction> {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/DexConstantOptimizer.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/DexConstantOptimizer.java
index c6e07ed..bcfd88c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/DexConstantOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/DexConstantOptimizer.java
@@ -31,6 +31,7 @@
 import com.android.tools.r8.ir.code.Position;
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.ir.optimize.ConstantCanonicalizer;
 import com.android.tools.r8.utils.LazyBox;
 import com.google.common.collect.Iterables;
@@ -63,9 +64,10 @@
   }
 
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     useDedicatedConstantForLitInstruction(code);
     shortenLiveRanges(code, constantCanonicalizer);
+    return CodeRewriterResult.NONE;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/KnownArrayLengthRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/KnownArrayLengthRewriter.java
index 9ae7aab..291f8b9 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/KnownArrayLengthRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/KnownArrayLengthRewriter.java
@@ -13,6 +13,7 @@
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import java.util.Set;
 
 public class KnownArrayLengthRewriter extends CodeRewriterPass<AppInfo> {
@@ -32,7 +33,7 @@
   }
 
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     InstructionListIterator iterator = code.instructionListIterator();
     while (iterator.hasNext()) {
       Instruction current = iterator.next();
@@ -78,5 +79,6 @@
     }
     code.removeRedundantBlocks();
     assert code.isConsistentSSA(appView);
+    return CodeRewriterResult.NONE;
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/MoveResultRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/MoveResultRewriter.java
new file mode 100644
index 0000000..4481aa0
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/MoveResultRewriter.java
@@ -0,0 +1,145 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.conversion.passes;
+
+import static com.android.tools.r8.ir.analysis.type.Nullability.maybeNull;
+
+import com.android.tools.r8.graph.AppInfo;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClassAndMethod;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
+import com.android.tools.r8.ir.optimize.AssumeRemover;
+import com.android.tools.r8.ir.optimize.info.MethodOptimizationInfo;
+import com.google.common.collect.Sets;
+import java.util.Collections;
+import java.util.ListIterator;
+import java.util.Set;
+
+public class MoveResultRewriter extends CodeRewriterPass<AppInfo> {
+
+  public MoveResultRewriter(AppView<?> appView) {
+    super(appView);
+  }
+
+  @Override
+  protected String getTimingId() {
+    return "MoveResultRewriter";
+  }
+
+  @Override
+  protected boolean shouldRewriteCode(IRCode code) {
+    return options.isGeneratingDex() && code.metadata().mayHaveInvokeMethod();
+  }
+
+  // Replace result uses for methods where something is known about what is returned.
+  @Override
+  protected CodeRewriterResult rewriteCode(IRCode code) {
+    AssumeRemover assumeRemover = new AssumeRemover(appView, code);
+    boolean changed = false;
+    boolean mayHaveRemovedTrivialPhi = false;
+    Set<BasicBlock> blocksToBeRemoved = Sets.newIdentityHashSet();
+    ListIterator<BasicBlock> blockIterator = code.listIterator();
+    while (blockIterator.hasNext()) {
+      BasicBlock block = blockIterator.next();
+      if (blocksToBeRemoved.contains(block)) {
+        continue;
+      }
+
+      InstructionListIterator iterator = block.listIterator(code);
+      while (iterator.hasNext()) {
+        InvokeMethod invoke = iterator.next().asInvokeMethod();
+        if (invoke == null || !invoke.hasOutValue() || invoke.outValue().hasLocalInfo()) {
+          continue;
+        }
+
+        // Check if the invoked method is known to return one of its arguments.
+        DexClassAndMethod target = invoke.lookupSingleTarget(appView, code.context());
+        if (target == null) {
+          continue;
+        }
+
+        MethodOptimizationInfo optimizationInfo = target.getDefinition().getOptimizationInfo();
+        if (!optimizationInfo.returnsArgument()) {
+          continue;
+        }
+
+        int argumentIndex = optimizationInfo.getReturnedArgument();
+        // Replace the out value of the invoke with the argument and ignore the out value.
+        if (argumentIndex < 0 || !checkArgumentType(invoke, argumentIndex)) {
+          continue;
+        }
+
+        Value argument = invoke.arguments().get(argumentIndex);
+        Value outValue = invoke.outValue();
+        assert outValue.verifyCompatible(argument.outType());
+
+        // Make sure that we are only narrowing information here. Note, in cases where we cannot
+        // find the definition of types, computing lessThanOrEqual will return false unless it is
+        // object.
+        if (!argument.getType().lessThanOrEqual(outValue.getType(), appView)) {
+          continue;
+        }
+
+        Set<Value> affectedValues =
+            argument.getType().equals(outValue.getType())
+                ? Collections.emptySet()
+                : outValue.affectedValues();
+
+        assumeRemover.markAssumeDynamicTypeUsersForRemoval(outValue);
+        mayHaveRemovedTrivialPhi |= outValue.numberOfPhiUsers() > 0;
+        outValue.replaceUsers(argument);
+        invoke.setOutValue(null);
+        changed = true;
+
+        if (!affectedValues.isEmpty()) {
+          new TypeAnalysis(appView).narrowing(affectedValues);
+        }
+      }
+    }
+    assumeRemover.removeMarkedInstructions(blocksToBeRemoved).finish();
+    Set<Value> affectedValues = Sets.newIdentityHashSet();
+    if (!blocksToBeRemoved.isEmpty()) {
+      code.removeBlocks(blocksToBeRemoved);
+      code.removeAllDeadAndTrivialPhis(affectedValues);
+      assert code.getUnreachableBlocks().isEmpty();
+    } else if (mayHaveRemovedTrivialPhi || assumeRemover.mayHaveIntroducedTrivialPhi()) {
+      code.removeAllDeadAndTrivialPhis(affectedValues);
+    }
+    if (!affectedValues.isEmpty()) {
+      new TypeAnalysis(appView).narrowing(affectedValues);
+    }
+    assert code.isConsistentSSA(appView);
+    return CodeRewriterResult.hasChanged(changed);
+  }
+
+  private boolean checkArgumentType(InvokeMethod invoke, int argumentIndex) {
+    // TODO(sgjesse): Insert cast if required.
+    TypeElement returnType =
+        TypeElement.fromDexType(invoke.getInvokedMethod().proto.returnType, maybeNull(), appView);
+    TypeElement argumentType =
+        TypeElement.fromDexType(getArgumentType(invoke, argumentIndex), maybeNull(), appView);
+    return appView.enableWholeProgramOptimizations()
+        ? argumentType.lessThanOrEqual(returnType, appView)
+        : argumentType.equals(returnType);
+  }
+
+  private DexType getArgumentType(InvokeMethod invoke, int argumentIndex) {
+    if (invoke.isInvokeStatic()) {
+      return invoke.getInvokedMethod().proto.parameters.values[argumentIndex];
+    }
+    if (argumentIndex == 0) {
+      return invoke.getInvokedMethod().holder;
+    }
+    return invoke.getInvokedMethod().proto.parameters.values[argumentIndex - 1];
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/NaturalIntLoopRemover.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/NaturalIntLoopRemover.java
index be7840f..2238b82 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/NaturalIntLoopRemover.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/NaturalIntLoopRemover.java
@@ -15,6 +15,7 @@
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.Sub;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.utils.WorkList;
 import com.google.common.collect.Sets;
 import java.util.Set;
@@ -40,7 +41,7 @@
   }
 
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     boolean loopRemoved = false;
     for (BasicBlock comparisonBlockCandidate : code.blocks) {
       if (isComparisonBlock(comparisonBlockCandidate)) {
@@ -52,6 +53,7 @@
       code.removeRedundantBlocks();
       assert code.isConsistentSSA(appView);
     }
+    return CodeRewriterResult.NONE;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/ParentConstructorHoistingCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/ParentConstructorHoistingCodeRewriter.java
index 4f96ce7..6a7c040 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/ParentConstructorHoistingCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/ParentConstructorHoistingCodeRewriter.java
@@ -14,6 +14,7 @@
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.InvokeDirect;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.shaking.KeepMethodInfo;
 import com.android.tools.r8.utils.CollectionUtils;
 import com.android.tools.r8.utils.IterableUtils;
@@ -48,10 +49,11 @@
   }
 
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     for (InvokeDirect invoke : getOrComputeSideEffectFreeConstructorCalls(code)) {
       hoistSideEffectFreeConstructorCall(code, invoke);
     }
+    return CodeRewriterResult.NONE;
   }
 
   private void hoistSideEffectFreeConstructorCall(IRCode code, InvokeDirect invoke) {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
index fbb3afd..e4bac3e 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/SplitBranch.java
@@ -14,6 +14,7 @@
 import com.android.tools.r8.ir.code.If;
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.WorkList;
 import com.google.common.collect.Sets;
@@ -43,24 +44,22 @@
   }
 
   /**
-   * Simplify Boolean branches for example: <code>
-   * boolean b = i == j; if (b) { ... } else { ... }
+   * Simplify Boolean branches for example: <code> boolean b = i == j; if (b) { ... } else { ... }
    * </code> ends up first creating a branch for the boolean b, then a second branch on b. D8/R8
-   * rewrites to: <code>
-   * if (i == j) { ... } else { ... }
+   * rewrites to: <code> if (i == j) { ... } else { ... }
    * </code> More complex control flow are also supported to some extent, including cases where the
    * input of the second branch comes from a set of dependent phis, and a subset of the inputs are
    * known boolean values.
    */
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     List<BasicBlock> candidates = computeCandidates(code);
     if (candidates.isEmpty()) {
-      return;
+      return CodeRewriterResult.NONE;
     }
     Map<Goto, BasicBlock> newTargets = findGotosToRetarget(candidates);
     if (newTargets.isEmpty()) {
-      return;
+      return CodeRewriterResult.NONE;
     }
     retargetGotos(newTargets);
     Set<Value> affectedValues = Sets.newIdentityHashSet();
@@ -74,6 +73,7 @@
     }
     code.removeRedundantBlocks();
     assert code.isConsistentSSA(appView);
+    return CodeRewriterResult.NONE;
   }
 
   private void retargetGotos(Map<Goto, BasicBlock> newTargets) {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
index cbce46f..77104d5 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialCheckCastAndInstanceOfRemover.java
@@ -28,6 +28,7 @@
 import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.ir.optimize.CodeRewriter;
 import com.android.tools.r8.ir.optimize.UtilityMethodsForCodeOptimizations;
 import com.android.tools.r8.ir.optimize.UtilityMethodsForCodeOptimizations.UtilityMethodForCodeOptimizations;
@@ -54,7 +55,7 @@
   }
 
   @Override
-  protected void rewriteCode(
+  protected CodeRewriterResult rewriteCode(
       IRCode code,
       MethodProcessor methodProcessor,
       MethodProcessingContext methodProcessingContext) {
@@ -124,6 +125,7 @@
     }
     code.removeRedundantBlocks();
     assert code.isConsistentSSA(appView);
+    return CodeRewriterResult.NONE;
   }
 
   enum RemoveCheckCastInstructionIfTrivialResult {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialGotosCollapser.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialGotosCollapser.java
index 4f8ee2e..a08d82a 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialGotosCollapser.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/TrivialGotosCollapser.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.If;
 import com.android.tools.r8.ir.code.Switch;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
@@ -33,7 +34,7 @@
   }
 
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     assert code.isConsistentGraph(appView);
     List<BasicBlock> blocksToRemove = new ArrayList<>();
     // Rewrite all non-fallthrough targets to the end of trivial goto chains and remove
@@ -73,6 +74,7 @@
     }
     assert removedTrivialGotos(code);
     assert code.isConsistentGraph(appView);
+    return CodeRewriterResult.NONE;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/result/CodeRewriterResult.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/result/CodeRewriterResult.java
new file mode 100644
index 0000000..68d5c0d
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/result/CodeRewriterResult.java
@@ -0,0 +1,40 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.conversion.passes.result;
+
+import com.android.tools.r8.errors.Unreachable;
+
+public interface CodeRewriterResult {
+
+  CodeRewriterResult NO_CHANGE = new DefaultCodeRewriterResult(false);
+  CodeRewriterResult HAS_CHANGED = new DefaultCodeRewriterResult(true);
+  CodeRewriterResult NONE =
+      new CodeRewriterResult() {
+        @Override
+        public boolean hasChanged() {
+          throw new Unreachable();
+        }
+      };
+
+  static CodeRewriterResult hasChanged(boolean hasChanged) {
+    return hasChanged ? HAS_CHANGED : NO_CHANGE;
+  }
+
+  class DefaultCodeRewriterResult implements CodeRewriterResult {
+
+    private final boolean hasChanged;
+
+    public DefaultCodeRewriterResult(boolean hasChanged) {
+      this.hasChanged = hasChanged;
+    }
+
+    @Override
+    public boolean hasChanged() {
+      return hasChanged;
+    }
+  }
+
+  boolean hasChanged();
+}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 4984711..5b78fcf 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -5,13 +5,11 @@
 package com.android.tools.r8.ir.optimize;
 
 import static com.android.tools.r8.ir.analysis.type.Nullability.definitelyNotNull;
-import static com.android.tools.r8.ir.analysis.type.Nullability.maybeNull;
 
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DebugLocalInfo;
 import com.android.tools.r8.graph.DexClass;
-import com.android.tools.r8.graph.DexClassAndMethod;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
@@ -35,14 +33,12 @@
 import com.android.tools.r8.ir.code.InstructionIterator;
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.InvokeInterface;
-import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.code.InvokeVirtual;
 import com.android.tools.r8.ir.code.Move;
 import com.android.tools.r8.ir.code.Position;
 import com.android.tools.r8.ir.code.Position.SyntheticPosition;
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.Value;
-import com.android.tools.r8.ir.optimize.info.MethodOptimizationInfo;
 import com.android.tools.r8.ir.regalloc.LinearScanRegisterAllocator;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.LazyBox;
@@ -59,7 +55,6 @@
 import it.unimi.dsi.fastutil.longs.Long2ReferenceMap;
 import it.unimi.dsi.fastutil.longs.Long2ReferenceOpenHashMap;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
 import java.util.ListIterator;
 import java.util.Set;
@@ -68,11 +63,9 @@
 
   private final AppView<?> appView;
   private final DexItemFactory dexItemFactory;
-  private final InternalOptions options;
 
   public CodeRewriter(AppView<?> appView) {
     this.appView = appView;
-    this.options = appView.options();
     this.dexItemFactory = appView.dexItemFactory();
   }
 
@@ -134,112 +127,6 @@
     assert Streams.stream(code.instructions()).noneMatch(Instruction::isAssume);
   }
 
-  private boolean checkArgumentType(InvokeMethod invoke, int argumentIndex) {
-    // TODO(sgjesse): Insert cast if required.
-    TypeElement returnType =
-        TypeElement.fromDexType(invoke.getInvokedMethod().proto.returnType, maybeNull(), appView);
-    TypeElement argumentType =
-        TypeElement.fromDexType(getArgumentType(invoke, argumentIndex), maybeNull(), appView);
-    return appView.enableWholeProgramOptimizations()
-        ? argumentType.lessThanOrEqual(returnType, appView)
-        : argumentType.equals(returnType);
-  }
-
-  private DexType getArgumentType(InvokeMethod invoke, int argumentIndex) {
-    if (invoke.isInvokeStatic()) {
-      return invoke.getInvokedMethod().proto.parameters.values[argumentIndex];
-    }
-    if (argumentIndex == 0) {
-      return invoke.getInvokedMethod().holder;
-    }
-    return invoke.getInvokedMethod().proto.parameters.values[argumentIndex - 1];
-  }
-
-  // Replace result uses for methods where something is known about what is returned.
-  public boolean rewriteMoveResult(IRCode code) {
-    if (options.isGeneratingClassFiles() || !code.metadata().mayHaveInvokeMethod()) {
-      return false;
-    }
-
-    AssumeRemover assumeRemover = new AssumeRemover(appView, code);
-    boolean changed = false;
-    boolean mayHaveRemovedTrivialPhi = false;
-    Set<BasicBlock> blocksToBeRemoved = Sets.newIdentityHashSet();
-    ListIterator<BasicBlock> blockIterator = code.listIterator();
-    while (blockIterator.hasNext()) {
-      BasicBlock block = blockIterator.next();
-      if (blocksToBeRemoved.contains(block)) {
-        continue;
-      }
-
-      InstructionListIterator iterator = block.listIterator(code);
-      while (iterator.hasNext()) {
-        InvokeMethod invoke = iterator.next().asInvokeMethod();
-        if (invoke == null || !invoke.hasOutValue() || invoke.outValue().hasLocalInfo()) {
-          continue;
-        }
-
-        // Check if the invoked method is known to return one of its arguments.
-        DexClassAndMethod target = invoke.lookupSingleTarget(appView, code.context());
-        if (target == null) {
-          continue;
-        }
-
-        MethodOptimizationInfo optimizationInfo = target.getDefinition().getOptimizationInfo();
-        if (!optimizationInfo.returnsArgument()) {
-          continue;
-        }
-
-        int argumentIndex = optimizationInfo.getReturnedArgument();
-        // Replace the out value of the invoke with the argument and ignore the out value.
-        if (argumentIndex < 0 || !checkArgumentType(invoke, argumentIndex)) {
-          continue;
-        }
-
-        Value argument = invoke.arguments().get(argumentIndex);
-        Value outValue = invoke.outValue();
-        assert outValue.verifyCompatible(argument.outType());
-
-        // Make sure that we are only narrowing information here. Note, in cases where we cannot
-        // find the definition of types, computing lessThanOrEqual will return false unless it is
-        // object.
-        if (!argument.getType().lessThanOrEqual(outValue.getType(), appView)) {
-          continue;
-        }
-
-        Set<Value> affectedValues =
-            argument.getType().equals(outValue.getType())
-                ? Collections.emptySet()
-                : outValue.affectedValues();
-
-        assumeRemover.markAssumeDynamicTypeUsersForRemoval(outValue);
-        mayHaveRemovedTrivialPhi |= outValue.numberOfPhiUsers() > 0;
-        outValue.replaceUsers(argument);
-        invoke.setOutValue(null);
-        changed = true;
-
-        if (!affectedValues.isEmpty()) {
-          new TypeAnalysis(appView).narrowing(affectedValues);
-        }
-      }
-    }
-    assumeRemover.removeMarkedInstructions(blocksToBeRemoved).finish();
-    Set<Value> affectedValues = Sets.newIdentityHashSet();
-    if (!blocksToBeRemoved.isEmpty()) {
-      code.removeBlocks(blocksToBeRemoved);
-      code.removeAllDeadAndTrivialPhis(affectedValues);
-      assert code.getUnreachableBlocks().isEmpty();
-    } else if (mayHaveRemovedTrivialPhi || assumeRemover.mayHaveIntroducedTrivialPhi()) {
-      code.removeAllDeadAndTrivialPhis(affectedValues);
-    }
-    if (!affectedValues.isEmpty()) {
-      new TypeAnalysis(appView).narrowing(affectedValues);
-    }
-    code.removeRedundantBlocks();
-    assert code.isConsistentSSA(appView);
-    return changed;
-  }
-
   public static void removeOrReplaceByDebugLocalWrite(
       Instruction currentInstruction, InstructionListIterator it, Value inValue, Value outValue) {
     if (outValue.hasLocalInfo() && outValue.getLocalInfo() != inValue.getLocalInfo()) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/DeadCodeRemover.java b/src/main/java/com/android/tools/r8/ir/optimize/DeadCodeRemover.java
index a207486..e9b180c 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/DeadCodeRemover.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/DeadCodeRemover.java
@@ -21,6 +21,7 @@
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.code.ValueIsDeadAnalysis;
 import com.android.tools.r8.ir.conversion.passes.BranchSimplifier;
+import com.android.tools.r8.ir.conversion.passes.MoveResultRewriter;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.Box;
 import com.android.tools.r8.utils.IterableUtils;
@@ -35,21 +36,15 @@
 public class DeadCodeRemover {
 
   private final AppView<?> appView;
-  private final CodeRewriter codeRewriter;
 
-  public DeadCodeRemover(AppView<?> appView, CodeRewriter codeRewriter) {
+  public DeadCodeRemover(AppView<?> appView) {
     this.appView = appView;
-    this.codeRewriter = codeRewriter;
-  }
-
-  public CodeRewriter getCodeRewriter() {
-    return codeRewriter;
   }
 
   public void run(IRCode code, Timing timing) {
     timing.begin("Remove dead code");
 
-    codeRewriter.rewriteMoveResult(code);
+    new MoveResultRewriter(appView).run(code, timing);
 
     BranchSimplifier branchSimplifier = new BranchSimplifier(appView);
 
@@ -75,7 +70,7 @@
   }
 
   public boolean verifyNoDeadCode(IRCode code) {
-    assert !codeRewriter.rewriteMoveResult(code);
+    assert !new MoveResultRewriter(appView).run(code, Timing.empty()).hasChanged();
     assert !removeUnneededCatchHandlers(code);
     ValueIsDeadAnalysis valueIsDeadAnalysis = new ValueIsDeadAnalysis(appView, code);
     for (BasicBlock block : code.blocks) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumValueOptimizer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumValueOptimizer.java
index c94b03e..46ed560 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumValueOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumValueOptimizer.java
@@ -34,6 +34,7 @@
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.passes.CodeRewriterPass;
+import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.ir.optimize.info.FieldOptimizationInfo;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.ArrayUtils;
@@ -60,8 +61,9 @@
   }
 
   @Override
-  protected void rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(IRCode code) {
     rewriteConstantEnumMethodCalls(code);
+    return CodeRewriterResult.NONE;
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/outliner/OutlinerImpl.java b/src/main/java/com/android/tools/r8/ir/optimize/outliner/OutlinerImpl.java
index 9cd00c7..1cedcc3 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/outliner/OutlinerImpl.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/outliner/OutlinerImpl.java
@@ -62,6 +62,7 @@
 import com.android.tools.r8.ir.conversion.MethodConversionOptions.MutableMethodConversionOptions;
 import com.android.tools.r8.ir.conversion.MethodProcessorEventConsumer;
 import com.android.tools.r8.ir.conversion.SourceCode;
+import com.android.tools.r8.ir.conversion.passes.MoveResultRewriter;
 import com.android.tools.r8.ir.optimize.CodeRewriter;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
 import com.android.tools.r8.ir.optimize.InliningConstraints;
@@ -1368,7 +1369,7 @@
             applyOutliningCandidate(code);
             converter.printMethod(code, "IR after outlining (SSA)", null);
             converter.memberValuePropagation.run(code);
-            converter.codeRewriter.rewriteMoveResult(code);
+            new MoveResultRewriter(appView).run(code, Timing.empty());
             converter.removeDeadCodeAndFinalizeIR(
                 code, OptimizationFeedbackIgnore.getInstance(), Timing.empty());
           },
@@ -1399,7 +1400,7 @@
           // optimizations needed for outlining: rewriteMoveResult() to remove out-values on
           // StringBuilder/StringBuffer method invocations, and removeDeadCode() to remove
           // unused out-values.
-          converter.codeRewriter.rewriteMoveResult(code);
+          new MoveResultRewriter(appView).run(code, Timing.empty());
           converter.deadCodeRemover.run(code, Timing.empty());
           CodeRewriter.removeAssumeInstructions(appView, code);
           consumer.accept(code);
diff --git a/src/main/java/com/android/tools/r8/shaking/EnqueuerDeferredTracingRewriter.java b/src/main/java/com/android/tools/r8/shaking/EnqueuerDeferredTracingRewriter.java
index 79165d1..48312bd 100644
--- a/src/main/java/com/android/tools/r8/shaking/EnqueuerDeferredTracingRewriter.java
+++ b/src/main/java/com/android/tools/r8/shaking/EnqueuerDeferredTracingRewriter.java
@@ -47,7 +47,7 @@
   EnqueuerDeferredTracingRewriter(AppView<? extends AppInfoWithClassHierarchy> appView) {
     this.appView = appView;
     this.codeRewriter = new CodeRewriter(appView);
-    this.deadCodeRemover = new DeadCodeRemover(appView, codeRewriter);
+    this.deadCodeRemover = new DeadCodeRemover(appView);
   }
 
   public CodeRewriter getCodeRewriter() {