LambdaRewriter: Should use invocationContext

Bug: 147578480
Change-Id: I5d9173ec88ad092338978350dfd8e085a4050a34
diff --git a/src/main/java/com/android/tools/r8/graph/AppView.java b/src/main/java/com/android/tools/r8/graph/AppView.java
index 1581a5e..1388051 100644
--- a/src/main/java/com/android/tools/r8/graph/AppView.java
+++ b/src/main/java/com/android/tools/r8/graph/AppView.java
@@ -351,6 +351,13 @@
   }
 
   @SuppressWarnings("unchecked")
+  public AppView<AppInfoWithClassHierarchy> withClassHierarchy() {
+    return appInfo.hasClassHierarchy()
+        ? (AppView<AppInfoWithClassHierarchy>) this
+        : null;
+  }
+
+  @SuppressWarnings("unchecked")
   public AppView<AppInfoWithSubtyping> withSubtyping() {
     return appInfo.hasSubtyping()
         ? (AppView<AppInfoWithSubtyping>) this
diff --git a/src/main/java/com/android/tools/r8/ir/code/InvokeCustom.java b/src/main/java/com/android/tools/r8/ir/code/InvokeCustom.java
index 6f79a82..2c7dcfb 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InvokeCustom.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InvokeCustom.java
@@ -67,8 +67,8 @@
     if (!appView.appInfo().hasSubtyping()) {
       return returnTypeLattice;
     }
-
-    List<DexType> lambdaInterfaces = LambdaDescriptor.getInterfaces(callSite, appView.appInfo());
+    List<DexType> lambdaInterfaces =
+        LambdaDescriptor.getInterfaces(callSite, appView.appInfo().withSubtyping());
     if (lambdaInterfaces == null || lambdaInterfaces.isEmpty()) {
       return returnTypeLattice;
     }
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaDescriptor.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaDescriptor.java
index ead20ac..2eb7b03 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaDescriptor.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaDescriptor.java
@@ -5,12 +5,13 @@
 package com.android.tools.r8.ir.desugar;
 
 import com.android.tools.r8.errors.Unreachable;
-import com.android.tools.r8.graph.AppInfo;
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.DexCallSite;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexMethodHandle;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
@@ -58,9 +59,16 @@
     targetHolder = null;
   }
 
