[Metadata] Model context receiver types in metadata

Bug: b/266396725
Change-Id: I66e1fdb5dcec6aac2d9620138c2333b921b4d2c8
diff --git a/src/main/java/com/android/tools/r8/kotlin/KotlinClassInfo.java b/src/main/java/com/android/tools/r8/kotlin/KotlinClassInfo.java
index 55314eb..81d408f 100644
--- a/src/main/java/com/android/tools/r8/kotlin/KotlinClassInfo.java
+++ b/src/main/java/com/android/tools/r8/kotlin/KotlinClassInfo.java
@@ -19,6 +19,7 @@
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.utils.Box;
 import com.android.tools.r8.utils.DescriptorUtils;
+import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.Pair;
 import com.android.tools.r8.utils.Reporter;
 import com.google.common.collect.ImmutableList;
@@ -59,6 +60,8 @@
   private final KotlinTypeInfo inlineClassUnderlyingType;
   private final int jvmFlags;
   private final String companionObjectName;
+  // Collection of context receiver types
+  private final List<KotlinTypeInfo> contextReceiverTypes;
 
   // List of tracked assignments of kotlin metadata.
   private final KotlinMetadataMembersTracker originalMembersWithKotlinInfo;
@@ -84,7 +87,8 @@
       KotlinTypeInfo inlineClassUnderlyingType,
       KotlinMetadataMembersTracker originalMembersWithKotlinInfo,
       int jvmFlags,
-      String companionObjectName) {
+      String companionObjectName,
+      List<KotlinTypeInfo> contextReceiverTypes) {
     this.flags = flags;
     this.name = name;
     this.nameCanBeSynthesizedFromClassOrAnonymousObjectOrigin =
@@ -107,6 +111,7 @@
     this.originalMembersWithKotlinInfo = originalMembersWithKotlinInfo;
     this.jvmFlags = jvmFlags;
     this.companionObjectName = companionObjectName;
+    this.contextReceiverTypes = contextReceiverTypes;
   }
 
   public static KotlinClassInfo create(
@@ -190,7 +195,10 @@
         KotlinTypeInfo.create(kmClass.getInlineClassUnderlyingType(), factory, reporter),
         originalMembersWithKotlinInfo,
         JvmExtensionsKt.getJvmFlags(kmClass),
-        setCompanionObject(kmClass, hostClass, reporter));
+        setCompanionObject(kmClass, hostClass, reporter),
+        ListUtils.map(
+            kmClass.getContextReceiverTypes(),
+            contextRecieverType -> KotlinTypeInfo.create(contextRecieverType, factory, reporter)));
   }
 
   private static KotlinTypeReference getAnonymousObjectOrigin(
@@ -416,6 +424,9 @@
       rewritten |=
           inlineClassUnderlyingType.rewrite(kmClass::visitInlineClassUnderlyingType, appView);
     }
+    for (KotlinTypeInfo contextReceiverType : contextReceiverTypes) {
+      rewritten |= contextReceiverType.rewrite(kmClass::visitContextReceiverType, appView);
+    }
     JvmClassExtensionVisitor extensionVisitor =
         (JvmClassExtensionVisitor) kmClass.visitExtensions(JvmClassExtensionVisitor.TYPE);
     extensionVisitor.visitJvmFlags(jvmFlags);
@@ -457,6 +468,7 @@
     forEachApply(superTypes, type -> type::trace, definitionSupplier);
     forEachApply(sealedSubClasses, sealedClass -> sealedClass::trace, definitionSupplier);
     forEachApply(nestedClasses, nested -> nested::trace, definitionSupplier);
