Rewrite java.nio.Buffer jdk11/13 covariant return

- MappedByteBuffer covariant return types has been introduced in jdk 13
while the others in jdk11.
Bug: 184087908
Change-Id: I3f16eec966d37714e3cc60b2194b8db9942d4283
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfInstruction.java b/src/main/java/com/android/tools/r8/cf/code/CfInstruction.java
index 1009080..775a909 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfInstruction.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfInstruction.java
@@ -157,6 +157,10 @@
     return false;
   }
 
+  public boolean isInvokeInterface() {
+    return false;
+  }
+
   public CfLabel asLabel() {
     return null;
   }
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java b/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
index 87cae25..4153c0b 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfInvoke.java
@@ -189,6 +189,10 @@
     return opcode == Opcodes.INVOKEVIRTUAL;
   }
 
+  public boolean isInvokeInterface() {
+    return opcode == Opcodes.INVOKEINTERFACE;
+  }
+
   @Override
   public boolean canThrow() {
     return true;
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 86b96e8..412a66d 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -239,6 +239,16 @@
   public final DexString closeableDescriptor = createString("Ljava/io/Closeable;");
   public final DexString zipFileDescriptor = createString("Ljava/util/zip/ZipFile;");
 
+  public final DexString bufferDescriptor = createString("Ljava/nio/Buffer;");
+  public final DexString byteBufferDescriptor = createString("Ljava/nio/ByteBuffer;");
+  public final DexString mappedByteBufferDescriptor = createString("Ljava/nio/MappedByteBuffer;");
+  public final DexString charBufferDescriptor = createString("Ljava/nio/CharBuffer;");
+  public final DexString shortBufferDescriptor = createString("Ljava/nio/ShortBuffer;");
+  public final DexString intBufferDescriptor = createString("Ljava/nio/IntBuffer;");
+  public final DexString longBufferDescriptor = createString("Ljava/nio/LongBuffer;");
+  public final DexString floatBufferDescriptor = createString("Ljava/nio/FloatBuffer;");
+  public final DexString doubleBufferDescriptor = createString("Ljava/nio/DoubleBuffer;");
+
   public final DexString stringBuilderDescriptor = createString("Ljava/lang/StringBuilder;");
   public final DexString stringBufferDescriptor = createString("Ljava/lang/StringBuffer;");
 
@@ -410,6 +420,26 @@
   public final DexType optionalLongType = createStaticallyKnownType(optionalLongDescriptor);
   public final DexType streamType = createStaticallyKnownType(streamDescriptor);
 
+  public final DexType bufferType = createStaticallyKnownType(bufferDescriptor);
+  public final DexType byteBufferType = createStaticallyKnownType(byteBufferDescriptor);
+  public final DexType mappedByteBufferType = createStaticallyKnownType(mappedByteBufferDescriptor);
+  public final DexType charBufferType = createStaticallyKnownType(charBufferDescriptor);
+  public final DexType shortBufferType = createStaticallyKnownType(shortBufferDescriptor);
+  public final DexType intBufferType = createStaticallyKnownType(intBufferDescriptor);
+  public final DexType longBufferType = createStaticallyKnownType(longBufferDescriptor);
+  public final DexType floatBufferType = createStaticallyKnownType(floatBufferDescriptor);
+  public final DexType doubleBufferType = createStaticallyKnownType(doubleBufferDescriptor);
+  public final List<DexType> typeSpecificBuffers =
+      ImmutableList.of(
+          byteBufferType,
+          mappedByteBufferType,
+          charBufferType,
+          shortBufferType,
+          intBufferType,
+          longBufferType,
+          floatBufferType,
+          doubleBufferType);
+
   public final DexType doubleConsumer =
       createStaticallyKnownType("Ljava/util/function/DoubleConsumer;");
   public final DexType longConsumer =
@@ -517,6 +547,7 @@
   public final LongMembers longMembers = new LongMembers();
   public final ObjectsMethods objectsMethods = new ObjectsMethods();
   public final ObjectMembers objectMembers = new ObjectMembers();
+  public final BufferMembers bufferMembers = new BufferMembers();
   public final RecordMembers recordMembers = new RecordMembers();
   public final ShortMembers shortMembers = new ShortMembers();
   public final StringMembers stringMembers = new StringMembers();
@@ -1278,6 +1309,20 @@
     }
   }
 
