Avoid inliner reprocessing of singler caller inlined methods

Bug: 202419103
Change-Id: I5bd5e2fe8ba84facb02df00b18ba7817aee50129
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
index 993315b..162e0bb 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteInformation.java
@@ -18,13 +18,13 @@
    *
    * <p>For pinned methods (methods kept through Proguard keep rules) this will always answer <code>
    * false</code>.
-   *
-   * @param method
    */
   public abstract boolean hasSingleCallSite(ProgramMethod method);
 
   public abstract boolean hasDoubleCallSite(ProgramMethod method);
 
+  public abstract void unsetCallSiteInformation(ProgramMethod method);
+
   public static CallSiteInformation empty() {
     return EmptyCallSiteInformation.EMPTY_INFO;
   }
@@ -42,6 +42,11 @@
     public boolean hasDoubleCallSite(ProgramMethod method) {
       return false;
     }
+
+    @Override
+    public void unsetCallSiteInformation(ProgramMethod method) {
+      // Intentionally empty.
+    }
   }
 
   static class CallGraphBasedCallSiteInformation extends CallSiteInformation {
@@ -95,5 +100,11 @@
     public boolean hasDoubleCallSite(ProgramMethod method) {
       return doubleCallSite.contains(method.getReference());
     }
+
+    @Override
+    public void unsetCallSiteInformation(ProgramMethod method) {
+      singleCallSite.remove(method.getReference());
+      doubleCallSite.remove(method.getReference());
+    }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 6359283..1d1a698 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -677,8 +677,7 @@
     {
       timing.begin("Build primary method processor");
       PrimaryMethodProcessor primaryMethodProcessor =
-          PrimaryMethodProcessor.create(
-              appView.withLiveness(), postMethodProcessorBuilder, executorService, timing);
+          PrimaryMethodProcessor.create(appView.withLiveness(), executorService, timing);
       timing.end();
       timing.begin("IR conversion phase 1");
       assert appView.graphLens() == graphLensForPrimaryOptimizationPass;
@@ -1977,4 +1976,16 @@
     }
     return previous;
   }
