Towards the propagation of call site optimization info.

Bug: 139246447, 80455722
Change-Id: I8d5ed18900d67ed9be483cdd06593edd92fdbaf8
diff --git a/src/main/java/com/android/tools/r8/graph/AppView.java b/src/main/java/com/android/tools/r8/graph/AppView.java
index 623da28..0817a03 100644
--- a/src/main/java/com/android/tools/r8/graph/AppView.java
+++ b/src/main/java/com/android/tools/r8/graph/AppView.java
@@ -9,6 +9,7 @@
 import com.android.tools.r8.ir.analysis.proto.GeneratedExtensionRegistryShrinker;
 import com.android.tools.r8.ir.analysis.proto.GeneratedMessageLiteShrinker;
 import com.android.tools.r8.ir.analysis.proto.ProtoShrinker;
+import com.android.tools.r8.ir.conversion.CallSiteOptimizationInfoPropagator;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.RootSetBuilder.RootSet;
 import com.android.tools.r8.shaking.VerticalClassMerger.VerticallyMergedClasses;
@@ -36,6 +37,7 @@
   private RootSet rootSet;
 
   // Optimizations.
+  private final CallSiteOptimizationInfoPropagator callSiteOptimizationInfoPropagator;
   private final ProtoShrinker protoShrinker;
 
   // Optimization results.
@@ -52,6 +54,13 @@
     this.graphLense = GraphLense.getIdentityLense();
     this.options = options;
 
+    if (enableWholeProgramOptimizations() && options.enableCallSiteOptimizationInfoPropagation) {
+      this.callSiteOptimizationInfoPropagator =
+          new CallSiteOptimizationInfoPropagator(withLiveness());
+    } else {
+      this.callSiteOptimizationInfoPropagator = null;
+    }
+
     if (enableWholeProgramOptimizations() && options.isProtoShrinkingEnabled()) {
       this.protoShrinker = new ProtoShrinker(withLiveness());
     } else {
@@ -153,6 +162,10 @@
     return wholeProgramOptimizations == WholeProgramOptimizations.ON;
   }
 