+    forEachApply(contextReceiverTypes, nested -> nested::trace, definitionSupplier);
     localDelegatedProperties.trace(definitionSupplier);
     // TODO(b/154347404): trace enum entries.
     if (anonymousObjectOrigin != null) {
diff --git a/src/main/java/com/android/tools/r8/kotlin/KotlinFunctionInfo.java b/src/main/java/com/android/tools/r8/kotlin/KotlinFunctionInfo.java
index f5b40c2..69e4004 100644
--- a/src/main/java/com/android/tools/r8/kotlin/KotlinFunctionInfo.java
+++ b/src/main/java/com/android/tools/r8/kotlin/KotlinFunctionInfo.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.graph.DexDefinitionSupplier;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.Reporter;
 import java.util.List;
 import kotlinx.metadata.KmFunction;
@@ -41,6 +42,8 @@
   private final KotlinContractInfo contract;
   // A value describing if any of the parameters are crossinline.
   private final boolean crossInlineParameter;
+  // Collection of context receiver types
+  private final List<KotlinTypeInfo> contextReceiverTypes;
 
   private KotlinFunctionInfo(
       int flags,
@@ -53,7 +56,8 @@
       KotlinTypeReference lambdaClassOrigin,
       KotlinVersionRequirementInfo versionRequirements,
       KotlinContractInfo contract,
-      boolean crossInlineParameter) {
+      boolean crossInlineParameter,
+      List<KotlinTypeInfo> contextReceiverTypes) {
     this.flags = flags;
     this.name = name;
     this.returnType = returnType;
@@ -65,6 +69,7 @@
     this.versionRequirements = versionRequirements;
     this.contract = contract;
     this.crossInlineParameter = crossInlineParameter;
+    this.contextReceiverTypes = contextReceiverTypes;
   }
 
   public boolean hasCrossInlineParameter() {
@@ -98,7 +103,10 @@
         getlambdaClassOrigin(kmFunction, factory),
         KotlinVersionRequirementInfo.create(kmFunction.getVersionRequirements()),
         KotlinContractInfo.create(kmFunction.getContract(), factory, reporter),
-        isCrossInline);
+        isCrossInline,
+        ListUtils.map(
+            kmFunction.getContextReceiverTypes(),
+            contextRecieverType -> KotlinTypeInfo.create(contextRecieverType, factory, reporter)));
   }
 
   private static KotlinTypeReference getlambdaClassOrigin(
@@ -141,6 +149,9 @@
     for (KotlinTypeParameterInfo typeParameterInfo : typeParameters) {
       rewritten |= typeParameterInfo.rewrite(kmFunction::visitTypeParameter, appView);
     }
+    for (KotlinTypeInfo contextReceiverType : contextReceiverTypes) {
+      rewritten |= contextReceiverType.rewrite(kmFunction::visitContextReceiverType, appView);
+    }
     if (receiverParameterType != null) {
       rewritten |= receiverParameterType.rewrite(kmFunction::visitReceiverParameterType, appView);
     }
@@ -187,6 +198,7 @@
       receiverParameterType.trace(definitionSupplier);
     }
     forEachApply(typeParameters, param -> param::trace, definitionSupplier);
+    forEachApply(contextReceiverTypes, type -> type::trace, definitionSupplier);
     if (signature != null) {
       signature.trace(definitionSupplier);
     }
diff --git a/src/main/java/com/android/tools/r8/kotlin/KotlinMetadataWriter.java b/src/main/java/com/android/tools/r8/kotlin/KotlinMetadataWriter.java
index 1e4d633..dfb5d76 100644
--- a/src/main/java/com/android/tools/r8/kotlin/KotlinMetadataWriter.java
+++ b/src/main/java/com/android/tools/r8/kotlin/KotlinMetadataWriter.java
@@ -379,20 +379,31 @@
         indent,
         "constructors",
         sb,
