VarHandle desugar: Better handling of polymorphic return types

Bug: b/247076137
Change-Id: If14311abf0e2fe40a28f09dd0168a37d8523b6ba
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/varhandle/VarHandleDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/varhandle/VarHandleDesugaring.java
index 380d5d4..b83c526 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/varhandle/VarHandleDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/varhandle/VarHandleDesugaring.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.ir.desugar.varhandle;
 
 import com.android.tools.r8.cf.code.CfCheckCast;
+import com.android.tools.r8.cf.code.CfConstClass;
 import com.android.tools.r8.cf.code.CfInstruction;
 import com.android.tools.r8.cf.code.CfInvoke;
 import com.android.tools.r8.cf.code.CfLoad;
@@ -404,16 +405,22 @@
     }
     assert newParameters.size() == proto.parameters.size();
     DexString name = invoke.getMethod().getName();
-    DexProto newProto =
-        factory.createProto(
-            factory.polymorphicMethods.varHandleCompareAndSetMethodNames.contains(name)
-                ? proto.returnType
-                : objectOrPrimitiveReturnType(proto.returnType),
-            newParameters);
+    DexType returnType =
+        factory.polymorphicMethods.varHandleCompareAndSetMethodNames.contains(name)
+            ? proto.returnType
+            : objectOrPrimitiveReturnType(proto.returnType);
+    if (proto.returnType != returnType) {
+      if (proto.returnType.isPrimitiveType()) {
+        builder.add(new CfConstClass(factory.getBoxedForPrimitiveType(proto.returnType)));
+      } else {
+        builder.add(new CfConstClass(proto.returnType));
+      }
+      newParameters.add(factory.classType);
+    }
+    DexProto newProto = factory.createProto(returnType, newParameters);
     DexMethod newMethod = factory.createMethod(factory.varHandleType, newProto, name);
     builder.add(new CfInvoke(Opcodes.INVOKEVIRTUAL, newMethod, false));
     if (proto.returnType.isPrimitiveType() && !newProto.returnType.isPrimitiveType()) {
-      assert proto.returnType.isPrimitiveType();
       assert newProto.returnType == factory.objectType;
       builder.add(new CfCheckCast(factory.getBoxedForPrimitiveType(proto.returnType)));
       builder.add(
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/varhandle/VarHandleDesugaringMethods.java b/src/main/java/com/android/tools/r8/ir/desugar/varhandle/VarHandleDesugaringMethods.java
index 90c8eb5..9f1ba61 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/varhandle/VarHandleDesugaringMethods.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/varhandle/VarHandleDesugaringMethods.java
@@ -52,6 +52,8 @@
   public static void registerSynthesizedCodeReferences(DexItemFactory factory) {
     factory.createSynthesizedType("Ljava/lang/Byte;");
     factory.createSynthesizedType("Ljava/lang/ClassCastException;");
+    factory.createSynthesizedType("Ljava/lang/Double;");
+    factory.createSynthesizedType("Ljava/lang/Float;");
     factory.createSynthesizedType("Ljava/lang/Integer;");
     factory.createSynthesizedType("Ljava/lang/Long;");
     factory.createSynthesizedType("Ljava/lang/RuntimeException;");
@@ -179,6 +181,14 @@
             builder.getType(),
             factory.createProto(factory.longType, factory.objectType),
             factory.createString("get"));
+    DexMethod getInBox =
+        factory.createMethod(
+            builder.getType(),
+            factory.createProto(
+                factory.objectType,
+                factory.objectType,
+                factory.createType(factory.createString("Ljava/lang/Class;"))),
+            factory.createString("get"));
     builder.setDirectMethods(
         ImmutableList.of(
             DexEncodedMethod.syntheticBuilder()
@@ -296,6 +306,14 @@
                         Constants.ACC_PUBLIC | Constants.ACC_SYNTHETIC, false))
                 .setCode(DesugarVarHandle_getLong(factory, getLong))
                 .disableAndroidApiLevelCheck()
+                .build(),
+            DexEncodedMethod.syntheticBuilder()
+                .setMethod(getInBox)
+                .setAccessFlags(
+                    MethodAccessFlags.fromSharedAccessFlags(
+                        Constants.ACC_PUBLIC | Constants.ACC_SYNTHETIC, false))
+                .setCode(DesugarVarHandle_getInBox(factory, getInBox))
+                .disableAndroidApiLevelCheck()
                 .build()));
   }
 
