Merge "Apply graph lense to DexValueMethodType and DexValueType in LensCodeRewriter"
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 9a88674..1877053 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -15,6 +15,8 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexValue;
 import com.android.tools.r8.graph.DexValue.DexValueMethodHandle;
+import com.android.tools.r8.graph.DexValue.DexValueMethodType;
+import com.android.tools.r8.graph.DexValue.DexValueType;
 import com.android.tools.r8.graph.GraphLense;
 import com.android.tools.r8.graph.GraphLense.GraphLenseLookupResult;
 import com.android.tools.r8.ir.code.BasicBlock;
@@ -36,11 +38,11 @@
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.StaticPut;
 import com.android.tools.r8.ir.code.Value;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.ListIterator;
 import java.util.Map;
-import java.util.stream.Collectors;
 
 public class LensCodeRewriter {
 
@@ -80,16 +82,7 @@
                   callSite.methodProto, graphLense::lookupType, protoFixupCache);
           DexMethodHandle newBootstrapMethod =
               rewriteDexMethodHandle(callSite.bootstrapMethod, method);
-          List<DexValue> newArgs = callSite.bootstrapArgs.stream().map(
-              (arg) -> {
-                if (arg instanceof DexValueMethodHandle) {
-                  return new DexValueMethodHandle(
-                      rewriteDexMethodHandle(((DexValueMethodHandle) arg).value, method));
-                }
-                return arg;
-              })
-              .collect(Collectors.toList());
-
+          List<DexValue> newArgs = rewriteBootstrapArgs(callSite.bootstrapArgs, method);
           if (!newMethodProto.equals(callSite.methodProto)
               || newBootstrapMethod != callSite.bootstrapMethod
               || !newArgs.equals(callSite.bootstrapArgs)) {
@@ -225,6 +218,47 @@
     assert code.isConsistentSSA();
   }
 
+  private List<DexValue> rewriteBootstrapArgs(
+      List<DexValue> bootstrapArgs, DexEncodedMethod method) {
+    List<DexValue> newBoostrapArgs = null;
+    boolean changed = false;
+    for (int i = 0; i < bootstrapArgs.size(); i++) {
+      DexValue argument = bootstrapArgs.get(i);
+      DexValue newArgument = null;
+      if (argument instanceof DexValueMethodHandle) {
+        DexMethodHandle oldHandle = ((DexValueMethodHandle) argument).value;
+        DexMethodHandle newHandle = rewriteDexMethodHandle(oldHandle, method);
+        if (newHandle != oldHandle) {
+          newArgument = new DexValueMethodHandle(newHandle);
+        }
+      } else if (argument instanceof DexValueMethodType) {
+        DexProto oldProto = ((DexValueMethodType) argument).value;
+        DexProto newProto =
+            appInfo.dexItemFactory.applyClassMappingToProto(
+                oldProto, graphLense::lookupType, protoFixupCache);
+        if (newProto != oldProto) {
+          newArgument = new DexValueMethodType(newProto);
+        }
+      } else if (argument instanceof DexValueType) {
+        DexType oldType = ((DexValueType) argument).value;
+        DexType newType = graphLense.lookupType(oldType);
+        if (newType != oldType) {
+          newArgument = new DexValueType(newType);
+        }
+      }
+      if (newArgument != null) {
+        if (newBoostrapArgs == null) {
+          newBoostrapArgs = new ArrayList<>(bootstrapArgs.subList(0, i));
+        }
+        newBoostrapArgs.add(newArgument);
+        changed = true;
+      } else if (newBoostrapArgs != null) {
+        newBoostrapArgs.add(argument);
+      }
+    }
+    return changed ? newBoostrapArgs : bootstrapArgs;
+  }
+
   private DexMethodHandle rewriteDexMethodHandle(
       DexMethodHandle methodHandle, DexEncodedMethod context) {
     if (methodHandle.isMethodHandle()) {