-        newIndent -> {
-          appendKmList(
-              newIndent,
-              "KmConstructor",
-              sb,
-              kmClass.getConstructors().stream()
-                  .sorted(
-                      Comparator.comparing(
-                          kmConstructor -> JvmExtensionsKt.getSignature(kmConstructor).asString()))
-                  .collect(Collectors.toList()),
-              (nextIndent, constructor) -> {
-                appendKmConstructor(nextIndent, sb, constructor);
-              });
-        });
+        newIndent ->
+            appendKmList(
+                newIndent,
+                "KmConstructor",
+                sb,
+                kmClass.getConstructors().stream()
+                    .sorted(
+                        Comparator.comparing(
+                            kmConstructor ->
+                                JvmExtensionsKt.getSignature(kmConstructor).asString()))
+                    .collect(Collectors.toList()),
+                (nextIndent, constructor) -> {
+                  appendKmConstructor(nextIndent, sb, constructor);
+                }));
+    appendKeyValue(
+        indent,
+        "contextReceiverTypes",
+        sb,
+        newIndent ->
+            appendKmList(
+                newIndent,
+                "KmType",
+                sb,
+                kmClass.getContextReceiverTypes(),
+                (nextIndent, kmType) -> appendKmType(nextIndent, sb, kmType)));
     appendKmDeclarationContainer(indent, sb, kmClass);
   }
 
@@ -458,6 +469,17 @@
                   appendKmContract(nextIndent, sb, contract);
                 });
           }
+          appendKeyValue(
+              newIndent,
+              "contextReceiverTypes",
+              sb,
+              newNewIndent ->
+                  appendKmList(
+                      newNewIndent,
+                      "KmType",
+                      sb,
+                      function.getContextReceiverTypes(),
+                      (nextIndent, kmType) -> appendKmType(nextIndent, sb, kmType)));
           JvmMethodSignature signature = JvmExtensionsKt.getSignature(function);
           appendKeyValue(
               newIndent, "signature", sb, signature != null ? signature.asString() : "null");
@@ -500,6 +522,17 @@
               sb,
               nextIndent -> appendValueParameter(nextIndent, sb, kmProperty.getSetterParameter()));
           appendKmVersionRequirement(newIndent, sb, kmProperty.getVersionRequirements());
+          appendKeyValue(
+              newIndent,
+              "contextReceiverTypes",
+              sb,
+              newNewIndent ->
+                  appendKmList(
+                      newNewIndent,
+                      "KmType",
+                      sb,
+                      kmProperty.getContextReceiverTypes(),
+                      (nextIndent, kmType) -> appendKmType(nextIndent, sb, kmType)));
           appendKeyValue(newIndent, "jvmFlags", sb, JvmExtensionsKt.getJvmFlags(kmProperty) + "");
           JvmFieldSignature fieldSignature = JvmExtensionsKt.getFieldSignature(kmProperty);
           appendKeyValue(
diff --git a/src/main/java/com/android/tools/r8/kotlin/KotlinPropertyInfo.java b/src/main/java/com/android/tools/r8/kotlin/KotlinPropertyInfo.java
index 6d10139..a11b9a9 100644
--- a/src/main/java/com/android/tools/r8/kotlin/KotlinPropertyInfo.java
+++ b/src/main/java/com/android/tools/r8/kotlin/KotlinPropertyInfo.java
@@ -12,6 +12,7 @@
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.utils.Box;
+import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.Reporter;
 import java.util.List;
 import kotlinx.metadata.KmProperty;
@@ -58,6 +59,8 @@
   private final KotlinJvmMethodSignatureInfo syntheticMethodForAnnotations;
 
   private final KotlinJvmMethodSignatureInfo syntheticMethodForDelegate;
+  // Collection of context receiver types
+  private final List<KotlinTypeInfo> contextReceiverTypes;
 
   private KotlinPropertyInfo(
       int flags,
@@ -74,7 +77,8 @@
       KotlinJvmMethodSignatureInfo getterSignature,
       KotlinJvmMethodSignatureInfo setterSignature,
       KotlinJvmMethodSignatureInfo syntheticMethodForAnnotations,
-      KotlinJvmMethodSignatureInfo syntheticMethodForDelegate) {
+      KotlinJvmMethodSignatureInfo syntheticMethodForDelegate,
+      List<KotlinTypeInfo> contextReceiverTypes) {
     this.flags = flags;
     this.getterFlags = getterFlags;
     this.setterFlags = setterFlags;
@@ -90,6 +94,7 @@
     this.setterSignature = setterSignature;
     this.syntheticMethodForAnnotations = syntheticMethodForAnnotations;
     this.syntheticMethodForDelegate = syntheticMethodForDelegate;
+    this.contextReceiverTypes = contextReceiverTypes;
   }
 
   public static KotlinPropertyInfo create(
@@ -113,7 +118,10 @@
         KotlinJvmMethodSignatureInfo.create(
             JvmExtensionsKt.getSyntheticMethodForAnnotations(kmProperty), factory),
         KotlinJvmMethodSignatureInfo.create(
-            JvmExtensionsKt.getSyntheticMethodForDelegate(kmProperty), factory));
+            JvmExtensionsKt.getSyntheticMethodForDelegate(kmProperty), factory),
+        ListUtils.map(
+            kmProperty.getContextReceiverTypes(),
+            contextRecieverType -> KotlinTypeInfo.create(contextRecieverType, factory, reporter)));
   }
 
   @Override