@@ -1366,6 +1384,294 @@
         ImmutableList.of());
   }
 
+  public static CfCode DesugarVarHandle_getInBox(DexItemFactory factory, DexMethod method) {
+    CfLabel label0 = new CfLabel();
+    CfLabel label1 = new CfLabel();
+    CfLabel label2 = new CfLabel();
+    CfLabel label3 = new CfLabel();
+    CfLabel label4 = new CfLabel();
+    CfLabel label5 = new CfLabel();
+    CfLabel label6 = new CfLabel();
+    CfLabel label7 = new CfLabel();
+    CfLabel label8 = new CfLabel();
+    CfLabel label9 = new CfLabel();
+    CfLabel label10 = new CfLabel();
+    CfLabel label11 = new CfLabel();
+    CfLabel label12 = new CfLabel();
+    CfLabel label13 = new CfLabel();
+    CfLabel label14 = new CfLabel();
+    CfLabel label15 = new CfLabel();
+    CfLabel label16 = new CfLabel();
+    CfLabel label17 = new CfLabel();
+    return new CfCode(
+        method.holder,
+        4,
+        5,
+        ImmutableList.of(
+            label0,
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.classType,
+                    factory.createString("type"))),
+            new CfStaticFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/Integer;"),
+                    factory.classType,
+                    factory.createString("TYPE"))),
+            new CfIfCmp(If.Type.NE, ValueType.OBJECT, label9),
+            label1,
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.createType("Lsun/misc/Unsafe;"),
+                    factory.createString("U"))),
+            new CfLoad(ValueType.OBJECT, 1),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.longType,
+                    factory.createString("offset"))),
+            new CfInvoke(
+                182,
+                factory.createMethod(
+                    factory.createType("Lsun/misc/Unsafe;"),
+                    factory.createProto(factory.intType, factory.objectType, factory.longType),
+                    factory.createString("getInt")),
+                false),
+            new CfStore(ValueType.INT, 3),
+            label2,
+            new CfLoad(ValueType.OBJECT, 2),
+            new CfConstClass(factory.createType("Ljava/lang/Long;")),
+            new CfIfCmp(If.Type.NE, ValueType.OBJECT, label4),
+            label3,
+            new CfLoad(ValueType.INT, 3),
+            new CfNumberConversion(NumericType.INT, NumericType.LONG),
+            new CfInvoke(
+                184,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/Long;"),
+                    factory.createProto(factory.createType("Ljava/lang/Long;"), factory.longType),
+                    factory.createString("valueOf")),
+                false),
+            new CfReturn(ValueType.OBJECT),
+            label4,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(
+                          factory.createType("Ljava/lang/invoke/VarHandle;")),
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(factory.classType),
+                      FrameType.intType()
+                    })),
+            new CfLoad(ValueType.OBJECT, 2),
+            new CfConstClass(factory.createType("Ljava/lang/Float;")),
+            new CfIfCmp(If.Type.NE, ValueType.OBJECT, label6),
+            label5,
+            new CfLoad(ValueType.INT, 3),
+            new CfNumberConversion(NumericType.INT, NumericType.FLOAT),
+            new CfInvoke(
+                184,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/Float;"),
+                    factory.createProto(factory.createType("Ljava/lang/Float;"), factory.floatType),
+                    factory.createString("valueOf")),
+                false),
+            new CfReturn(ValueType.OBJECT),
+            label6,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(
+                          factory.createType("Ljava/lang/invoke/VarHandle;")),
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(factory.classType),
+                      FrameType.intType()
+                    })),
+            new CfLoad(ValueType.OBJECT, 2),
+            new CfConstClass(factory.createType("Ljava/lang/Double;")),
+            new CfIfCmp(If.Type.NE, ValueType.OBJECT, label8),
+            label7,
+            new CfLoad(ValueType.INT, 3),
+            new CfNumberConversion(NumericType.INT, NumericType.DOUBLE),
+            new CfInvoke(
+                184,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/Double;"),
+                    factory.createProto(
+                        factory.createType("Ljava/lang/Double;"), factory.doubleType),
+                    factory.createString("valueOf")),
+                false),
+            new CfReturn(ValueType.OBJECT),
+            label8,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(
+                          factory.createType("Ljava/lang/invoke/VarHandle;")),
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(factory.classType),
+                      FrameType.intType()
+                    })),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInvoke(
+                182,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.createProto(factory.createType("Ljava/lang/RuntimeException;")),
+                    factory.createString("desugarWrongMethodTypeException")),
+                false),
+            new CfThrow(),
+            label9,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(
+                          factory.createType("Ljava/lang/invoke/VarHandle;")),
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(factory.classType)
+                    })),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.classType,
+                    factory.createString("type"))),
+            new CfStaticFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/Long;"),
+                    factory.classType,
+                    factory.createString("TYPE"))),
+            new CfIfCmp(If.Type.NE, ValueType.OBJECT, label16),
+            label10,
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.createType("Lsun/misc/Unsafe;"),
+                    factory.createString("U"))),
+            new CfLoad(ValueType.OBJECT, 1),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.longType,
+                    factory.createString("offset"))),
+            new CfInvoke(
+                182,
+                factory.createMethod(
+                    factory.createType("Lsun/misc/Unsafe;"),
+                    factory.createProto(factory.longType, factory.objectType, factory.longType),
+                    factory.createString("getLong")),
+                false),
+            new CfStore(ValueType.LONG, 3),
+            label11,
+            new CfLoad(ValueType.OBJECT, 2),
+            new CfConstClass(factory.createType("Ljava/lang/Float;")),
+            new CfIfCmp(If.Type.NE, ValueType.OBJECT, label13),
+            label12,
+            new CfLoad(ValueType.LONG, 3),
+            new CfNumberConversion(NumericType.LONG, NumericType.FLOAT),
+            new CfInvoke(
+                184,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/Float;"),
+                    factory.createProto(factory.createType("Ljava/lang/Float;"), factory.floatType),
+                    factory.createString("valueOf")),
+                false),
+            new CfReturn(ValueType.OBJECT),
+            label13,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(
+                          factory.createType("Ljava/lang/invoke/VarHandle;")),
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(factory.classType),
+                      FrameType.longType(),
+                      FrameType.longHighType()
+                    })),
+            new CfLoad(ValueType.OBJECT, 2),
+            new CfConstClass(factory.createType("Ljava/lang/Double;")),
+            new CfIfCmp(If.Type.NE, ValueType.OBJECT, label15),
+            label14,
+            new CfLoad(ValueType.LONG, 3),
+            new CfNumberConversion(NumericType.LONG, NumericType.DOUBLE),
+            new CfInvoke(
+                184,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/Double;"),
+                    factory.createProto(
+                        factory.createType("Ljava/lang/Double;"), factory.doubleType),
+                    factory.createString("valueOf")),
+                false),
+            new CfReturn(ValueType.OBJECT),
+            label15,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2, 3, 4},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(
+                          factory.createType("Ljava/lang/invoke/VarHandle;")),
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(factory.classType),
+                      FrameType.longType(),
+                      FrameType.longHighType()
+                    })),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInvoke(
+                182,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.createProto(factory.createType("Ljava/lang/RuntimeException;")),
+                    factory.createString("desugarWrongMethodTypeException")),
+                false),
+            new CfThrow(),
+            label16,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1, 2},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(
+                          factory.createType("Ljava/lang/invoke/VarHandle;")),
+                      FrameType.initializedNonNullReference(factory.objectType),
+                      FrameType.initializedNonNullReference(factory.classType)
+                    })),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.createType("Lsun/misc/Unsafe;"),
+                    factory.createString("U"))),
+            new CfLoad(ValueType.OBJECT, 1),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInstanceFieldRead(
+                factory.createField(
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.longType,
+                    factory.createString("offset"))),
+            new CfInvoke(
+                182,
+                factory.createMethod(
+                    factory.createType("Lsun/misc/Unsafe;"),
+                    factory.createProto(factory.objectType, factory.objectType, factory.longType),
+                    factory.createString("getObject")),
+                false),
+            new CfReturn(ValueType.OBJECT),
+            label17),
+        ImmutableList.of(),
+        ImmutableList.of());
+  }
+
   public static CfCode DesugarVarHandle_getInt(DexItemFactory factory, DexMethod method) {
     CfLabel label0 = new CfLabel();
     CfLabel label1 = new CfLabel();
@@ -1436,27 +1742,14 @@
             new CfIfCmp(If.Type.NE, ValueType.OBJECT, label4),
             label3,
             new CfLoad(ValueType.OBJECT, 0),
-            new CfInstanceFieldRead(
-                factory.createField(
-                    factory.createType("Ljava/lang/invoke/VarHandle;"),
-                    factory.createType("Lsun/misc/Unsafe;"),
-                    factory.createString("U"))),
-            new CfLoad(ValueType.OBJECT, 1),
-            new CfLoad(ValueType.OBJECT, 0),
-            new CfInstanceFieldRead(
-                factory.createField(
-                    factory.createType("Ljava/lang/invoke/VarHandle;"),
-                    factory.longType,
-                    factory.createString("offset"))),
             new CfInvoke(
                 182,
                 factory.createMethod(
-                    factory.createType("Lsun/misc/Unsafe;"),
-                    factory.createProto(factory.longType, factory.objectType, factory.longType),
-                    factory.createString("getLong")),
+                    factory.createType("Ljava/lang/invoke/VarHandle;"),
+                    factory.createProto(factory.createType("Ljava/lang/RuntimeException;")),
+                    factory.createString("desugarWrongMethodTypeException")),
                 false),
