diff --git a/src/main/java/com/android/tools/r8/graph/GraphLens.java b/src/main/java/com/android/tools/r8/graph/GraphLens.java
index 0c946dd..7bfa6cb 100644
--- a/src/main/java/com/android/tools/r8/graph/GraphLens.java
+++ b/src/main/java/com/android/tools/r8/graph/GraphLens.java
@@ -8,10 +8,12 @@
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.horizontalclassmerging.ClassMerger;
 import com.android.tools.r8.ir.code.Invoke.Type;
+import com.android.tools.r8.ir.conversion.LensCodeRewriterUtils;
 import com.android.tools.r8.ir.desugar.InterfaceProcessor.InterfaceProcessorNestedGraphLens;
 import com.android.tools.r8.shaking.KeepInfoCollection;
 import com.android.tools.r8.utils.Action;
 import com.android.tools.r8.utils.SetUtils;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.collect.BiMap;
 import com.google.common.collect.HashBiMap;
 import com.google.common.collect.ImmutableMap;
@@ -465,6 +467,22 @@
     return true;
   }
 
+  public Map<DexCallSite, ProgramMethodSet> rewriteCallSites(
+      Map<DexCallSite, ProgramMethodSet> callSites, DexDefinitionSupplier definitions) {
+    Map<DexCallSite, ProgramMethodSet> result = new IdentityHashMap<>();
+    LensCodeRewriterUtils rewriter = new LensCodeRewriterUtils(definitions, this);
+    callSites.forEach(
+        (callSite, contexts) -> {
+          for (ProgramMethod context : contexts.rewrittenWithLens(definitions, this)) {
+            DexCallSite rewrittenCallSite = rewriter.rewriteCallSite(callSite, context);
+            result
+                .computeIfAbsent(rewrittenCallSite, ignore -> ProgramMethodSet.create())
+                .add(context);
+          }
+        });
+    return result;
+  }
+
   public DexReference rewriteReference(DexReference reference) {
     if (reference.isDexField()) {
       return getRenamedFieldSignature(reference.asDexField());
diff --git a/src/main/java/com/android/tools/r8/naming/InterfaceMethodNameMinifier.java b/src/main/java/com/android/tools/r8/naming/InterfaceMethodNameMinifier.java
index 060f171..448a8fa 100644
--- a/src/main/java/com/android/tools/r8/naming/InterfaceMethodNameMinifier.java
+++ b/src/main/java/com/android/tools/r8/naming/InterfaceMethodNameMinifier.java
@@ -441,7 +441,7 @@
     // desugared lambdas this is a conservative estimate, as we don't track if the generated
     // lambda classes survive into the output. As multi-interface lambda expressions are rare
     // this is not a big deal.
-    Set<DexCallSite> liveCallSites = appView.appInfo().callSites;
+    Set<DexCallSite> liveCallSites = appView.appInfo().callSites.keySet();
     // If the input program contains a multi-interface lambda expression that implements
     // interface methods with different protos, we need to make sure tha the implemented lambda
     // methods are renamed to the same name.
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index 92722cd..311e10d 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -127,7 +127,7 @@
    * Set of live call sites in the code. Note that if desugaring has taken place call site objects
    * will have been removed from the code.
    */
-  public final Set<DexCallSite> callSites;
+  public final Map<DexCallSite, ProgramMethodSet> callSites;
   /** Collection of keep requirements for the program. */
   private final KeepInfoCollection keepInfo;
   /** All items with assumemayhavesideeffects rule. */
@@ -207,7 +207,7 @@
       FieldAccessInfoCollectionImpl fieldAccessInfoCollection,
       MethodAccessInfoCollection methodAccessInfoCollection,
       ObjectAllocationInfoCollectionImpl objectAllocationInfoCollection,
-      Set<DexCallSite> callSites,
+      Map<DexCallSite, ProgramMethodSet> callSites,
       KeepInfoCollection keepInfo,
       Map<DexReference, ProguardMemberRule> mayHaveSideEffects,
       Map<DexReference, ProguardMemberRule> noSideEffects,
@@ -288,12 +288,7 @@
       FieldAccessInfoCollectionImpl fieldAccessInfoCollection,
       MethodAccessInfoCollection methodAccessInfoCollection,
       ObjectAllocationInfoCollectionImpl objectAllocationInfoCollection,
-      SortedMap<DexMethod, ProgramMethodSet> virtualInvokes,
-      SortedMap<DexMethod, ProgramMethodSet> interfaceInvokes,
-      SortedMap<DexMethod, ProgramMethodSet> superInvokes,
-      SortedMap<DexMethod, ProgramMethodSet> directInvokes,
-      SortedMap<DexMethod, ProgramMethodSet> staticInvokes,
-      Set<DexCallSite> callSites,
+      Map<DexCallSite, ProgramMethodSet> callSites,
       KeepInfoCollection keepInfo,
       Map<DexReference, ProguardMemberRule> mayHaveSideEffects,
       Map<DexReference, ProguardMemberRule> noSideEffects,
@@ -623,7 +618,7 @@
     for (DexProgramClass clazz : classes()) {
       worklist.add(clazz.type);
     }
-    for (DexCallSite callSite : callSites) {
+    for (DexCallSite callSite : callSites.keySet()) {
       for (DexEncodedMethod method : lookupLambdaImplementedMethods(callSite)) {
         worklist.add(method.holder());
       }
@@ -998,9 +993,7 @@
         fieldAccessInfoCollection.rewrittenWithLens(definitionSupplier, lens),
         methodAccessInfoCollection.rewrittenWithLens(definitionSupplier, lens),
         objectAllocationInfoCollection.rewrittenWithLens(definitionSupplier, lens),
-        // TODO(sgjesse): Rewrite call sites as well? Right now they are only used by minification
-        //   after second tree shaking.
-        callSites,
+        lens.rewriteCallSites(callSites, definitionSupplier),
         keepInfo.rewrite(lens, application.options),
         lens.rewriteReferenceKeys(mayHaveSideEffects),
         lens.rewriteReferenceKeys(noSideEffects),
@@ -1261,7 +1254,7 @@
   static void forEachTypeInHierarchyOfLiveProgramClasses(
       Consumer<DexClass> fn,
       Collection<DexProgramClass> liveProgramClasses,
-      Set<DexCallSite> callSites,
+      Map<DexCallSite, ProgramMethodSet> callSites,
       AppInfoWithClassHierarchy appInfo) {
     Set<DexType> seen = Sets.newIdentityHashSet();
     Deque<DexType> worklist = new ArrayDeque<>();
@@ -1278,7 +1271,7 @@
         }
       }
     }
-    for (DexCallSite callSite : callSites) {
+    for (DexCallSite callSite : callSites.keySet()) {
       List<DexType> interfaces = LambdaDescriptor.getInterfaces(callSite, appInfo);
       if (interfaces != null) {
         for (DexType iface : interfaces) {
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index fad27eb..0afccd0 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -206,7 +206,7 @@
   private final MethodAccessInfoCollection.SortedBuilder methodAccessInfoCollection =
       MethodAccessInfoCollection.sortedBuilder();
   private final ObjectAllocationInfoCollectionImpl.Builder objectAllocationInfoCollection;
-  private final Set<DexCallSite> callSites = Sets.newIdentityHashSet();
+  private final Map<DexCallSite, ProgramMethodSet> callSites = new IdentityHashMap<>();
 
   private final Set<DexReference> identifierNameStrings = Sets.newIdentityHashSet();
 
@@ -881,7 +881,7 @@
     } else {
       markLambdaAsInstantiated(descriptor, context);
       transitionMethodsForInstantiatedLambda(descriptor);
-      callSites.add(callSite);
+      callSites.computeIfAbsent(callSite, ignore -> ProgramMethodSet.create()).add(context);
     }
 
     // For call sites representing a lambda, we link the targeted method