@@ -160,6 +168,9 @@
     for (KotlinTypeParameterInfo typeParameter : typeParameters) {
       rewritten |= typeParameter.rewrite(kmProperty::visitTypeParameter, appView);
     }
+    for (KotlinTypeInfo contextReceiverType : contextReceiverTypes) {
+      rewritten |= contextReceiverType.rewrite(kmProperty::visitContextReceiverType, appView);
+    }
     rewritten |= versionRequirements.rewrite(kmProperty::visitVersionRequirement);
     JvmPropertyExtensionVisitor extensionVisitor =
         (JvmPropertyExtensionVisitor) kmProperty.visitExtensions(JvmPropertyExtensionVisitor.TYPE);
@@ -207,6 +218,7 @@
       setterParameter.trace(definitionSupplier);
     }
     forEachApply(typeParameters, param -> param::trace, definitionSupplier);
+    forEachApply(contextReceiverTypes, type -> type::trace, definitionSupplier);
     if (fieldSignature != null) {
       fieldSignature.trace(definitionSupplier);
     }
diff --git a/src/test/java/com/android/tools/r8/kotlin/metadata/MetadataRewriteContextReceiverTest.java b/src/test/java/com/android/tools/r8/kotlin/metadata/MetadataRewriteContextReceiverTest.java
index 206b7c9..102a056 100644
--- a/src/test/java/com/android/tools/r8/kotlin/metadata/MetadataRewriteContextReceiverTest.java
+++ b/src/test/java/com/android/tools/r8/kotlin/metadata/MetadataRewriteContextReceiverTest.java
@@ -109,7 +109,13 @@
             .addOptionsModification(
                 options -> options.testing.keepMetadataInR8IfNotRewritten = false)
             .compile()
-            // TODO(b/266396725): Assert equivalence of metadata.
+            // Since this has a keep-all classes rule assert that the meta-data is equal to the
+            // original one.
+            .inspect(
+                inspector ->
+                    assertEqualDeserializedMetadata(
+                        inspector,
+                        new CodeInspector(libJars.getForConfiguration(kotlinc, targetVersion))))
             .writeToZip();
     Path main =
         kotlinc(parameters.getRuntime().asCf(), kotlinc, targetVersion)
@@ -122,8 +128,7 @@
         .addRunClasspathFiles(kotlinc.getKotlinStdlibJar(), kotlinc.getKotlinReflectJar(), libJar)
         .addClasspath(main)
         .run(parameters.getRuntime(), MAIN)
-        // TODO(b/266396725): Rewrite context receivers in metadata.
-        .assertFailureWithErrorThatThrows(NoSuchMethodError.class);
+        .assertSuccessWithOutput(EXPECTED);
   }
 
   @Test
@@ -183,7 +188,6 @@
         .addRunClasspathFiles(kotlinc.getKotlinStdlibJar(), libJar)
         .addClasspath(output)
         .run(parameters.getRuntime(), MAIN)
-        // TODO(b/266396725): Rewrite context receivers in metadata.
-        .assertFailureWithErrorThatThrows(NoSuchMethodError.class);
+        .assertSuccessWithOutput(EXPECTED);
   }
 }