VerticalClassMerger: use invocationContext in lookup

Bug: 147578480
Change-Id: I312f2db87f8e235472a0ee933a5ed1cf9f05a805
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/InliningConstraints.java b/src/main/java/com/android/tools/r8/ir/optimize/InliningConstraints.java
index fd4b0ca..918231c 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/InliningConstraints.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/InliningConstraints.java
@@ -11,15 +11,18 @@
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexReference;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.GraphLense;
 import com.android.tools.r8.graph.ResolutionResult;
+import com.android.tools.r8.graph.ResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.ir.code.Invoke.Type;
 import com.android.tools.r8.ir.optimize.Inliner.Constraint;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import java.util.Collection;
+import java.util.function.BiFunction;
 
 // Computes the inlining constraint for a given instruction.
 public class InliningConstraints {
@@ -51,6 +54,10 @@
     allowStaticInterfaceMethodCalls = false;
   }
 
+  private boolean isVerticalClassMerging() {
+    return !graphLense.isIdentityLense();
+  }
+
   public ConstraintWithTarget forAlwaysMaterializingUser() {
     return ConstraintWithTarget.ALWAYS;
   }
@@ -154,10 +161,47 @@
     return ConstraintWithTarget.NEVER;
   }
 
+  private DexEncodedMethod lookupWhileVerticalClassMerging(
+      DexMethod method,
+      DexType invocationContext,
+      BiFunction<SingleResolutionResult, DexProgramClass, DexEncodedMethod> lookupFunction) {
+    SingleResolutionResult singleResolutionResult =
+        appView.appInfo().resolveMethod(method.holder, method).asSingleResolution();
+    if (singleResolutionResult == null) {
+      return null;
+    }
+    DexProgramClass context = appView.definitionForProgramType(invocationContext);
+    if (context == null) {
+      return null;
+    }
+    DexEncodedMethod dexEncodedMethod = lookupFunction.apply(singleResolutionResult, context);
+    if (dexEncodedMethod != null) {
+      return dexEncodedMethod;
+    }
+    assert graphLense.lookupType(context.superType) == context.type;
+    DexProgramClass superContext = appView.definitionForProgramType(context.superType);
+    if (superContext == null) {
+      return null;
+    }
+    DexEncodedMethod alternativeDexEncodedMethod =
+        lookupFunction.apply(singleResolutionResult, superContext);
+    if (alternativeDexEncodedMethod != null
+        && alternativeDexEncodedMethod.method.holder == superContext.type) {
+      return alternativeDexEncodedMethod;
+    }
+    return null;
+  }
+
   public ConstraintWithTarget forInvokeDirect(DexMethod method, DexType invocationContext) {
     DexMethod lookup = graphLense.lookupMethod(method);
-    return forSingleTargetInvoke(
-        lookup, appView.appInfo().lookupDirectTarget(lookup), invocationContext);
+    DexEncodedMethod target =
+        isVerticalClassMerging()
+            ? lookupWhileVerticalClassMerging(
+                lookup,
+                invocationContext,
+                (res, ctxt) -> res.lookupInvokeDirectTarget(ctxt, appView.appInfo()))
+            : appView.appInfo().lookupDirectTarget(lookup, invocationContext);
+    return forSingleTargetInvoke(lookup, target, invocationContext);
   }
 
   public ConstraintWithTarget forInvokeInterface(DexMethod method, DexType invocationContext) {
@@ -182,8 +226,14 @@
 
   public ConstraintWithTarget forInvokeStatic(DexMethod method, DexType invocationContext) {
     DexMethod lookup = graphLense.lookupMethod(method);
-    return forSingleTargetInvoke(
-        lookup, appView.appInfo().lookupStaticTarget(lookup, invocationContext), invocationContext);
+    DexEncodedMethod target =
+        isVerticalClassMerging()
+            ? lookupWhileVerticalClassMerging(
+                lookup,
+                invocationContext,
+                (res, ctxt) -> res.lookupInvokeStaticTarget(ctxt, appView.appInfo()))
+            : appView.appInfo().lookupStaticTarget(lookup, invocationContext);
+    return forSingleTargetInvoke(lookup, target, invocationContext);
   }
 
   public ConstraintWithTarget forInvokeSuper(DexMethod method, DexType invocationContext) {