diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCode.java b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
index 1a34adf..9594d10 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCode.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
@@ -1114,25 +1114,27 @@
     return true;
   }
 
-  public void removeAllTrivialPhis() {
-    removeAllTrivialPhis(null, null);
+  public boolean removeAllTrivialPhis() {
+    return removeAllTrivialPhis(null, null);
   }
 
-  public void removeAllTrivialPhis(IRBuilder builder) {
-    removeAllTrivialPhis(builder, null);
+  public boolean removeAllTrivialPhis(IRBuilder builder) {
+    return removeAllTrivialPhis(builder, null);
   }
 
-  public void removeAllTrivialPhis(Set<Value> affectedValues) {
-    removeAllTrivialPhis(null, affectedValues);
+  public boolean removeAllTrivialPhis(Set<Value> affectedValues) {
+    return removeAllTrivialPhis(null, affectedValues);
   }
 
-  public void removeAllTrivialPhis(IRBuilder builder, Set<Value> affectedValues) {
+  public boolean removeAllTrivialPhis(IRBuilder builder, Set<Value> affectedValues) {
+    boolean anyTrivialPhisRemoved = false;
     for (BasicBlock block : blocks) {
       List<Phi> phis = new ArrayList<>(block.getPhis());
       for (Phi phi : phis) {
-        phi.removeTrivialPhi(builder, affectedValues);
+        anyTrivialPhisRemoved |= phi.removeTrivialPhi(builder, affectedValues);
       }
     }
+    return anyTrivialPhisRemoved;
   }
 
   public int reserveMarkingColor() {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index ace131b..3547e12 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -193,7 +193,7 @@
     this.printer = printer;
     this.mainDexClasses = mainDexClasses.getClasses();
     this.codeRewriter = new CodeRewriter(appView, this);
-    this.constantCanonicalizer = new ConstantCanonicalizer();
+    this.constantCanonicalizer = new ConstantCanonicalizer(codeRewriter);
     this.classInitializerDefaultsOptimization =
         options.debug ? null : new ClassInitializerDefaultsOptimization(appView, this);
     this.stringConcatRewriter = new StringConcatRewriter(appView);
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/ConstantCanonicalizer.java b/src/main/java/com/android/tools/r8/ir/optimize/ConstantCanonicalizer.java
index 48f7b09..eff0d07 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/ConstantCanonicalizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/ConstantCanonicalizer.java
@@ -30,8 +30,10 @@
 import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
 import it.unimi.dsi.fastutil.objects.Object2IntMap;
 import it.unimi.dsi.fastutil.objects.Object2ObjectLinkedOpenCustomHashMap;
+import it.unimi.dsi.fastutil.objects.Object2ObjectMap;
 import it.unimi.dsi.fastutil.objects.Object2ObjectSortedMap.FastSortedEntrySet;
 import java.util.ArrayList;
+import java.util.Iterator;
 import java.util.List;
 
 /**
@@ -41,6 +43,8 @@
   // Threshold to limit the number of constant canonicalization.
   private static final int MAX_CANONICALIZED_CONSTANT = 22;
 
+  private final CodeRewriter codeRewriter;
+
   private int numberOfConstNumberCanonicalization = 0;
   private int numberOfConstStringCanonicalization = 0;
   private int numberOfDexItemBasedConstStringCanonicalization = 0;
@@ -48,7 +52,8 @@
   private int numberOfEffectivelyFinalFieldCanonicalization = 0;
   private final Object2IntMap<Long> histogramOfCanonicalizationCandidatesPerMethod;
 
-  public ConstantCanonicalizer() {
+  public ConstantCanonicalizer(CodeRewriter codeRewriter) {
+    this.codeRewriter = codeRewriter;
     if (Log.ENABLED) {
       histogramOfCanonicalizationCandidatesPerMethod = new Object2IntArrayMap<>();
     } else {
@@ -180,59 +185,69 @@
         histogramOfCanonicalizationCandidatesPerMethod.put(numOfCandidates, count + 1);
       }
     }
-    entries.stream()
-        .filter(a -> a.getValue().size() > 1)
-        .sorted((a, b) -> Integer.compare(b.getValue().size(), a.getValue().size()))
-        .limit(MAX_CANONICALIZED_CONSTANT)
-        .forEach(
-            (entry) -> {
-              Instruction canonicalizedConstant = entry.getKey();
-              assert canonicalizedConstant.instructionTypeCanBeCanonicalized();
-              Instruction newConst;
-              switch (canonicalizedConstant.opcode()) {
-                case CONST_CLASS:
-                  if (Log.ENABLED) {
-                    numberOfConstClassCanonicalization++;
-                  }
-                  newConst = ConstClass.copyOf(code, canonicalizedConstant.asConstClass());
-                  break;
-                case CONST_NUMBER:
-                  if (Log.ENABLED) {
-                    numberOfConstNumberCanonicalization++;
-                  }
-                  newConst = ConstNumber.copyOf(code, canonicalizedConstant.asConstNumber());
-                  break;
-                case CONST_STRING:
-                  if (Log.ENABLED) {
-                    numberOfConstStringCanonicalization++;
-                  }
-                  newConst = ConstString.copyOf(code, canonicalizedConstant.asConstString());
-                  break;
-                case DEX_ITEM_BASED_CONST_STRING:
-                  if (Log.ENABLED) {
-                    numberOfDexItemBasedConstStringCanonicalization++;
-                  }
-                  newConst =
-                      DexItemBasedConstString.copyOf(
-                          code, canonicalizedConstant.asDexItemBasedConstString());
-                  break;
-                case STATIC_GET:
-                  if (Log.ENABLED) {
-                    numberOfEffectivelyFinalFieldCanonicalization++;
-                  }
-                  newConst = StaticGet.copyOf(code, canonicalizedConstant.asStaticGet());
-                  break;
-                default:
-                  throw new Unreachable();
-              }
-              newConst.setPosition(firstNonNonePosition);
-              insertCanonicalizedConstant(code, newConst);
-              for (Value outValue : entry.getValue()) {
-                outValue.replaceUsers(newConst.outValue());
-              }
-            });
 
-    code.removeAllTrivialPhis();
+    Iterator<Object2ObjectMap.Entry<Instruction, List<Value>>> iterator =
+        entries.stream()
+            .filter(a -> a.getValue().size() > 1)
+            .sorted((a, b) -> Integer.compare(b.getValue().size(), a.getValue().size()))
+            .limit(MAX_CANONICALIZED_CONSTANT)
+            .iterator();
+
+    if (!iterator.hasNext()) {
+      return;
+    }
+    do {
+      Object2ObjectMap.Entry<Instruction, List<Value>> entry = iterator.next();
+      Instruction canonicalizedConstant = entry.getKey();
+      assert canonicalizedConstant.instructionTypeCanBeCanonicalized();
+      Instruction newConst;
+      switch (canonicalizedConstant.opcode()) {
+        case CONST_CLASS:
+          if (Log.ENABLED) {
+            numberOfConstClassCanonicalization++;
+          }
+          newConst = ConstClass.copyOf(code, canonicalizedConstant.asConstClass());
+          break;
+        case CONST_NUMBER:
+          if (Log.ENABLED) {
+            numberOfConstNumberCanonicalization++;
+          }
+          newConst = ConstNumber.copyOf(code, canonicalizedConstant.asConstNumber());
+          break;
+        case CONST_STRING:
+          if (Log.ENABLED) {
+            numberOfConstStringCanonicalization++;
+          }
+          newConst = ConstString.copyOf(code, canonicalizedConstant.asConstString());
+          break;
+        case DEX_ITEM_BASED_CONST_STRING:
+          if (Log.ENABLED) {
+            numberOfDexItemBasedConstStringCanonicalization++;
+          }
+          newConst =
+              DexItemBasedConstString.copyOf(
+                  code, canonicalizedConstant.asDexItemBasedConstString());
+          break;
+        case STATIC_GET:
+          if (Log.ENABLED) {
+            numberOfEffectivelyFinalFieldCanonicalization++;
+          }
+          newConst = StaticGet.copyOf(code, canonicalizedConstant.asStaticGet());
+          break;
+        default:
+          throw new Unreachable();
+      }
+      newConst.setPosition(firstNonNonePosition);
+      insertCanonicalizedConstant(code, newConst);
+      for (Value outValue : entry.getValue()) {
+        outValue.replaceUsers(newConst.outValue());
+      }
+    } while (iterator.hasNext());
+
+    if (code.removeAllTrivialPhis()) {
+      codeRewriter.simplifyControlFlow(code);
+    }
+
     assert code.isConsistentSSA();
   }
 
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ifs/SemiTrivialPhiBranchTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ifs/SemiTrivialPhiBranchTest.java
index 8b84dab..a8bccea 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/ifs/SemiTrivialPhiBranchTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ifs/SemiTrivialPhiBranchTest.java
@@ -5,6 +5,7 @@
 package com.android.tools.r8.ir.optimize.ifs;
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.MatcherAssert.assertThat;
 
 import com.android.tools.r8.NeverInline;
@@ -50,8 +51,7 @@
     assertThat(testClassSubject, isPresent());
     assertThat(testClassSubject.mainMethod(), isPresent());
     assertThat(testClassSubject.uniqueMethodWithName("live"), isPresent());
-    // TODO(christofferqa): Should not be present.
-    assertThat(testClassSubject.uniqueMethodWithName("dead"), isPresent());
+    assertThat(testClassSubject.uniqueMethodWithName("dead"), not(isPresent()));
   }
 
   static class TestClass {