-  private LambdaDescriptor(AppInfo appInfo, DexCallSite callSite,
-      DexString name, DexProto erasedProto, DexProto enforcedProto,
-      DexMethodHandle implHandle, DexType mainInterface, DexTypeList captures) {
+  private LambdaDescriptor(
+      AppInfoWithClassHierarchy appInfo,
+      DexType invocationContext,
+      DexCallSite callSite,
+      DexString name,
+      DexProto erasedProto,
+      DexProto enforcedProto,
+      DexMethodHandle implHandle,
+      DexType mainInterface,
+      DexTypeList captures) {
     assert appInfo != null;
     assert callSite != null;
     assert name != null;
@@ -78,8 +86,8 @@
     this.captures = captures;
 
     this.interfaces.add(mainInterface);
-
-    DexEncodedMethod targetMethod = lookupTargetMethod(appInfo);
+    DexEncodedMethod targetMethod =
+        invocationContext == null ? null : lookupTargetMethod(appInfo, invocationContext);
     if (targetMethod != null) {
       targetAccessFlags = targetMethod.accessFlags.copy();
       targetHolder = targetMethod.method.holder;
@@ -98,15 +106,18 @@
     return captures.length > 0 ? captures[0] : params[0];
   }
 
-  private DexEncodedMethod lookupTargetMethod(AppInfo appInfo) {
+  private DexEncodedMethod lookupTargetMethod(
+      AppInfoWithClassHierarchy appInfo, DexType invocationContext) {
+    assert invocationContext != null;
     // Find the lambda's impl-method target.
     DexMethod method = implHandle.asMethod();
     switch (implHandle.type) {
       case INVOKE_DIRECT:
       case INVOKE_INSTANCE: {
-        DexEncodedMethod target = appInfo.lookupVirtualTarget(getImplReceiverType(), method);
+          DexEncodedMethod target =
+              appInfo.resolveMethod(getImplReceiverType(), method).getSingleTarget();
         if (target == null) {
-          target = appInfo.lookupDirectTarget(method);
+            target = appInfo.lookupDirectTarget(method, invocationContext);
         }
         assert target == null
             || (implHandle.type.isInvokeInstance() && isInstanceMethod(target))
@@ -116,19 +127,20 @@
       }
 
       case INVOKE_STATIC: {
-        DexEncodedMethod target = appInfo.lookupStaticTarget(method);
+          DexEncodedMethod target = appInfo.lookupStaticTarget(method, invocationContext);
         assert target == null || target.accessFlags.isStatic();
         return target;
       }
 
       case INVOKE_CONSTRUCTOR: {
-        DexEncodedMethod target = appInfo.lookupDirectTarget(method);
+          DexEncodedMethod target = appInfo.lookupDirectTarget(method, invocationContext);
         assert target == null || target.accessFlags.isConstructor();
         return target;
       }
 
       case INVOKE_INTERFACE: {
-        DexEncodedMethod target = appInfo.lookupVirtualTarget(getImplReceiverType(), method);
+          DexEncodedMethod target =
+              appInfo.resolveMethod(getImplReceiverType(), method).getSingleTarget();
         assert target == null || isInstanceMethod(target);
         return target;
       }
@@ -226,19 +238,21 @@
   }
 
   /**
-   * Matches call site for lambda metafactory invocation pattern and
-   * returns extracted match information, or null if match failed.
+   * Matches call site for lambda metafactory invocation pattern and returns extracted match
+   * information, or null if match failed.
    */
-  public static LambdaDescriptor tryInfer(DexCallSite callSite, AppInfo appInfo) {
-    LambdaDescriptor descriptor = infer(callSite, appInfo);
+  public static LambdaDescriptor tryInfer(
+      DexCallSite callSite, AppInfoWithClassHierarchy appInfo, DexProgramClass invocationContext) {
+    LambdaDescriptor descriptor = infer(callSite, appInfo, invocationContext.type);
     return descriptor == MATCH_FAILED ? null : descriptor;
   }
 
   /**
-   * Matches call site for lambda metafactory invocation pattern and
-   * returns extracted match information, or MATCH_FAILED if match failed.
+   * Matches call site for lambda metafactory invocation pattern and returns extracted match
+   * information, or MATCH_FAILED if match failed.
    */
-  static LambdaDescriptor infer(DexCallSite callSite, AppInfo appInfo) {
+  static LambdaDescriptor infer(
+      DexCallSite callSite, AppInfoWithClassHierarchy appInfo, DexType invocationContext) {
     // We expect bootstrap method to be either `metafactory` or `altMetafactory` method
     // of `java.lang.invoke.LambdaMetafactory` class. Both methods are static.
     if (!callSite.bootstrapMethod.type.isInvokeStatic()) {
@@ -285,9 +299,17 @@
     DexTypeList captures = lambdaFactoryProto.parameters;
 
     // Create a match.
-    LambdaDescriptor match = new LambdaDescriptor(appInfo, callSite,
-        funcMethodName, funcErasedSignature.value, funcEnforcedSignature.value,
-        lambdaImplMethodHandle, mainFuncInterface, captures);
+    LambdaDescriptor match =
+        new LambdaDescriptor(
+            appInfo,
+            invocationContext,
+            callSite,
+            funcMethodName,
+            funcErasedSignature.value,
+            funcEnforcedSignature.value,
+            lambdaImplMethodHandle,
+            mainFuncInterface,
+            captures);
 
     if (bootstrapMethod == factory.metafactoryMethod) {
       if (callSite.bootstrapArgs.size() != 3) {
@@ -350,8 +372,10 @@
     }
   }
 
-  public static List<DexType> getInterfaces(DexCallSite callSite, AppInfo appInfo) {
-    LambdaDescriptor descriptor = infer(callSite, appInfo);
+  public static List<DexType> getInterfaces(
+      DexCallSite callSite, AppInfoWithClassHierarchy appInfo) {
+    // No need for the invocationContext to figure out only the interfaces.
+    LambdaDescriptor descriptor = infer(callSite, appInfo, null);
     if (descriptor == LambdaDescriptor.MATCH_FAILED) {
       return null;
     }
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
index 2fd29e1..6ce1f5f 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
@@ -4,7 +4,7 @@
 
 package com.android.tools.r8.ir.desugar;
 
-import com.android.tools.r8.graph.AppInfo;
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.Code;
 import com.android.tools.r8.graph.DefaultUseRegistry;
@@ -63,7 +63,7 @@
   static final String EXPECTED_LAMBDA_METHOD_PREFIX = "lambda$";
   private static final String LAMBDA_INSTANCE_FIELD_NAME = "INSTANCE";
 
-  private final AppView<?> appView;
+  private final AppView<? extends AppInfoWithClassHierarchy> appView;
 
   final DexString instanceFieldName;
 
@@ -87,7 +87,9 @@
   }
 
   public LambdaRewriter(AppView<?> appView) {
-    this.appView = appView;
+    assert appView.appInfo().hasClassHierarchy()
+        : "Lambda desugaring is not available without class hierarchy.";
+    this.appView = appView.withClassHierarchy();
     this.graphLens = new LambdaRewriterGraphLense(appView);
     this.instanceFieldName = getFactory().createString(LAMBDA_INSTANCE_FIELD_NAME);
   }
@@ -96,8 +98,8 @@
     return appView;
   }
 
-  public AppInfo getAppInfo() {
-    return getAppView().appInfo();
+  public AppInfoWithClassHierarchy getAppInfo() {
+    return appView.appInfo();
   }
 
   public DexItemFactory getFactory() {
@@ -214,7 +216,8 @@
           @Override
           public void registerCallSite(DexCallSite callSite) {
             LambdaDescriptor descriptor =
-                inferLambdaDescriptor(lensCodeRewriter.rewriteCallSite(callSite, method));
+                inferLambdaDescriptor(
+                    lensCodeRewriter.rewriteCallSite(callSite, method), method.method.holder);
             if (descriptor != LambdaDescriptor.MATCH_FAILED) {
               consumer.accept(getOrCreateLambdaClass(descriptor, method.method.holder));
             }
@@ -238,7 +241,8 @@
         Instruction instruction = instructions.next();
         if (instruction.isInvokeCustom()) {
           InvokeCustom invoke = instruction.asInvokeCustom();
-          LambdaDescriptor descriptor = inferLambdaDescriptor(invoke.getCallSite());
+          LambdaDescriptor descriptor =
+              inferLambdaDescriptor(invoke.getCallSite(), encodedMethod.method.holder);
           if (descriptor == LambdaDescriptor.MATCH_FAILED) {
             continue;
           }
@@ -333,7 +337,7 @@
   // corresponding to this lambda invocation point.
   //
   // Returns the lambda descriptor or `MATCH_FAILED`.
-  private LambdaDescriptor inferLambdaDescriptor(DexCallSite callSite) {
+  private LambdaDescriptor inferLambdaDescriptor(DexCallSite callSite, DexType invocationContext) {
     // We check the map before and after inferring lambda descriptor to minimize time
     // spent in synchronized block. As a result we may throw away calculated descriptor
     // in rare case when another thread has same call site processed concurrently,
@@ -341,7 +345,10 @@
     LambdaDescriptor descriptor = getKnown(knownCallSites, callSite);
     return descriptor != null
         ? descriptor
-        : putIfAbsent(knownCallSites, callSite, LambdaDescriptor.infer(callSite, getAppInfo()));
+        : putIfAbsent(
+            knownCallSites,
+            callSite,
+            LambdaDescriptor.infer(callSite, getAppInfo(), invocationContext));
   }
 
   private boolean isInMainDexList(DexType type) {
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 2884fcd..4a8cdf3 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -612,7 +612,7 @@
       bootstrapMethods.add(callSite.bootstrapMethod.asMethod());
     }
 
-    LambdaDescriptor descriptor = LambdaDescriptor.tryInfer(callSite, appInfo);
+    LambdaDescriptor descriptor = LambdaDescriptor.tryInfer(callSite, appInfo, context.holder);
     if (descriptor == null) {
       return;
     }