+  public class BufferMembers {
+    public final DexMethod positionArg =
+        createMethod(bufferType, createProto(bufferType, intType), "position");
+    public final DexMethod limitArg =
+        createMethod(bufferType, createProto(bufferType, intType), "limit");
+    public final DexMethod mark = createMethod(bufferType, createProto(bufferType), "mark");
+    public final DexMethod reset = createMethod(bufferType, createProto(bufferType), "reset");
+    public final DexMethod clear = createMethod(bufferType, createProto(bufferType), "clear");
+    public final DexMethod flip = createMethod(bufferType, createProto(bufferType), "flip");
+    public final DexMethod rewind = createMethod(bufferType, createProto(bufferType), "rewind");
+    public final List<DexMethod> bufferCovariantMethods =
+        ImmutableList.of(positionArg, limitArg, mark, reset, clear, flip, rewind);
+  }
+
   public class ObjectsMethods {
 
     public final DexMethod requireNonNull;
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/BufferCovariantReturnTypeRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/BufferCovariantReturnTypeRewriter.java
new file mode 100644
index 0000000..8678041
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/desugar/BufferCovariantReturnTypeRewriter.java
@@ -0,0 +1,93 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.desugar;
+
+import com.android.tools.r8.cf.code.CfCheckCast;
+import com.android.tools.r8.cf.code.CfInstruction;
+import com.android.tools.r8.cf.code.CfInvoke;
+import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProto;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.google.common.collect.ImmutableList;
+import java.util.Collection;
+
+/**
+ * BufferCovariantReturnTypeRewriter rewrites the return type of invoked methods matching
+ * factory.bufferMembers.bufferCovariantMethods to return Buffer instead of the subtype.
+ */
+public class BufferCovariantReturnTypeRewriter implements CfInstructionDesugaring {
+
+  private final AppView<?> appView;
+  private final DexItemFactory factory;
+
+  public BufferCovariantReturnTypeRewriter(AppView<?> appView) {
+    this.appView = appView;
+    this.factory = appView.dexItemFactory();
+  }
+
+  public Collection<CfInstruction> desugarInstruction(
+      CfInstruction instruction,
+      FreshLocalProvider freshLocalProvider,
+      LocalStackAllocator localStackAllocator,
+      CfInstructionDesugaringEventConsumer eventConsumer,
+      ProgramMethod context,
+      MethodProcessingContext methodProcessingContext) {
+    if (!isInvokeCandidate(instruction)) {
+      return null;
+    }
+    CfInvoke cfInvoke = instruction.asInvoke();
+    DexMethod invokedMethod = cfInvoke.getMethod();
+    DexMethod covariantMethod = matchingBufferCovariantMethod(invokedMethod);
+    if (covariantMethod == null) {
+      return null;
+    }
+    DexProto proto =
+        factory.createProto(factory.bufferType, invokedMethod.getProto().parameters.values);
+    CfInvoke newInvoke =
+        new CfInvoke(
+            cfInvoke.getOpcode(), invokedMethod.withProto(proto, factory), cfInvoke.isInterface());
+    return ImmutableList.of(newInvoke, new CfCheckCast(invokedMethod.getReturnType()));
+  }
+
+  private DexMethod matchingBufferCovariantMethod(DexMethod invokedMethod) {
+    if (invokedMethod.getArity() > 1
+        || (invokedMethod.getArity() == 1 && !invokedMethod.getParameter(0).isIntType())
+        || invokedMethod.getReturnType() == factory.bufferType
+        || !factory.typeSpecificBuffers.contains(invokedMethod.holder)
+        // The return type can differ from the holder with for example
+        // holder: MappedByteBuffer, return type: ByteBuffer, covariant return type: Buffer.
+        || !factory.typeSpecificBuffers.contains(invokedMethod.getReturnType())) {
+      return null;
+    }
+    // This rewrites the methods only for the java library buffers, but it is not normally possible
+    // to create user-defined buffers which suffer from the issue since all constructors in buffers
+    // are package-private.
+    for (DexMethod covariantMethod : factory.bufferMembers.bufferCovariantMethods) {
+      if (covariantMethod.name == invokedMethod.name
+          && covariantMethod.getParameters().equals(invokedMethod.getParameters())) {
+        return covariantMethod;
+      }
+    }
+    return null;
+  }
+
+  private boolean isInvokeCandidate(CfInstruction instruction) {
+    return instruction.isInvoke()
+        || instruction.isInvokeStatic()
+        || instruction.isInvokeInterface();
+  }
+
+  @Override
+  public boolean needsDesugaring(CfInstruction instruction, ProgramMethod context) {
+    if (!isInvokeCandidate(instruction)) {
+      return false;
+    }
+    DexMethod invokedMethod = instruction.asInvoke().getMethod();
+    return matchingBufferCovariantMethod(invokedMethod) != null;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/NonEmptyCfInstructionDesugaringCollection.java b/src/main/java/com/android/tools/r8/ir/desugar/NonEmptyCfInstructionDesugaringCollection.java
index 433e023..04b40b3 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/NonEmptyCfInstructionDesugaringCollection.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/NonEmptyCfInstructionDesugaringCollection.java
@@ -46,6 +46,7 @@
     desugarings.add(new LambdaInstructionDesugaring(appView));
     desugarings.add(new InvokeSpecialToSelfDesugaring(appView));
     desugarings.add(new StringConcatInstructionDesugaring(appView));
+    desugarings.add(new BufferCovariantReturnTypeRewriter(appView));
     if (appView.options().enableBackportedMethodRewriting()) {
       BackportedMethodRewriter backportedMethodRewriter = new BackportedMethodRewriter(appView);
       if (backportedMethodRewriter.hasBackports()) {
diff --git a/src/test/examplesJava11/buffercovariantreturntype/BufferCovariantReturnTypeMain.java b/src/test/examplesJava11/buffercovariantreturntype/BufferCovariantReturnTypeMain.java
new file mode 100644
index 0000000..67079a5
--- /dev/null
+++ b/src/test/examplesJava11/buffercovariantreturntype/BufferCovariantReturnTypeMain.java
@@ -0,0 +1,377 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package buffercovariantreturntype;
+
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import java.nio.ShortBuffer;
+
+public class BufferCovariantReturnTypeMain {
+
+  public static void main(String[] args) {
+    byteBufferTest();
+    charBufferTest();
+    shortBufferTest();
+    intBufferTest();
+    longBufferTest();
+    floatBufferTest();
+    doubleBufferTest();
+  }
+
+  static void byteBufferTest() {
+    byte[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+    byte putValue = 55;
+
+    ByteBuffer directBuffer = ByteBuffer.wrap(data);
+    Buffer indirectBuffer = ByteBuffer.wrap(data);
+    ByteBuffer castedIndirectBuffer = (ByteBuffer) indirectBuffer;
+
+    directBuffer = directBuffer.position(5);
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.limit(7);
+    System.out.println(directBuffer.remaining()); // 2
+    directBuffer = directBuffer.mark().put(putValue).put(putValue).reset();
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.clear();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 16
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.rewind();
+    System.out.println(directBuffer.position()); // 0
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.flip();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 2
+
+    indirectBuffer = indirectBuffer.position(5);
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.limit(7);
+    System.out.println(indirectBuffer.remaining()); // 2
+    indirectBuffer = indirectBuffer.mark();
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.reset();
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.clear();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 16
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.rewind();
+    System.out.println(indirectBuffer.position()); // 0
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.flip();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 2
+  }
+
+  static void charBufferTest() {
+    char[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+    char putValue = 55;
+
+    CharBuffer directBuffer = CharBuffer.wrap(data);
+    Buffer indirectBuffer = CharBuffer.wrap(data);
+    CharBuffer castedIndirectBuffer = (CharBuffer) indirectBuffer;
+
+    directBuffer = directBuffer.position(5);
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.limit(7);
+    System.out.println(directBuffer.remaining()); // 2
+    directBuffer = directBuffer.mark().put(putValue).put(putValue).reset();
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.clear();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 16
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.rewind();
+    System.out.println(directBuffer.position()); // 0
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.flip();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 2
+
+    indirectBuffer = indirectBuffer.position(5);
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.limit(7);
+    System.out.println(indirectBuffer.remaining()); // 2
+    indirectBuffer = indirectBuffer.mark();
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.reset();
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.clear();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 16
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.rewind();
+    System.out.println(indirectBuffer.position()); // 0
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.flip();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 2
+  }
+
+  static void shortBufferTest() {
+    short[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+    short putValue = 55;
+
+    ShortBuffer directBuffer = ShortBuffer.wrap(data);
+    Buffer indirectBuffer = ShortBuffer.wrap(data);
+    ShortBuffer castedIndirectBuffer = (ShortBuffer) indirectBuffer;
+
+    directBuffer = directBuffer.position(5);
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.limit(7);
+    System.out.println(directBuffer.remaining()); // 2
+    directBuffer = directBuffer.mark().put(putValue).put(putValue).reset();
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.clear();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 16
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.rewind();
+    System.out.println(directBuffer.position()); // 0
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.flip();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 2
+
+    indirectBuffer = indirectBuffer.position(5);
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.limit(7);
+    System.out.println(indirectBuffer.remaining()); // 2
+    indirectBuffer = indirectBuffer.mark();
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.reset();
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.clear();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 16
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.rewind();
+    System.out.println(indirectBuffer.position()); // 0
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.flip();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 2
+  }
+
+  static void intBufferTest() {
+    int[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+    int putValue = 55;
+
+    IntBuffer directBuffer = IntBuffer.wrap(data);
+    Buffer indirectBuffer = IntBuffer.wrap(data);
+    IntBuffer castedIndirectBuffer = (IntBuffer) indirectBuffer;
+
+    directBuffer = directBuffer.position(5);
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.limit(7);
+    System.out.println(directBuffer.remaining()); // 2
+    directBuffer = directBuffer.mark().put(putValue).put(putValue).reset();
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.clear();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 16
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.rewind();
+    System.out.println(directBuffer.position()); // 0
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.flip();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 2
+
+    indirectBuffer = indirectBuffer.position(5);
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.limit(7);
+    System.out.println(indirectBuffer.remaining()); // 2
+    indirectBuffer = indirectBuffer.mark();
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.reset();
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.clear();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 16
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.rewind();
+    System.out.println(indirectBuffer.position()); // 0
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.flip();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 2
+  }
+
+  static void longBufferTest() {
+    long[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+    long putValue = 55;
+
+    LongBuffer directBuffer = LongBuffer.wrap(data);
+    Buffer indirectBuffer = LongBuffer.wrap(data);
+    LongBuffer castedIndirectBuffer = (LongBuffer) indirectBuffer;
+
+    directBuffer = directBuffer.position(5);
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.limit(7);
+    System.out.println(directBuffer.remaining()); // 2
+    directBuffer = directBuffer.mark().put(putValue).put(putValue).reset();
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.clear();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 16
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.rewind();
+    System.out.println(directBuffer.position()); // 0
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.flip();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 2
+
+    indirectBuffer = indirectBuffer.position(5);
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.limit(7);
+    System.out.println(indirectBuffer.remaining()); // 2
+    indirectBuffer = indirectBuffer.mark();
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.reset();
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.clear();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 16
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.rewind();
+    System.out.println(indirectBuffer.position()); // 0
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.flip();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 2
+  }
+
+  static void floatBufferTest() {
+    float[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+    float putValue = 55;
+
+    FloatBuffer directBuffer = FloatBuffer.wrap(data);
+    Buffer indirectBuffer = FloatBuffer.wrap(data);
+    FloatBuffer castedIndirectBuffer = (FloatBuffer) indirectBuffer;
+
+    directBuffer = directBuffer.position(5);
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.limit(7);
+    System.out.println(directBuffer.remaining()); // 2
+    directBuffer = directBuffer.mark().put(putValue).put(putValue).reset();
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.clear();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 16
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.rewind();
+    System.out.println(directBuffer.position()); // 0
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.flip();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 2
+
+    indirectBuffer = indirectBuffer.position(5);
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.limit(7);
+    System.out.println(indirectBuffer.remaining()); // 2
+    indirectBuffer = indirectBuffer.mark();
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.reset();
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.clear();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 16
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.rewind();
+    System.out.println(indirectBuffer.position()); // 0
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.flip();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 2
+  }
+
+  static void doubleBufferTest() {
+    double[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+    double putValue = 55;
+
+    DoubleBuffer directBuffer = DoubleBuffer.wrap(data);
+    Buffer indirectBuffer = DoubleBuffer.wrap(data);
+    DoubleBuffer castedIndirectBuffer = (DoubleBuffer) indirectBuffer;
+
+    directBuffer = directBuffer.position(5);
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.limit(7);
+    System.out.println(directBuffer.remaining()); // 2
+    directBuffer = directBuffer.mark().put(putValue).put(putValue).reset();
+    System.out.println(directBuffer.position()); // 5
+    directBuffer = directBuffer.clear();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 16
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.rewind();
+    System.out.println(directBuffer.position()); // 0
+    directBuffer.put(putValue);
+    directBuffer.put(putValue);
+    directBuffer = directBuffer.flip();
+    System.out.println(directBuffer.position()); // 0
+    System.out.println(directBuffer.remaining()); // 2
+
+    indirectBuffer = indirectBuffer.position(5);
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.limit(7);
+    System.out.println(indirectBuffer.remaining()); // 2
+    indirectBuffer = indirectBuffer.mark();
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.reset();
+    System.out.println(indirectBuffer.position()); // 5
+    indirectBuffer = indirectBuffer.clear();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 16
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.rewind();
+    System.out.println(indirectBuffer.position()); // 0
+    castedIndirectBuffer.put(putValue);
+    castedIndirectBuffer.put(putValue);
+    indirectBuffer = indirectBuffer.flip();
+    System.out.println(indirectBuffer.position()); // 0
+    System.out.println(indirectBuffer.remaining()); // 2
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/desugar/buffercovariantreturntype/BufferCovariantReturnTypeTest.java b/src/test/java/com/android/tools/r8/desugar/buffercovariantreturntype/BufferCovariantReturnTypeTest.java
new file mode 100644
index 0000000..2cd006d
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/buffercovariantreturntype/BufferCovariantReturnTypeTest.java
@@ -0,0 +1,75 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.desugar.buffercovariantreturntype;
+
+import static com.android.tools.r8.TestRuntime.CfVm.JDK11;
+import static com.android.tools.r8.utils.FileUtils.JAR_EXTENSION;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.ToolHelper;
+import com.android.tools.r8.utils.StringUtils;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import org.junit.Assume;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class BufferCovariantReturnTypeTest extends TestBase {
+
+  private static final Path JAR =
+      Paths.get(ToolHelper.EXAMPLES_JAVA11_JAR_DIR)
+          .resolve("buffercovariantreturntype" + JAR_EXTENSION);
+  private static final String EXPECTED_RESULT_PER_BUFFER =
+      StringUtils.lines("5", "2", "5", "0", "16", "0", "0", "2");
+  private static final String EXPECTED_RESULT =
+      new String(new char[14]).replace("\0", EXPECTED_RESULT_PER_BUFFER);
+
+  private final TestParameters parameters;
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters()
+        .withCfRuntimesStartingFromIncluding(JDK11)
+        .withDexRuntimes()
+        .withAllApiLevelsAlsoForCf()
+        .build();
+  }
+
+  public BufferCovariantReturnTypeTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testJVM() throws Exception {
+    Assume.assumeTrue(parameters.isCfRuntime());
+    testForJvm()
+        .addProgramFiles(JAR)
+        .run(parameters.getRuntime(), "buffercovariantreturntype.BufferCovariantReturnTypeMain")
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    testForD8(parameters.getBackend())
+        .addProgramFiles(JAR)
+        .setMinApi(parameters.getApiLevel())
+        .run(parameters.getRuntime(), "buffercovariantreturntype.BufferCovariantReturnTypeMain")
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramFiles(JAR)
+        .addKeepMainRule("buffercovariantreturntype.BufferCovariantReturnTypeMain")
+        .setMinApi(parameters.getApiLevel())
+        .run(parameters.getRuntime(), "buffercovariantreturntype.BufferCovariantReturnTypeMain")
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+}