-            new CfNumberConversion(NumericType.LONG, NumericType.INT),
-            new CfReturn(ValueType.INT),
+            new CfThrow(),
             label4,
             new CfFrame(
                 new Int2ObjectAVLTreeMap<>(
diff --git a/src/test/examplesJava9/varhandle/InstanceIntField.java b/src/test/examplesJava9/varhandle/InstanceIntField.java
index 0320030..aeadbcc 100644
--- a/src/test/examplesJava9/varhandle/InstanceIntField.java
+++ b/src/test/examplesJava9/varhandle/InstanceIntField.java
@@ -18,10 +18,54 @@
     throw e;
   }
 
-  public static void testSet(VarHandle varHandle) {
+  public static void testGet(VarHandle varHandle) {
     System.out.println("testGet");
 
     InstanceIntField instance = new InstanceIntField();
+    varHandle.set(instance, 1);
+
+    System.out.println(varHandle.get(instance));
+    System.out.println((Object) varHandle.get(instance));
+    System.out.println((int) varHandle.get(instance));
+    System.out.println((long) varHandle.get(instance));
+    System.out.println((float) varHandle.get(instance));
+    System.out.println((double) varHandle.get(instance));
+    try {
+      System.out.println((boolean) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((byte) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((short) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((char) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((String) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+  }
+
+  public static void testSet(VarHandle varHandle) {
+    System.out.println("testSet");
+
+    InstanceIntField instance = new InstanceIntField();
     System.out.println((int) varHandle.get(instance));
 
     // int and Integer values.
@@ -217,6 +261,7 @@
   public static void main(String[] args) throws NoSuchFieldException, IllegalAccessException {
     VarHandle varHandle =
         MethodHandles.lookup().findVarHandle(InstanceIntField.class, "field", int.class);
+    testGet(varHandle);
     testSet(varHandle);
     testCompareAndSet(varHandle);
   }
diff --git a/src/test/examplesJava9/varhandle/InstanceLongField.java b/src/test/examplesJava9/varhandle/InstanceLongField.java
index fac0ef3..b23e692 100644
--- a/src/test/examplesJava9/varhandle/InstanceLongField.java
+++ b/src/test/examplesJava9/varhandle/InstanceLongField.java
@@ -18,6 +18,55 @@
     throw e;
   }
 
+  public static void testGet(VarHandle varHandle) {
+    System.out.println("testGet");
+
+    InstanceLongField instance = new InstanceLongField();
+    varHandle.set(instance, 1L);
+
+    System.out.println(varHandle.get(instance));
+    System.out.println((Object) varHandle.get(instance));
+    System.out.println((long) varHandle.get(instance));
+    System.out.println((float) varHandle.get(instance));
+    System.out.println((double) varHandle.get(instance));
+    try {
+      System.out.println((boolean) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((byte) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((short) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((char) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((int) varHandle.get(instance));
+      System.out.println("Unexpected success 5");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+    try {
+      System.out.println((String) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (RuntimeException e) {
+      checkJavaLangInvokeWrongMethodTypeException(e);
+    }
+  }
+
   public static void testSet(VarHandle varHandle) {
     System.out.println("testSet");
 
@@ -187,6 +236,7 @@
   public static void main(String[] args) throws NoSuchFieldException, IllegalAccessException {
     VarHandle varHandle =
         MethodHandles.lookup().findVarHandle(InstanceLongField.class, "field", long.class);
+    testGet(varHandle);
     testSet(varHandle);
     testCompareAndSet(varHandle);
   }
diff --git a/src/test/examplesJava9/varhandle/InstanceStringField.java b/src/test/examplesJava9/varhandle/InstanceStringField.java
index 1be6f67..231ee6a 100644
--- a/src/test/examplesJava9/varhandle/InstanceStringField.java
+++ b/src/test/examplesJava9/varhandle/InstanceStringField.java
@@ -14,12 +14,31 @@
     System.out.println(s);
   }
 
-  public static void testSet(VarHandle varHandle) {
+  public static void testGet(VarHandle varHandle) {
     System.out.println("testGet");
 
     InstanceStringField instance = new InstanceStringField();
 
     // Then polymorphic invoke will remove the cast and make that as the return type of the get.
+    System.out.println((String) varHandle.get(instance));
+    varHandle.set(instance, "1");
+    System.out.println(varHandle.get(instance));
+    System.out.println((Object) varHandle.get(instance));
+    System.out.println((String) varHandle.get(instance));
+    System.out.println((CharSequence) varHandle.get(instance));
+    try {
+      System.out.println((Byte) varHandle.get(instance));
+      System.out.println("Unexpected success");
+    } catch (ClassCastException e) {
+    }
+  }
+
+  public static void testSet(VarHandle varHandle) {
+    System.out.println("testSet");
+
+    InstanceStringField instance = new InstanceStringField();
+
+    // Then polymorphic invoke will remove the cast and make that as the return type of the get.
     println((String) varHandle.get(instance));
     varHandle.set(instance, "1");
     println((String) varHandle.get(instance));
@@ -39,6 +58,7 @@
   public static void main(String[] args) throws NoSuchFieldException, IllegalAccessException {
     VarHandle varHandle =
         MethodHandles.lookup().findVarHandle(InstanceStringField.class, "field", Object.class);
+    testGet(varHandle);
     testSet(varHandle);
     testCompareAndSet(varHandle);
   }
diff --git a/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceIntFieldTest.java b/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceIntFieldTest.java
index 7cfc0e5..f676891 100644
--- a/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceIntFieldTest.java
+++ b/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceIntFieldTest.java
@@ -17,6 +17,13 @@
   private static final String EXPECTED_OUTPUT =
       StringUtils.lines(
           "testGet",
+          "1",
+          "1",
+          "1",
+          "1",
+          "1.0",
+          "1.0",
+          "testSet",
           "0",
           "1",
           "2",
diff --git a/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceLongFieldTest.java b/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceLongFieldTest.java
index 227cc76..7303ec5 100644
--- a/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceLongFieldTest.java
+++ b/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceLongFieldTest.java
@@ -16,6 +16,12 @@
 
   private static final String EXPECTED_OUTPUT =
       StringUtils.lines(
+          "testGet",
+          "1",
+          "1",
+          "1",
+          "1.0",
+          "1.0",
           "testSet",
           "0",
           "1",
diff --git a/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceStringFieldTest.java b/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceStringFieldTest.java
index f33f52c..f1440c1 100644
--- a/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceStringFieldTest.java
+++ b/src/test/java/com/android/tools/r8/cf/varhandle/VarHandleDesugaringInstanceStringFieldTest.java
@@ -15,7 +15,19 @@
 public class VarHandleDesugaringInstanceStringFieldTest extends VarHandleDesugaringTestBase {
 
   private static final String EXPECTED_OUTPUT =
-      StringUtils.lines("testGet", "null", "1", "testCompareAndSet", "null", "1");
+      StringUtils.lines(
+          "testGet",
+          "null",
+          "1",
+          "1",
+          "1",
+          "1",
+          "testSet",
+          "null",
+          "1",
+          "testCompareAndSet",
+          "null",
+          "1");
   private static final String MAIN_CLASS = VarHandle.InstanceStringField.typeName();
   private static final String JAR_ENTRY = "varhandle/InstanceStringField.class";
 
@@ -38,4 +50,9 @@
   protected String getExpectedOutputForReferenceImplementation() {
     return EXPECTED_OUTPUT;
   }
+
+  @Override
+  protected boolean getTestWithDesugaring() {
+    return true;
+  }
 }
diff --git a/src/test/java/com/android/tools/r8/ir/desugar/varhandle/DesugarVarHandle.java b/src/test/java/com/android/tools/r8/ir/desugar/varhandle/DesugarVarHandle.java
index 5780485..2ba4cf4 100644
--- a/src/test/java/com/android/tools/r8/ir/desugar/varhandle/DesugarVarHandle.java
+++ b/src/test/java/com/android/tools/r8/ir/desugar/varhandle/DesugarVarHandle.java
@@ -142,11 +142,38 @@
     return U.getObject(ct1, offset);
   }
 
+  Object getInBox(Object ct1, Class<?> expectedBox) {
+    if (type == int.class) {
+      int value = U.getInt(ct1, offset);
+      if (expectedBox == Long.class) {
+        return Long.valueOf(value);
+      }
+      if (expectedBox == Float.class) {
+        return Float.valueOf(value);
+      }
+      if (expectedBox == Double.class) {
+        return Double.valueOf(value);
+      }
+      throw desugarWrongMethodTypeException();
+    }
+    if (type == long.class) {
+      long value = U.getLong(ct1, offset);
+      if (expectedBox == Float.class) {
+        return Float.valueOf(value);
+      }
+      if (expectedBox == Double.class) {
+        return Double.valueOf(value);
+      }
+      throw desugarWrongMethodTypeException();
+    }
+    return U.getObject(ct1, offset);
+  }
+
   int getInt(Object ct1) {
     if (type == int.class) {
       return U.getInt(ct1, offset);
     } else if (type == long.class) {
-      return (int) U.getLong(ct1, offset);
+      throw desugarWrongMethodTypeException();
     } else {
       return toIntIfPossible(U.getObject(ct1, offset), true);
     }
diff --git a/src/test/java/com/android/tools/r8/ir/desugar/varhandle/GenerateVarHandleMethods.java b/src/test/java/com/android/tools/r8/ir/desugar/varhandle/GenerateVarHandleMethods.java
index 369bc51..7052773 100644
--- a/src/test/java/com/android/tools/r8/ir/desugar/varhandle/GenerateVarHandleMethods.java
+++ b/src/test/java/com/android/tools/r8/ir/desugar/varhandle/GenerateVarHandleMethods.java
@@ -223,7 +223,8 @@
     for (String prefix : ImmutableList.of("get", "set", "compareAndSet")) {
       if (name.startsWith(prefix)
           && (name.substring(prefix.length()).equals("Int")
-              || name.substring(prefix.length()).equals("Long"))) {
+              || name.substring(prefix.length()).equals("Long")
+              || name.substring(prefix.length()).equals("InBox"))) {
         return prefix;
       }
     }