+  public CallSiteOptimizationInfoPropagator callSiteOptimizationInfoPropagator() {
+    return callSiteOptimizationInfoPropagator;
+  }
+
   public ProtoShrinker protoShrinker() {
     return protoShrinker;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteOptimizationInfo.java
index 06b8500..4f5c08c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteOptimizationInfo.java
@@ -24,7 +24,7 @@
   }
 
   /**
-   * CallSiteOptimizationInfoPropagator will reprocess the call target if its collected call
+   * {@link CallSiteOptimizationInfoPropagator} will reprocess the call target if its collected call
    * site optimization info has something useful that can trigger more optimizations. For example,
    * if a certain argument is guaranteed to be definitely not null for all call sites, null-check on
    * that argument can be simplified during the reprocessing of the method.
@@ -33,6 +33,7 @@
     return false;
   }
 
+  // The index exactly matches with in values of invocation, i.e., even including receiver.
   public abstract Nullability getNullability(int argIndex);
 
   // TODO(b/139246447): extend it to TypeLattice and insert AssumeDynamicType if the join of all
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CallSiteOptimizationInfoPropagator.java b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteOptimizationInfoPropagator.java
new file mode 100644
index 0000000..5cdabd3
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CallSiteOptimizationInfoPropagator.java
@@ -0,0 +1,234 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.conversion;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.ir.analysis.type.Nullability;
+import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
+import com.android.tools.r8.ir.code.Assume;
+import com.android.tools.r8.ir.code.Assume.NonNullAssumption;
+import com.android.tools.r8.ir.code.ConstInstruction;
+import com.android.tools.r8.ir.code.ConstNumber;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.logging.Log;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.ThrowingBiConsumer;
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.function.Predicate;
+
+public class CallSiteOptimizationInfoPropagator {
+
+  // TODO(b/139246447): should we revisit new targets over and over again?
+  //   Maybe piggy-back on MethodProcessor's wave/batch processing?
+  // For now, before revisiting methods with more precise argument info, we switch the mode.
+  // Then, revisiting a target at a certain level will not improve call site information of
+  // callees in lower levels.
+  private enum Mode {
+    COLLECT, // Set until the end of the 1st round of IR processing. CallSiteOptimizationInfo will
+             // be updated in this mode only.
+    REVISIT, // Set once the all methods are processed. IRBuilder will add other instructions that
+             // reflect collected CallSiteOptimizationInfo.
+    FINISH;  // Set once the 2nd round of IR processing is done. Other optimizations that need post
+             // IR processing, e.g., outliner, are still using IRBuilder, and this will isolate the
+             // impact of IR manipulation due to this optimization.
+  }
+
+  private final AppView<AppInfoWithLiveness> appView;
+  private Set<DexEncodedMethod> revisitedMethods = null;
+  private Mode mode = Mode.COLLECT;
+
+  public CallSiteOptimizationInfoPropagator(AppView<AppInfoWithLiveness> appView) {
+    assert appView.enableWholeProgramOptimizations();
+    this.appView = appView;
+    if (Log.isLoggingEnabledFor(CallSiteOptimizationInfoPropagator.class)) {
+      revisitedMethods = Sets.newIdentityHashSet();
+    }
+  }
+
+  void logResults() {
+    assert Log.ENABLED;
+    if (revisitedMethods != null) {
+      Log.info(getClass(), "# of methods to revisit: %s", revisitedMethods.size());
+      for (DexEncodedMethod m : revisitedMethods) {
+        Log.info(getClass(), "%s: %s",
+            m.toSourceString(), m.getMutableCallSiteOptimizationInfo().toString());
+      }
+    }
+  }
+
+  void collectCallSiteOptimizationInfo(IRCode code) {
+    // TODO(b/139246447): we could collect call site optimization during REVISIT mode as well,
+    //   but that may require a separate copy of CallSiteOptimizationInfo.
+    if (mode != Mode.COLLECT) {
+      return;
+    }
+    DexEncodedMethod context = code.method;
+    for (Instruction instruction : code.instructions()) {
+      if (!instruction.isInvokeMethod() && !instruction.isInvokeCustom()) {
+        continue;
+      }
+      if (!MutableCallSiteOptimizationInfo.hasArgumentsToRecord(instruction.inValues())) {
+        continue;
+      }
+      if (instruction.isInvokeMethod()) {
+        Collection<DexEncodedMethod> targets =
+            instruction.asInvokeMethod().lookupTargets(appView, context.method.holder);
+        if (targets == null) {
+          continue;
+        }
+        for (DexEncodedMethod target : targets) {
+          recordArgumentsIfNecessary(context, target, instruction.inValues());
+        }
+      }
+      // TODO(b/129458850): if lambda desugaring happens before IR processing, seeing invoke-custom
+      //  means we can't find matched methods in the app, hence safe to ignore (only for DEX).
+      if (instruction.isInvokeCustom()) {
+        // Conservatively register argument info for all possible lambda implemented methods.
+        Collection<DexEncodedMethod> targets =
+            appView.appInfo().lookupLambdaImplementedMethods(
+                instruction.asInvokeCustom().getCallSite());
+        if (targets == null) {
+          continue;
+        }
+        for (DexEncodedMethod target : targets) {
+          recordArgumentsIfNecessary(context, target, instruction.inValues());
+        }
+      }
+    }
+  }
+
+  private void recordArgumentsIfNecessary(
+      DexEncodedMethod context, DexEncodedMethod target, List<Value> inValues) {
+    assert !target.isObsolete();
+    if (target.shouldNotHaveCode() || target.method.getArity() == 0) {
+      return;
+    }
+    // If pinned, that method could be invoked via reflection.
+    if (appView.appInfo().isPinned(target.method)) {
+      return;
+    }
+    // If the program already has illegal accesses, method resolution results will reflect that too.
+    // We should avoid recording arguments in that case. E.g., b/139823850: static methods can be a
+    // result of virtual call targets, if that's the only method that matches name and signature.
+    int argumentOffset = target.isStatic() ? 0 : 1;
+    if (inValues.size() != argumentOffset + target.method.getArity()) {
+      return;
+    }
+    MutableCallSiteOptimizationInfo optimizationInfo = target.getMutableCallSiteOptimizationInfo();
+    optimizationInfo.recordArguments(context, inValues);
+  }
+
+  // If collected call site optimization info has something useful, e.g., non-null argument,
+  // insert corresponding assume instructions for arguments.
+  void applyCallSiteOptimizationInfo(
+      IRCode code, CallSiteOptimizationInfo callSiteOptimizationInfo) {
+    if (mode != Mode.REVISIT
+        || !callSiteOptimizationInfo.hasUsefulOptimizationInfo()) {
+      return;
+    }
+    Set<Value> affectedValues = Sets.newIdentityHashSet();
+    List<Assume<NonNullAssumption>> assumeInstructions = new LinkedList<>();
+    List<ConstInstruction> constants = new LinkedList<>();
+    int argumentsSeen = 0;
+    InstructionListIterator iterator = code.entryBlock().listIterator(code);
+    while (iterator.hasNext()) {
+      Instruction instr = iterator.next();
+      if (!instr.isArgument()) {
+        break;
+      }
+      Value arg = instr.asArgument().outValue();
+      if (!arg.hasLocalInfo() && arg.getTypeLattice().isReference()) {
+        Nullability nullability = callSiteOptimizationInfo.getNullability(argumentsSeen);
+        if (nullability.isDefinitelyNotNull()) {
+          // If we already knew `arg` is never null, e.g., receiver, skip adding non-null.
+          if (!arg.getTypeLattice().nullability().isDefinitelyNotNull()) {
+            Value nonNullValue = code.createValue(
+                arg.getTypeLattice().asReferenceTypeLatticeElement().asNotNull(),
+                arg.getLocalInfo());
+            affectedValues.addAll(arg.affectedValues());
+            arg.replaceUsers(nonNullValue);
+            Assume<NonNullAssumption> assumeNotNull =
+                Assume.createAssumeNonNullInstruction(nonNullValue, arg, instr, appView);
+            assumeNotNull.setPosition(instr.getPosition());
+            assumeInstructions.add(assumeNotNull);
+          }
+        } else if (nullability.isDefinitelyNull()) {
+          ConstNumber nullInstruction = code.createConstNull();
+          nullInstruction.setPosition(instr.getPosition());
+          affectedValues.addAll(arg.affectedValues());
+          arg.replaceUsers(nullInstruction.outValue());
+          constants.add(nullInstruction);
+        }
+      }
+      // TODO(b/69963623): Handle other kinds of constants, e.g. number, string, or class.
+      argumentsSeen++;
+    }
+    assert argumentsSeen == code.method.method.getArity() + (code.method.isStatic() ? 0 : 1);
+    // After packed Argument instructions, add Assume<?> and constant instructions.
+    assert !iterator.peekPrevious().isArgument();
+    iterator.previous();
+    assert iterator.peekPrevious().isArgument();
+    assumeInstructions.forEach(iterator::add);
+    // TODO(b/69963623): Can update method signature and save more on call sites.
+    constants.forEach(iterator::add);
+
+    if (!affectedValues.isEmpty()) {
+      new TypeAnalysis(appView).narrowing(affectedValues);
+    }
+  }
+
+  <E extends Exception> void revisitMethods(
+      ThrowingBiConsumer<DexEncodedMethod, Predicate<DexEncodedMethod>, E> consumer,
+      ExecutorService executorService)
+      throws ExecutionException {
+    Set<DexEncodedMethod> targetsToRevisit = Sets.newIdentityHashSet();
+    for (DexProgramClass clazz : appView.appInfo().classes()) {
+      for (DexEncodedMethod method : clazz.methods()) {
+        assert !method.isObsolete();
+        if (method.shouldNotHaveCode()
+            || method.getCallSiteOptimizationInfo().isDefaultCallSiteOptimizationInfo()) {
+          continue;
+        }
+        MutableCallSiteOptimizationInfo optimizationInfo =
+            method.getCallSiteOptimizationInfo().asMutableCallSiteOptimizationInfo();
+        if (optimizationInfo.hasUsefulOptimizationInfo()) {
+          targetsToRevisit.add(method);
+        }
+      }
+    }
+    if (targetsToRevisit.isEmpty()) {
+      mode = Mode.FINISH;
+      return;
+    }
+    if (revisitedMethods != null) {
+      revisitedMethods.addAll(targetsToRevisit);
+    }
+    mode = Mode.REVISIT;
+    List<Future<?>> futures = new ArrayList<>();
+    for (DexEncodedMethod method : targetsToRevisit) {
+      futures.add(
+          executorService.submit(
+              () -> {
+                consumer.accept(method, targetsToRevisit::contains);
+                return null; // we want a Callable not a Runnable to be able to throw
+              }));
+    }
+    ThreadUtils.awaitFutures(futures);
+    mode = Mode.FINISH;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
index 9b31a58..902a795 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
@@ -618,6 +618,13 @@
       new TypeAnalysis(appView).widening(context, method, ir);
     }
 
+    // Update the IR code if collected call site optimization info has something useful.
+    CallSiteOptimizationInfo callSiteOptimizationInfo = method.getCallSiteOptimizationInfo();
+    if (appView.callSiteOptimizationInfoPropagator() != null) {
+      appView.callSiteOptimizationInfoPropagator()
+          .applyCallSiteOptimizationInfo(ir, callSiteOptimizationInfo);
+    }
+
     if (appView.options().isStringSwitchConversionEnabled()) {
       StringSwitchConverter.convertToStringSwitchInstructions(ir, appView.dexItemFactory());
     }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 87c174f..d278966 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -611,11 +611,12 @@
     collectLambdaMergingCandidates(application);
     collectStaticizerCandidates(application);
 
-    // The process is in two phases.
-    // 1) Subject all DexEncodedMethods to optimization (except outlining).
-    //    - a side effect is candidates for outlining are identified.
-    // 2) Perform outlining for the collected candidates.
-    // Ideally, we should outline eagerly when threshold for a template has been reached.
+    // The process is in two phases in general.
+    // 1) Subject all DexEncodedMethods to optimization, except some optimizations that require
+    //    reprocessing IR code of methods, e.g., outlining, double-inlining, class staticizer, etc.
+    //    - a side effect is candidates for those optimizations are identified.
+    // 2) Revisit DexEncodedMethods for the collected candidates.
+    // TODO(b/127694949): unified framework to reprocess methods only once.
 
     printPhase("Primary optimization pass");
 
@@ -649,6 +650,22 @@
       libraryMethodOverrideAnalysis.finish();
     }
 
+    // Second pass for methods whose collected call site information become more precise.
+    if (appView.callSiteOptimizationInfoPropagator() != null) {
+      printPhase("2nd round of method processing after inter-procedural analysis.");
+      timing.begin("IR conversion phase 2");
+      appView.callSiteOptimizationInfoPropagator().revisitMethods(
+          (method, isProcessedConcurrently) ->
+              processMethod(
+                  method,
+                  feedback,
+                  isProcessedConcurrently,
+                  CallSiteInformation.empty(),
+                  Outliner::noProcessing),
+          executorService);
+      timing.end();
+    }
+
     // Second inlining pass for dealing with double inline callers.
     if (inliner != null) {
       printPhase("Double caller inlining");
@@ -684,7 +701,7 @@
 
     if (outliner != null) {
       printPhase("Outlining");
-      timing.begin("IR conversion phase 2");
+      timing.begin("IR conversion phase 3");
       if (outliner.selectMethodsForOutlining()) {
         forEachSelectedOutliningMethod(
             executorService,
@@ -715,6 +732,9 @@
     }
 
     if (Log.ENABLED) {
+      if (appView.callSiteOptimizationInfoPropagator() != null) {
+        appView.callSiteOptimizationInfoPropagator().logResults();
+      }
       constantCanonicalizer.logResults();
       if (idempotentFunctionCallCanonicalizer != null) {
         idempotentFunctionCallCanonicalizer.logResults();
@@ -1165,6 +1185,13 @@
     }
 
     if (nonNullTracker != null) {
+      // TODO(b/139246447): Once we extend this optimization to, e.g., constants of primitive args,
+      //   this may not be the right place to collect call site optimization info.
+      // Collecting call-site optimization info depends on the existence of non-null IRs.
+      // Arguments can be changed during the debug mode.
+      if (!isDebugMode && appView.callSiteOptimizationInfoPropagator() != null) {
+        appView.callSiteOptimizationInfoPropagator().collectCallSiteOptimizationInfo(code);
+      }
       // Computation of non-null parameters on normal exits rely on the existence of non-null IRs.
       nonNullTracker.computeNonNullParamOnNormalExits(feedback, code);
       assert code.isConsistentSSA();
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MutableCallSiteOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/conversion/MutableCallSiteOptimizationInfo.java
index e4adac1..76268b0 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/MutableCallSiteOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MutableCallSiteOptimizationInfo.java
@@ -5,15 +5,194 @@
 
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.ir.analysis.type.Nullability;
+import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.utils.StringUtils;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 
 public class MutableCallSiteOptimizationInfo extends CallSiteOptimizationInfo {
 
+  // inValues() size == DexMethod.arity + (isStatic ? 0 : 1) // receiver
+  // That is, this information takes into account the receiver as well.
+  private final int size;
+  // Mappings from the calling context to argument collection. Note that, even in the same context,
+  // the corresponding method can be invoked multiple times with different arguments, hence *set* of
+  // argument collections.
+  private final Map<DexEncodedMethod, Set<ArgumentCollection>> callSiteInfos =
+      new ConcurrentHashMap<>();
+  private ArgumentCollection cachedRepresentative = null;
+
+  static class ArgumentCollection {
+
+    // TODO(b/139246447): extend it to TypeLattice as well as constants/ranges.
+    Nullability[] nullabilities;
+
+    private static final ArgumentCollection BOTTOM = new ArgumentCollection() {
+      @Override
+      Nullability getNullability(int index) {
+        return Nullability.maybeNull();
+      }
+
+      @Override
+      public int hashCode() {
+        return System.identityHashCode(this);
+      }
+
+      @Override
+      public String toString() {
+        return "(BOTTOM)";
+      }
+    };
+
+    private ArgumentCollection() {}
+
+    ArgumentCollection(int size) {
+      this.nullabilities = new Nullability[size];
+      for (int i = 0; i < size; i++) {
+        this.nullabilities[i] = Nullability.maybeNull();
+      }
+    }
+
+    Nullability getNullability(int index) {
+      assert nullabilities != null;
+      assert 0 <= index && index < nullabilities.length;
+      return nullabilities[index];
+    }
+
+    ArgumentCollection copy() {
+      ArgumentCollection copy = new ArgumentCollection();
+      copy.nullabilities = new Nullability[this.nullabilities.length];
+      System.arraycopy(this.nullabilities, 0, copy.nullabilities, 0, this.nullabilities.length);
+      return copy;
+    }
+
+    ArgumentCollection join(ArgumentCollection other) {
+      if (other == BOTTOM) {
+        return this;
+      }
+      if (this == BOTTOM) {
+        return other;
+      }
+      assert this.nullabilities.length == other.nullabilities.length;
+      ArgumentCollection result = this.copy();
+      for (int i = 0; i < result.nullabilities.length; i++) {
+        result.nullabilities[i] = result.nullabilities[i].join(other.nullabilities[i]);
+      }
+      return result;
+    }
+
+    static ArgumentCollection join(Collection<ArgumentCollection> collections) {
+      return collections.stream().reduce(BOTTOM, ArgumentCollection::join);
+    }
+
+    @Override
+    public boolean equals(Object other) {
+      if (!(other instanceof ArgumentCollection)) {
+        return false;
+      }
+      ArgumentCollection otherCollection = (ArgumentCollection) other;
+      if (this == BOTTOM || otherCollection == BOTTOM) {
+        return this == BOTTOM && otherCollection == BOTTOM;
+      }
+      if (this.nullabilities.length != otherCollection.nullabilities.length) {
+        return false;
+      }
+      for (int i = 0; i < this.nullabilities.length; i++) {
+        if (!this.nullabilities[i].equals(otherCollection.nullabilities[i])) {
+          return false;
+        }
+      }
+      return true;
+    }
+
+    @Override
+    public int hashCode() {
+      return Arrays.hashCode(nullabilities);
+    }
+
+    @Override
+    public String toString() {
+      return "(" + StringUtils.join(Arrays.asList(nullabilities), ", ") + ")";
+    }
+  }
+
   public MutableCallSiteOptimizationInfo(DexEncodedMethod encodedMethod) {
+    assert encodedMethod.method.getArity() > 0;
+    this.size = encodedMethod.method.getArity() + (encodedMethod.isStatic() ? 0 : 1);
+  }
+
+  private void computeCachedRepresentativeIfNecessary() {
+    if (cachedRepresentative == null && !callSiteInfos.isEmpty()) {
+      synchronized (callSiteInfos) {
+        // Make sure collected information is not flushed out by other threads.
+        if (!callSiteInfos.isEmpty()) {
+          cachedRepresentative =
+              callSiteInfos.values().stream()
+                  .reduce(
+                      ArgumentCollection.BOTTOM,
+                      (prev, collections) -> prev.join(ArgumentCollection.join(collections)),
+                      ArgumentCollection::join);
+          // After creating a cached representative, flush out the collected information.
+          callSiteInfos.clear();
+        } else {
+          // If collected information is gone while waiting for the lock, make sure it's used to
+          // compute the cached representative.
+          assert cachedRepresentative != null;
+        }
+      }
+    }
+  }
+
+  @Override
+  public boolean hasUsefulOptimizationInfo() {
+    computeCachedRepresentativeIfNecessary();
+    for (int i = 0; i < size; i++) {
+      Nullability nullability = getNullability(i);
+      if (nullability.isDefinitelyNull() || nullability.isDefinitelyNotNull()) {
+        return true;
+      }
+    }
+    return false;
   }
 
   @Override
   public Nullability getNullability(int argIndex) {
-    return Nullability.maybeNull();
+    assert 0 <= argIndex && argIndex < size;
+    if (cachedRepresentative == null) {
+      return Nullability.maybeNull();
+    }
+    return cachedRepresentative.getNullability(argIndex);
+  }
+
+  static boolean hasArgumentsToRecord(List<Value> inValues) {
+    // TODO(b/69963623): allow primitive types with compile-time constants.
+    for (Value v : inValues) {
+      if (v.getTypeLattice().isReference()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  void recordArguments(DexEncodedMethod callingContext, List<Value> inValues) {
+    assert cachedRepresentative == null;
+    assert size == inValues.size();
+    Set<ArgumentCollection> collections =
+        callSiteInfos.computeIfAbsent(callingContext, ignore -> new HashSet<>());
+    ArgumentCollection newCallSiteInfo = new ArgumentCollection(size);
+    for (int i = 0; i < size; i++) {
+      TypeLatticeElement typeLatticeElement = inValues.get(i).getTypeLattice();
+      if (typeLatticeElement.isReference()) {
+        newCallSiteInfo.nullabilities[i] = typeLatticeElement.nullability();
+      }
+    }
+    collections.add(newCallSiteInfo);
   }
 
   @Override
@@ -25,4 +204,20 @@
   public MutableCallSiteOptimizationInfo asMutableCallSiteOptimizationInfo() {
     return this;
   }
+
+  @Override
+  public String toString() {
+    if (cachedRepresentative != null) {
+      return cachedRepresentative.toString();
+    }
+    StringBuilder builder = new StringBuilder();
+    for (Map.Entry<DexEncodedMethod, Set<ArgumentCollection>> entry : callSiteInfos.entrySet()) {
+      builder.append(entry.getKey().toSourceString());
+      builder.append(" -> {");
+      StringUtils.join(entry.getValue(), ", ");
+      builder.append("}");
+      builder.append(System.lineSeparator());
+    }
+    return builder.toString();
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 1f19169..d3e5566 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -169,6 +169,7 @@
     enableValuePropagation = false;
     enableSideEffectAnalysis = false;
     enableTreeShakingOfLibraryMethodOverrides = false;
+    enableCallSiteOptimizationInfoPropagation = false;
   }
 
   public boolean printTimes = System.getProperty("com.android.tools.r8.printtimes") != null;
@@ -209,6 +210,8 @@
   public boolean enableNameReflectionOptimization = true;
   public boolean enableStringConcatenationOptimization = true;
   public boolean enableTreeShakingOfLibraryMethodOverrides = false;
+  // TODO(b/139246447): enable after branching.
+  public boolean enableCallSiteOptimizationInfoPropagation = false;
   public boolean encodeChecksums = false;
   public BiPredicate<String, Long> dexClassChecksumFilter = (name, checksum) -> true;
 
@@ -1019,6 +1022,13 @@
     enableNameReflectionOptimization = false;
   }
 
+  // TODO(b/139246447): Remove this once enabled.
+  @VisibleForTesting
+  public void enableCallSiteOptimizationInfoPropagation() {
+    assert !enableCallSiteOptimizationInfoPropagation;
+    enableCallSiteOptimizationInfoPropagation = true;
+  }
+
   private boolean hasMinApi(AndroidApiLevel level) {
     assert isGeneratingDex();
     return minApiLevel >= level.getLevel();
diff --git a/src/test/examples/shaking18/Options.java b/src/test/examples/shaking18/Options.java
index 7012340..edb9d77 100644
--- a/src/test/examples/shaking18/Options.java
+++ b/src/test/examples/shaking18/Options.java
@@ -4,7 +4,9 @@
 package shaking18;
 
 public class Options {
-  public boolean alwaysFalse = false;
+  // TODO(b/138913138): member value propagation can behave same with and without initialization.
+  // public boolean alwaysFalse = false;
+  public boolean alwaysFalse;
   public boolean dummy = false;
 
   public Options() {}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeDirectPositiveTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeDirectPositiveTest.java
index e919b7c..0ebf6b9 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeDirectPositiveTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeDirectPositiveTest.java
@@ -5,13 +5,14 @@
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertEquals;
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
@@ -42,6 +43,7 @@
         .addKeepMainRule(MAIN)
         .enableClassInliningAnnotations()
         .enableInliningAnnotations()
+        .addOptionsModification(InternalOptions::enableCallSiteOptimizationInfoPropagation)
         .setMinApi(parameters.getRuntime())
         .run(parameters.getRuntime(), MAIN)
         .assertSuccessWithOutputLines("non-null")
@@ -54,8 +56,8 @@
 
     MethodSubject test = main.uniqueMethodWithName("test");
     assertThat(test, isPresent());
-    // TODO(b/139246447): Can optimize branches since `arg` is definitely not null.
-    assertNotEquals(0, test.streamInstructions().filter(InstructionSubject::isIf).count());
+    // Can optimize branches since `arg` is definitely not null.
+    assertEquals(0, test.streamInstructions().filter(InstructionSubject::isIf).count());
   }
 
   @NeverClassInline
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeInterfacePositiveTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeInterfacePositiveTest.java
index 7f14dd3..bebc3eb 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeInterfacePositiveTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeInterfacePositiveTest.java
@@ -5,13 +5,14 @@
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertEquals;
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
@@ -42,6 +43,7 @@
         .addKeepMainRule(MAIN)
         .enableClassInliningAnnotations()
         .enableInliningAnnotations()
+        .addOptionsModification(InternalOptions::enableCallSiteOptimizationInfoPropagation)
         .addOptionsModification(o -> {
           // To prevent invoke-interface from being rewritten to invoke-virtual w/ a single target.
           o.enableDevirtualization = false;
@@ -61,16 +63,16 @@
 
     MethodSubject a_m = a.uniqueMethodWithName("m");
     assertThat(a_m, isPresent());
-    // TODO(b/139246447): Can optimize branches since `arg` is definitely not null.
-    assertNotEquals(0, a_m.streamInstructions().filter(InstructionSubject::isIf).count());
+    // Can optimize branches since `arg` is definitely not null.
+    assertEquals(0, a_m.streamInstructions().filter(InstructionSubject::isIf).count());
 
     ClassSubject b = inspector.clazz(A.class);
     assertThat(b, isPresent());
 
     MethodSubject b_m = b.uniqueMethodWithName("m");
     assertThat(b_m, isPresent());
-    // TODO(b/139246447): Can optimize branches since `arg` is definitely not null.
-    assertNotEquals(0, b_m.streamInstructions().filter(InstructionSubject::isIf).count());
+    // Can optimize branches since `arg` is definitely not null.
+    assertEquals(0, b_m.streamInstructions().filter(InstructionSubject::isIf).count());
   }
 
   interface I {
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeStaticPositiveTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeStaticPositiveTest.java
index 866f8e1..c37e990 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeStaticPositiveTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeStaticPositiveTest.java
@@ -5,12 +5,13 @@
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertEquals;
 
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
@@ -41,6 +42,7 @@
         .addKeepMainRule(MAIN)
         .enableInliningAnnotations()
         .setMinApi(parameters.getRuntime())
+        .addOptionsModification(InternalOptions::enableCallSiteOptimizationInfoPropagation)
         .run(parameters.getRuntime(), MAIN)
         .assertSuccessWithOutputLines("non-null")
         .inspect(this::inspect);
@@ -52,8 +54,8 @@
 
     MethodSubject test = main.uniqueMethodWithName("test");
     assertThat(test, isPresent());
-    // TODO(b/139246447): Can optimize branches since `arg` is definitely not null.
-    assertNotEquals(0, test.streamInstructions().filter(InstructionSubject::isIf).count());
+    // Can optimize branches since `arg` is definitely not null.
+    assertEquals(0, test.streamInstructions().filter(InstructionSubject::isIf).count());
   }
 
   static class Main {
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeVirtualPositiveTest.java b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeVirtualPositiveTest.java
index 09883c3..deb3204 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeVirtualPositiveTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/callsites/nullability/InvokeVirtualPositiveTest.java
@@ -5,7 +5,7 @@
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertEquals;
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
@@ -13,6 +13,7 @@
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
@@ -44,6 +45,7 @@
         .enableMergeAnnotations()
         .enableClassInliningAnnotations()
         .enableInliningAnnotations()
+        .addOptionsModification(InternalOptions::enableCallSiteOptimizationInfoPropagation)
         .setMinApi(parameters.getRuntime())
         .run(parameters.getRuntime(), MAIN)
         .assertSuccessWithOutputLines("A", "B")
@@ -56,16 +58,16 @@
 
     MethodSubject a_m = a.uniqueMethodWithName("m");
     assertThat(a_m, isPresent());
-    // TODO(b/139246447): Can optimize branches since `arg` is definitely not null.
-    assertNotEquals(0, a_m.streamInstructions().filter(InstructionSubject::isIf).count());
+    // Can optimize branches since `arg` is definitely not null.
+    assertEquals(0, a_m.streamInstructions().filter(InstructionSubject::isIf).count());
 
     ClassSubject b = inspector.clazz(B.class);
     assertThat(b, isPresent());
 
     MethodSubject b_m = b.uniqueMethodWithName("m");
     assertThat(b_m, isPresent());
-    // TODO(b/139246447): Can optimize branches since `arg` is definitely not null.
-    assertNotEquals(0, b_m.streamInstructions().filter(InstructionSubject::isIf).count());
+    // Can optimize branches since `arg` is definitely not null.
+    assertEquals(0, b_m.streamInstructions().filter(InstructionSubject::isIf).count());
   }
 
   @NeverMerge
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/reflection/GetClassTest.java b/src/test/java/com/android/tools/r8/ir/optimize/reflection/GetClassTest.java
index 71063f4..90aeb7c 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/reflection/GetClassTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/reflection/GetClassTest.java
@@ -15,6 +15,7 @@
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.TestRunResult;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
@@ -130,6 +131,13 @@
     this.parameters = parameters;
   }
 
+  private void configure(InternalOptions options) {
+    // In `getMainClass`, a call with `null`, which will throw NPE, is replaced with null throwing
+    // code. Then, remaining call with non-null argument made getClass() replaceable.
+    // Disable the propagation of call site information to separate the tests.
+    options.enableCallSiteOptimizationInfoPropagation = false;
+  }
+
   @Test
   public void testJVMOutput() throws Exception {
     assumeTrue("Only run JVM reference on CF runtimes", parameters.isCfRuntime());
@@ -174,6 +182,7 @@
         testForD8()
             .debug()
             .addProgramClassesAndInnerClasses(MAIN)
+            .addOptionsModification(this::configure)
             .setMinApi(parameters.getRuntime())
             .run(parameters.getRuntime(), MAIN)
             .assertSuccessWithOutput(JAVA_OUTPUT);
@@ -184,6 +193,7 @@
         testForD8()
             .release()
             .addProgramClassesAndInnerClasses(MAIN)
+            .addOptionsModification(this::configure)
             .setMinApi(parameters.getRuntime())
             .run(parameters.getRuntime(), MAIN)
             .assertSuccessWithOutput(JAVA_OUTPUT);
@@ -200,14 +210,11 @@
             .enableInliningAnnotations()
             .addKeepMainRule(MAIN)
             .noMinification()
+            .addOptionsModification(this::configure)
             .setMinApi(parameters.getRuntime())
             .run(parameters.getRuntime(), MAIN);
     test(result, true, false);
 
-    // The number of expected const-class instructions differs because constant canonicalization is
-    // only enabled for the DEX backend.
-    int expectedConstClassCount = parameters.isCfRuntime() ? 7 : 5;
-
     // R8 release, no minification.
     result =
         testForR8(parameters.getBackend())
@@ -215,6 +222,7 @@
             .enableInliningAnnotations()
             .addKeepMainRule(MAIN)
             .noMinification()
+            .addOptionsModification(this::configure)
             .setMinApi(parameters.getRuntime())
             .run(parameters.getRuntime(), MAIN)
             .assertSuccessWithOutput(JAVA_OUTPUT);
@@ -226,6 +234,7 @@
             .addProgramClassesAndInnerClasses(MAIN)
             .enableInliningAnnotations()
             .addKeepMainRule(MAIN)
+            .addOptionsModification(this::configure)
             .setMinApi(parameters.getRuntime())
             // We are not checking output because it can't be matched due to minification. Just run.
             .run(parameters.getRuntime(), MAIN);
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/string/StringValueOfTest.java b/src/test/java/com/android/tools/r8/ir/optimize/string/StringValueOfTest.java
index b07d8f1..85cf999 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/string/StringValueOfTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/string/StringValueOfTest.java
@@ -55,6 +55,10 @@
   }
 
   private void configure(InternalOptions options) {
+    // Disable the propagation of call site information to test String#valueOf optimization with
+    // nullable argument. Otherwise, e.g., we know that only `null` is used for `hideNPE`, and then
+    // simplify everything in that method.
+    options.enableCallSiteOptimizationInfoPropagation = false;
     options.testing.forceNameReflectionOptimization = true;
   }
 
@@ -90,7 +94,6 @@
     assertEquals(expectedCount, countNullStringNumber(mainMethod));
 
     MethodSubject hideNPE = mainClass.uniqueMethodWithName("hideNPE");
-    assertThat(hideNPE, isPresent());
     // Due to the nullable argument, valueOf should remain.
     assertEquals(1, countCall(hideNPE, "String", "valueOf"));
 
diff --git a/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java b/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
index 97ec186..a044d3f 100644
--- a/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
+++ b/src/test/java/com/android/tools/r8/kotlin/SimplifyIfNotNullKotlinTest.java
@@ -7,6 +7,7 @@
 
 import com.android.tools.r8.ToolHelper.KotlinTargetVersion;
 import com.android.tools.r8.naming.MemberNaming.MethodSignature;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.InstructionSubject;
@@ -33,25 +34,21 @@
     final String mainClassName = ex1.getClassName();
     final String extraRules =
         keepMainMethod(mainClassName) + neverInlineMethod(mainClassName, testMethodSignature);
-    runTest(FOLDER, mainClassName, extraRules, app -> {
-      CodeInspector codeInspector = new CodeInspector(app);
-      ClassSubject clazz = checkClassIsKept(codeInspector, ex1.getClassName());
+    runTest(FOLDER, mainClassName, extraRules,
+        InternalOptions::enableCallSiteOptimizationInfoPropagation,
+        app -> {
+          CodeInspector codeInspector = new CodeInspector(app);
+          ClassSubject clazz = checkClassIsKept(codeInspector, ex1.getClassName());
 
-      MethodSubject testMethod = checkMethodIsKept(clazz, testMethodSignature);
-      long ifzCount =
-          testMethod.streamInstructions().filter(i -> i.isIfEqz() || i.isIfNez()).count();
-      long paramNullCheckCount =
-          countCall(testMethod, "ArrayIteratorKt", "checkParameterIsNotNull");
-      if (allowAccessModification) {
-        // Three null-check's from inlined checkParameterIsNotNull for receiver and two arguments.
-        assertEquals(5, ifzCount);
-        assertEquals(0, paramNullCheckCount);
-      } else {
-        // One after Iterator#hasNext, and another in the filter predicate: sinceYear != null.
-        assertEquals(2, ifzCount);
-        assertEquals(5, paramNullCheckCount);
-      }
-    });
+          MethodSubject testMethod = checkMethodIsKept(clazz, testMethodSignature);
+          long ifzCount =
+              testMethod.streamInstructions().filter(i -> i.isIfEqz() || i.isIfNez()).count();
+          long paramNullCheckCount =
+              countCall(testMethod, "ArrayIteratorKt", "checkParameterIsNotNull");
+          // One after Iterator#hasNext, and another in the filter predicate: sinceYear != null.
+          assertEquals(2, ifzCount);
+          assertEquals(allowAccessModification ? 0 : 5, paramNullCheckCount);
+        });
   }
 
   @Test
@@ -63,24 +60,20 @@
     final String mainClassName = ex2.getClassName();
     final String extraRules =
         keepMainMethod(mainClassName) + neverInlineMethod(mainClassName, testMethodSignature);
-    runTest(FOLDER, mainClassName, extraRules, app -> {
-      CodeInspector codeInspector = new CodeInspector(app);
-      ClassSubject clazz = checkClassIsKept(codeInspector, ex2.getClassName());
+    runTest(FOLDER, mainClassName, extraRules,
+        InternalOptions::enableCallSiteOptimizationInfoPropagation,
+        app -> {
+          CodeInspector codeInspector = new CodeInspector(app);
+          ClassSubject clazz = checkClassIsKept(codeInspector, ex2.getClassName());
 
-      MethodSubject testMethod = checkMethodIsKept(clazz, testMethodSignature);
-      long ifzCount = testMethod.streamInstructions().filter(InstructionSubject::isIfEqz).count();
-      long paramNullCheckCount =
-          countCall(testMethod, "Intrinsics", "checkParameterIsNotNull");
-      if (allowAccessModification) {
-        // One null-check from inlined checkParameterIsNotNull.
-        assertEquals(2, ifzCount);
-        assertEquals(0, paramNullCheckCount);
-      } else {
-        // ?: in aOrDefault
-        assertEquals(1, ifzCount);
-        assertEquals(1, paramNullCheckCount);
-      }
-    });
+          MethodSubject testMethod = checkMethodIsKept(clazz, testMethodSignature);
+          long ifzCount = testMethod.streamInstructions().filter(InstructionSubject::isIfEqz).count();
+          long paramNullCheckCount =
+              countCall(testMethod, "Intrinsics", "checkParameterIsNotNull");
+          // ?: in aOrDefault
+          assertEquals(1, ifzCount);
+          assertEquals(allowAccessModification ? 0 : 1, paramNullCheckCount);
+        });
   }
 
   @Test
@@ -92,15 +85,17 @@
     final String mainClassName = ex3.getClassName();
     final String extraRules =
         keepMainMethod(mainClassName) + neverInlineMethod(mainClassName, testMethodSignature);
-    runTest(FOLDER, mainClassName, extraRules, app -> {
-      CodeInspector codeInspector = new CodeInspector(app);
-      ClassSubject clazz = checkClassIsKept(codeInspector, ex3.getClassName());
+    runTest(FOLDER, mainClassName, extraRules,
+        InternalOptions::enableCallSiteOptimizationInfoPropagation,
+        app -> {
+          CodeInspector codeInspector = new CodeInspector(app);
+          ClassSubject clazz = checkClassIsKept(codeInspector, ex3.getClassName());
 
-      MethodSubject testMethod = checkMethodIsKept(clazz, testMethodSignature);
-      long ifzCount = testMethod.streamInstructions().filter(InstructionSubject::isIfEqz).count();
-      // !! operator inside explicit null check should be gone.
-      // One explicit null-check as well as 4 bar? accesses.
-      assertEquals(5, ifzCount);
-    });
+          MethodSubject testMethod = checkMethodIsKept(clazz, testMethodSignature);
+          long ifzCount = testMethod.streamInstructions().filter(InstructionSubject::isIfEqz).count();
+          // !! operator inside explicit null check should be gone.
+          // One explicit null-check as well as 4 bar? accesses.
+          assertEquals(5, ifzCount);
+        });
   }
 }
diff --git a/src/test/java/com/android/tools/r8/shaking/ReturnTypeTest.java b/src/test/java/com/android/tools/r8/shaking/ReturnTypeTest.java
index 244199e..ecd6635 100644
--- a/src/test/java/com/android/tools/r8/shaking/ReturnTypeTest.java
+++ b/src/test/java/com/android/tools/r8/shaking/ReturnTypeTest.java
@@ -79,6 +79,11 @@
         .addKeepMainRule(MAIN)
         .setMinApi(parameters.getRuntime())
         .addOptionsModification(o -> {
+          // No actual implementation of B112517039I, rather invoked with `null`.
+          // Call site optimization propagation will conclude that the input of B...Caller#call is
+          // always null, and replace the last call with null-throwing instruction.
+          // However, we want to test return type and parameter type are kept in this scenario.
+          o.enableCallSiteOptimizationInfoPropagation = false;
           o.enableInlining = false;
         })
         .run(parameters.getRuntime(), MAIN)
diff --git a/src/test/java/com/android/tools/r8/shaking/examples/TreeShaking18Test.java b/src/test/java/com/android/tools/r8/shaking/examples/TreeShaking18Test.java
index fa76296..f2a924d 100644
--- a/src/test/java/com/android/tools/r8/shaking/examples/TreeShaking18Test.java
+++ b/src/test/java/com/android/tools/r8/shaking/examples/TreeShaking18Test.java
@@ -3,13 +3,15 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.shaking.examples;
 
+import static org.junit.Assert.assertFalse;
+
 import com.android.tools.r8.shaking.TreeShakingTest;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
-import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -39,12 +41,12 @@
         TreeShaking18Test::unusedRemoved,
         null,
         null,
-        ImmutableList.of("src/test/examples/shaking18/keep-rules.txt"));
+        ImmutableList.of("src/test/examples/shaking18/keep-rules.txt"),
+        InternalOptions::enableCallSiteOptimizationInfoPropagation);
   }
 
   private static void unusedRemoved(CodeInspector inspector) {
-    // TODO(b/80455722): Change to assertFalse when tree-shaking detects this case.
-    Assert.assertTrue(
+    assertFalse(
         "DerivedUnused should be removed", inspector.clazz("shaking18.DerivedUnused").isPresent());
   }
 }