+
+  /**
+   * Called when a method is pruned as a result of optimizations during IR processing in R8, to
+   * allow optimizations that track sets of methods to fixup their state.
+   */
+  public void pruneMethod(ProgramMethod method) {
+    assert appView.enableWholeProgramOptimizations();
+    assert method.getHolder().lookupMethod(method.getReference()) == null;
+    if (inliner != null) {
+      inliner.pruneMethod(method);
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MethodOptimizationFeedback.java b/src/main/java/com/android/tools/r8/ir/conversion/MethodOptimizationFeedback.java
index 5219a50..80b9cbd 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/MethodOptimizationFeedback.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MethodOptimizationFeedback.java
@@ -37,7 +37,7 @@
   void methodReturnsAbstractValue(
       DexEncodedMethod method, AppView<AppInfoWithLiveness> appView, AbstractValue abstractValue);
 
-  void unsetAbstractReturnValue(DexEncodedMethod method);
+  void unsetAbstractReturnValue(ProgramMethod method);
 
   void methodReturnsObjectWithUpperBoundType(
       DexEncodedMethod method, AppView<?> appView, TypeElement type);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
index 9bd75fe..bea221a 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
@@ -38,29 +38,25 @@
 
   private final AppView<?> appView;
   private final CallSiteInformation callSiteInformation;
-  private final PostMethodProcessor.Builder postMethodProcessorBuilder;
   private final Deque<SortedProgramMethodSet> waves;
 
   private ProcessorContext processorContext;
 
   private PrimaryMethodProcessor(
       AppView<AppInfoWithLiveness> appView,
-      PostMethodProcessor.Builder postMethodProcessorBuilder,
       CallGraph callGraph) {
     this.appView = appView;
     this.callSiteInformation = callGraph.createCallSiteInformation(appView);
-    this.postMethodProcessorBuilder = postMethodProcessorBuilder;
-    this.waves = createWaves(appView, callGraph, callSiteInformation);
+    this.waves = createWaves(appView, callGraph);
   }
 
   static PrimaryMethodProcessor create(
       AppView<AppInfoWithLiveness> appView,
-      PostMethodProcessor.Builder postMethodProcessorBuilder,
       ExecutorService executorService,
       Timing timing)
       throws ExecutionException {
     CallGraph callGraph = CallGraph.builder(appView).build(executorService, timing);
-    return new PrimaryMethodProcessor(appView, postMethodProcessorBuilder, callGraph);
+    return new PrimaryMethodProcessor(appView, callGraph);
   }
 
   @Override
@@ -84,29 +80,18 @@
     return callSiteInformation;
   }
 
-  private Deque<SortedProgramMethodSet> createWaves(
-      AppView<?> appView, CallGraph callGraph, CallSiteInformation callSiteInformation) {
+  private Deque<SortedProgramMethodSet> createWaves(AppView<?> appView, CallGraph callGraph) {
     InternalOptions options = appView.options();
     Deque<SortedProgramMethodSet> waves = new ArrayDeque<>();
     Set<Node> nodes = callGraph.nodes;
-    ProgramMethodSet reprocessing = ProgramMethodSet.create();
     int waveCount = 1;
     while (!nodes.isEmpty()) {
       SortedProgramMethodSet wave = callGraph.extractLeaves();
-      wave.forEach(
-          method -> {
-            if (callSiteInformation.hasSingleCallSite(method) && options.enableInlining) {
-              callGraph.cycleEliminationResult.forEachRemovedCaller(method, reprocessing::add);
-            }
-          });
       waves.addLast(wave);
       if (Log.ENABLED && Log.isLoggingEnabledFor(PrimaryMethodProcessor.class)) {
         Log.info(getClass(), "Wave #%d: %d", waveCount++, wave.size());
       }
     }
-    if (!reprocessing.isEmpty()) {
-      postMethodProcessorBuilder.put(reprocessing);
-    }
     options.testing.waveModifier.accept(waves);
     return waves;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
index 4a91c4c..264eda0 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
@@ -32,8 +32,10 @@
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.optimize.Inliner.InlineAction;
+import com.android.tools.r8.ir.optimize.Inliner.InlineResult;
 import com.android.tools.r8.ir.optimize.Inliner.InlineeWithReason;
 import com.android.tools.r8.ir.optimize.Inliner.Reason;
+import com.android.tools.r8.ir.optimize.Inliner.RetryAction;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
 import com.android.tools.r8.ir.optimize.inliner.InliningReasonStrategy;
 import com.android.tools.r8.ir.optimize.inliner.WhyAreYouNotInliningReporter;
@@ -258,7 +260,7 @@
   }
 
   @Override
-  public InlineAction computeInlining(
+  public InlineResult computeInlining(
       InvokeMethod invoke,
       SingleResolutionResult resolutionResult,
       ProgramMethod singleTarget,
@@ -287,6 +289,15 @@
       return null;
     }
 
+    if (reason == Reason.SIMPLE
+        && !singleTarget.getDefinition().isProcessed()
+        && methodProcessor.isPrimaryMethodProcessor()) {
+      // The single target has this method as single caller, but the single target is not yet
+      // processed. Enqueue the context for processing in the secondary optimization pass to allow
+      // the single caller inlining to happen.
+      return new RetryAction();
+    }
+
     if (!singleTarget
         .getDefinition()
         .isInliningCandidate(method, reason, appView.appInfo(), whyAreYouNotInliningReporter)) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/ForcedInliningOracle.java b/src/main/java/com/android/tools/r8/ir/optimize/ForcedInliningOracle.java
index 672d9a7..0e26fa3 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/ForcedInliningOracle.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/ForcedInliningOracle.java
@@ -14,6 +14,7 @@
 import com.android.tools.r8.ir.code.InvokeDirect;
 import com.android.tools.r8.ir.code.InvokeMethod;
 import com.android.tools.r8.ir.optimize.Inliner.InlineAction;
+import com.android.tools.r8.ir.optimize.Inliner.InlineResult;
 import com.android.tools.r8.ir.optimize.Inliner.InlineeWithReason;
 import com.android.tools.r8.ir.optimize.Inliner.Reason;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
@@ -61,7 +62,7 @@
   }
 
   @Override
-  public InlineAction computeInlining(
+  public InlineResult computeInlining(
       InvokeMethod invoke,
       SingleResolutionResult resolutionResult,
       ProgramMethod singleTarget,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
index 605896f..2f61964 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
@@ -86,6 +86,10 @@
   private final LensCodeRewriter lensCodeRewriter;
   final MainDexInfo mainDexInfo;
 
+  // The set of callers of single caller methods where the single caller method could not be inlined
+  // due to not being processed at the time of inlining.
+  private final LongLivedProgramMethodSetBuilder<ProgramMethodSet> singleInlineCallers;
+
   // State for inlining methods which are known to be called twice.
   private LongLivedProgramMethodSetBuilder<ProgramMethodSet> doubleInlineCallers;
   private final ProgramMethodSet doubleInlineSelectedTargets = ProgramMethodSet.create();
@@ -105,6 +109,8 @@
             : ImmutableSet.of(intrinsics.throwNpe, intrinsics.throwParameterIsNullException);
     this.lensCodeRewriter = lensCodeRewriter;
     this.mainDexInfo = appView.appInfo().getMainDexInfo();
+    this.singleInlineCallers =
+        LongLivedProgramMethodSetBuilder.createConcurrentForIdentitySet(appView.graphLens());
     availableApiExceptions =
         appView.options().canHaveDalvikCatchHandlerVerificationBug()
             ? new AvailableApiExceptions(appView.options())
@@ -250,8 +256,23 @@
     // The double inline callers are always rewritten up until the graph lens of the primary
     // optimization pass, so we can safely merge them into the methods to reprocess (which may be
     // rewritten with a newer graph lens).
-    postMethodProcessorBuilder.getMethodsToReprocessBuilder().merge(doubleInlineCallers);
+    postMethodProcessorBuilder
+        .getMethodsToReprocessBuilder()
+        .rewrittenWithLens(appView)
+        .merge(
+            doubleInlineCallers
+                .rewrittenWithLens(appView)
+                .removeIf(
+                    appView,
+                    method -> method.getOptimizationInfo().hasBeenInlinedIntoSingleCallSite()))
+        .merge(
+            singleInlineCallers
+                .rewrittenWithLens(appView)
+                .removeIf(
+                    appView,
+                    method -> method.getOptimizationInfo().hasBeenInlinedIntoSingleCallSite()));
     doubleInlineCallers = null;
+    singleInlineCallers.clear();
   }
 
   /**
@@ -570,7 +591,18 @@
     }
   }
 
-  public static class InlineAction {
+  public abstract static class InlineResult {
+
+    InlineAction asInlineAction() {
+      return null;
+    }
+
+    boolean isRetryAction() {
+      return false;
+    }
+  }
+
+  public static class InlineAction extends InlineResult {
 
     public final ProgramMethod target;
     public final Invoke invoke;
@@ -585,6 +617,11 @@
       this.reason = reason;
     }
 
+    @Override
+    InlineAction asInlineAction() {
+      return this;
+    }
+
     void setShouldSynthesizeInitClass() {
       assert !shouldSynthesizeNullCheckForReceiver;
       shouldSynthesizeInitClass = true;
@@ -798,6 +835,14 @@
     }
   }
 
+  public static class RetryAction extends InlineResult {
+
+    @Override
+    boolean isRetryAction() {
+      return true;
+    }
+  }
+
   static class InlineeWithReason {
 
     final Reason reason;
@@ -1003,7 +1048,7 @@
               oracle.isForcedInliningOracle()
                   ? NopWhyAreYouNotInliningReporter.getInstance()
                   : WhyAreYouNotInliningReporter.createFor(singleTarget, appView, context);
-          InlineAction action =
+          InlineResult inlineResult =
               oracle.computeInlining(
                   invoke,
                   resolutionResult,
@@ -1011,11 +1056,18 @@
                   context,
                   classInitializationAnalysis,
                   whyAreYouNotInliningReporter);
-          if (action == null) {
+          if (inlineResult == null) {
             assert whyAreYouNotInliningReporter.unsetReasonHasBeenReportedFlag();
             continue;
           }
 
+          if (inlineResult.isRetryAction()) {
+            enqueueMethodForReprocessing(context);
+            continue;
+          }
+
+          InlineAction action = inlineResult.asInlineAction();
+
           DexProgramClass downcastClass = getDowncastTypeIfNeeded(strategy, invoke, singleTarget);
           if (downcastClass != null
               && AccessControl.isClassAccessible(downcastClass, context, appView)
@@ -1239,6 +1291,14 @@
     assert IteratorUtils.peekNext(blockIterator) == firstInlineeBlock;
   }
 
+  public void enqueueMethodForReprocessing(ProgramMethod method) {
+    singleInlineCallers.add(method, appView.graphLens());
+  }
+
+  public void pruneMethod(ProgramMethod method) {
+    singleInlineCallers.remove(method.getReference(), appView.graphLens());
+  }
+
   public static boolean verifyNoMethodsInlinedDueToSingleCallSite(AppView<?> appView) {
     for (DexProgramClass clazz : appView.appInfo().classes()) {
       for (DexEncodedMethod method : clazz.methods()) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/InliningOracle.java b/src/main/java/com/android/tools/r8/ir/optimize/InliningOracle.java
index 1e5c3ae..f839991 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/InliningOracle.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/InliningOracle.java
@@ -8,7 +8,7 @@
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
 import com.android.tools.r8.ir.code.InvokeMethod;
-import com.android.tools.r8.ir.optimize.Inliner.InlineAction;
+import com.android.tools.r8.ir.optimize.Inliner.InlineResult;
 import com.android.tools.r8.ir.optimize.Inliner.Reason;
 import com.android.tools.r8.ir.optimize.inliner.WhyAreYouNotInliningReporter;
 
@@ -29,7 +29,7 @@
       Reason reason,
       WhyAreYouNotInliningReporter whyAreYouNotInliningReporter);
 
-  InlineAction computeInlining(
+  InlineResult computeInlining(
       InvokeMethod invoke,
       SingleResolutionResult resolutionResult,
       ProgramMethod singleTarget,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java b/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
index e28b6ad..f93dc2a 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
@@ -137,7 +137,7 @@
       Timing timing) {
     DexEncodedMethod definition = method.getDefinition();
     identifyBridgeInfo(definition, code, feedback, timing);
-    analyzeReturns(code, feedback, timing);
+    analyzeReturns(code, feedback, methodProcessor, timing);
     if (options.enableInlining) {
       identifyInvokeSemanticsForInlining(definition, code, feedback, timing);
     }
@@ -166,13 +166,15 @@
     timing.end();
   }
 
-  private void analyzeReturns(IRCode code, OptimizationFeedback feedback, Timing timing) {
+  private void analyzeReturns(
+      IRCode code, OptimizationFeedback feedback, MethodProcessor methodProcessor, Timing timing) {
     timing.begin("Identify returns argument");
-    analyzeReturns(code, feedback);
+    analyzeReturns(code, feedback, methodProcessor);
     timing.end();
   }
 
-  private void analyzeReturns(IRCode code, OptimizationFeedback feedback) {
+  private void analyzeReturns(
+      IRCode code, OptimizationFeedback feedback, MethodProcessor methodProcessor) {
     ProgramMethod context = code.context();
     DexEncodedMethod method = context.getDefinition();
     List<BasicBlock> normalExits = code.computeNormalExitBlocks();
@@ -204,7 +206,7 @@
           feedback.methodReturnsAbstractValue(method, appView, abstractReturnValue);
           if (checkCastAndInstanceOfMethodSpecialization != null) {
             checkCastAndInstanceOfMethodSpecialization.addCandidateForOptimization(
-                context, abstractReturnValue);
+                context, abstractReturnValue, methodProcessor);
           }
         }
       }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/MutableMethodOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/optimize/info/MutableMethodOptimizationInfo.java
index e43573a..9f1ea2d 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/MutableMethodOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/MutableMethodOptimizationInfo.java
@@ -361,6 +361,10 @@
     return isFlagSet(HAS_BEEN_INLINED_INTO_SINGLE_CALL_SITE_FLAG);
   }
 
+  void unsetInlinedIntoSingleCallSite() {
+    clearFlag(HAS_BEEN_INLINED_INTO_SINGLE_CALL_SITE_FLAG);
+  }
+
   void markInlinedIntoSingleCallSite() {
     setFlag(HAS_BEEN_INLINED_INTO_SINGLE_CALL_SITE_FLAG);
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackDelayed.java b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackDelayed.java
index 472a65b..1b4acdd 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackDelayed.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackDelayed.java
@@ -199,7 +199,7 @@
   }
 
   @Override
-  public synchronized void unsetAbstractReturnValue(DexEncodedMethod method) {
+  public synchronized void unsetAbstractReturnValue(ProgramMethod method) {
     getMethodOptimizationInfoForUpdating(method).unsetAbstractReturnValue();
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackIgnore.java b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackIgnore.java
index 366a1ac..8b13771 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackIgnore.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackIgnore.java
@@ -79,7 +79,7 @@
       DexEncodedMethod method, AppView<AppInfoWithLiveness> appView, AbstractValue value) {}
 
   @Override
-  public void unsetAbstractReturnValue(DexEncodedMethod method) {}
+  public void unsetAbstractReturnValue(ProgramMethod method) {}
 
   @Override
   public void methodReturnsObjectWithUpperBoundType(
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackSimple.java b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackSimple.java
index f33bf52..86ea927 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackSimple.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackSimple.java
@@ -107,8 +107,10 @@
   }
 
   @Override
-  public void unsetAbstractReturnValue(DexEncodedMethod method) {
-    method.getMutableOptimizationInfo().unsetAbstractReturnValue();
+  public void unsetAbstractReturnValue(ProgramMethod method) {
+    if (method.getOptimizationInfo().isMutableOptimizationInfo()) {
+      method.getDefinition().getMutableOptimizationInfo().unsetAbstractReturnValue();
+    }
   }
 
   @Override
@@ -242,4 +244,13 @@
   public void setUnusedArguments(ProgramMethod method, BitSet unusedArguments) {
     method.getDefinition().getMutableOptimizationInfo().setUnusedArguments(unusedArguments);
   }
+
+  public void unsetInlinedIntoSingleCallSite(ProgramMethod method) {
+    if (method.getOptimizationInfo().isMutableOptimizationInfo()) {
+      method
+          .getOptimizationInfo()
+          .asMutableMethodOptimizationInfo()
+          .unsetInlinedIntoSingleCallSite();
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/typechecks/CheckCastAndInstanceOfMethodSpecialization.java b/src/main/java/com/android/tools/r8/ir/optimize/typechecks/CheckCastAndInstanceOfMethodSpecialization.java
index 906ac7f..18a964f 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/typechecks/CheckCastAndInstanceOfMethodSpecialization.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/typechecks/CheckCastAndInstanceOfMethodSpecialization.java
@@ -15,10 +15,10 @@
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.conversion.IRConverter;
+import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.optimize.info.MethodOptimizationInfo;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackSimple;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.utils.Action;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.android.tools.r8.utils.collections.SortedProgramMethodSet;
 
@@ -34,7 +34,7 @@
  *
  * <p>TODO(b/151596599): Also handle methods that implement casts.
  */
-public class CheckCastAndInstanceOfMethodSpecialization implements Action {
+public class CheckCastAndInstanceOfMethodSpecialization {
 
   private static final OptimizationFeedbackSimple feedback =
       OptimizationFeedbackSimple.getInstance();
@@ -52,14 +52,16 @@
     this.converter = converter;
   }
 
-  public void addCandidateForOptimization(ProgramMethod method, AbstractValue abstractReturnValue) {
+  public void addCandidateForOptimization(
+      ProgramMethod method, AbstractValue abstractReturnValue, MethodProcessor methodProcessor) {
     if (!converter.isInWave()) {
       return;
     }
+    assert methodProcessor.isPrimaryMethodProcessor();
     if (isCandidateForInstanceOfOptimization(method, abstractReturnValue)) {
       synchronized (this) {
         if (candidatesForInstanceOfOptimization.isEmpty()) {
-          converter.addWaveDoneAction(this);
+          converter.addWaveDoneAction(() -> execute(methodProcessor));
         }
         candidatesForInstanceOfOptimization.add(method);
       }
@@ -72,18 +74,18 @@
         && abstractReturnValue.isSingleBoolean();
   }
 
-  @Override
-  public void execute() {
+  public void execute(MethodProcessor methodProcessor) {
     assert !candidatesForInstanceOfOptimization.isEmpty();
     ProgramMethodSet processed = ProgramMethodSet.create();
     for (ProgramMethod method : candidatesForInstanceOfOptimization) {
       if (!processed.contains(method)) {
-        processCandidateForInstanceOfOptimization(method);
+        processCandidateForInstanceOfOptimization(method, methodProcessor);
       }
     }
   }
 
-  private void processCandidateForInstanceOfOptimization(ProgramMethod method) {
+  private void processCandidateForInstanceOfOptimization(
+      ProgramMethod method, MethodProcessor methodProcessor) {
     DexEncodedMethod definition = method.getDefinition();
     if (!definition.isNonPrivateVirtualMethod()) {
       return;
@@ -140,16 +142,27 @@
           parentMethodDefinition.getCode().buildIR(parentMethod, appView, parentMethod.getOrigin());
       converter.markProcessed(code, feedback);
       // Fixup method optimization info (the method no longer returns a constant).
-      feedback.unsetAbstractReturnValue(parentMethod.getDefinition());
+      feedback.unsetAbstractReturnValue(parentMethod);
       feedback.unsetClassInlinerMethodConstraint(parentMethod);
     } else {
       return;
     }
 
-    method.getHolder().removeMethod(method.getReference());
-
     appView.withArgumentPropagator(
         argumentPropagator -> argumentPropagator.transferArgumentInformation(method, parentMethod));
+
+    // Each call to the override is now a call to the parent method, so it is no longer true that
+    // parent method has been inlined into its single call site.
+    feedback.unsetInlinedIntoSingleCallSite(parentMethod);
+
+    // For the same reason, we no longer have single or dual caller information for the parent
+    // method.
+    methodProcessor.getCallSiteInformation().unsetCallSiteInformation(parentMethod);
+
+    // Remove the method and notify other optimizations that the override has been removed to allow
+    // the optimizations to fixup their state.
+    method.getHolder().removeMethod(method.getReference());
+    converter.pruneMethod(method);
   }
 
   private ProgramMethod resolveOnSuperClass(ProgramMethod method) {
diff --git a/src/main/java/com/android/tools/r8/utils/collections/LongLivedProgramMethodSetBuilder.java b/src/main/java/com/android/tools/r8/utils/collections/LongLivedProgramMethodSetBuilder.java
index 7687a88..911290b 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/LongLivedProgramMethodSetBuilder.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/LongLivedProgramMethodSetBuilder.java
@@ -4,7 +4,10 @@
 
 package com.android.tools.r8.utils.collections;
 
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexDefinitionSupplier;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.GraphLens;
@@ -13,6 +16,7 @@
 import com.android.tools.r8.utils.SetUtils;
 import java.util.Set;
 import java.util.function.IntFunction;
+import java.util.function.Predicate;
 
 public class LongLivedProgramMethodSetBuilder<T extends ProgramMethodSet> {
 
@@ -98,15 +102,37 @@
     return this;
   }
 
+  @Deprecated
   public void remove(DexMethod method) {
     methods.remove(method);
   }
 
+  public void remove(DexMethod method, GraphLens currentGraphLens) {
+    assert isEmpty() || verifyIsRewrittenWithLens(currentGraphLens);
+    methods.remove(method);
+  }
+
   public LongLivedProgramMethodSetBuilder<T> removeAll(Iterable<DexMethod> methods) {
     methods.forEach(this::remove);
     return this;
   }
 
+  public LongLivedProgramMethodSetBuilder<T> removeIf(
+      DexDefinitionSupplier definitions, Predicate<ProgramMethod> predicate) {
+    methods.removeIf(
+        method -> {
+          DexProgramClass holder =
+              asProgramClassOrNull(definitions.definitionFor(method.getHolderType()));
+          ProgramMethod definition = method.lookupOnProgramClass(holder);
+          if (definition == null) {
+            assert false;
+            return true;
+          }
+          return predicate.test(definition);
+        });
+    return this;
+  }
+
   public LongLivedProgramMethodSetBuilder<T> rewrittenWithLens(
       AppView<AppInfoWithLiveness> appView) {
     return rewrittenWithLens(appView.graphLens());
diff --git a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
index 08dd682..5d3c306 100644
--- a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
+++ b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
@@ -204,7 +204,7 @@
             b ->
                 b.addProguardConfiguration(
                     getProguardOptionsNPlus(enableProguardCompatibilityMode), Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 3, "lambdadesugaringnplus"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 2, "lambdadesugaringnplus"))
         .run();
   }
 
@@ -244,7 +244,7 @@
             b ->
                 b.addProguardConfiguration(
                     getProguardOptionsNPlus(enableProguardCompatibilityMode), Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 3, "lambdadesugaringnplus"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 0, "lambdadesugaringnplus"))
         .run();
   }