Fix flatMap conversions issues

Bug: b/243636261
Bug: b/238179854
Change-Id: I82d99993f7e9699e939a7d7fa4c496a2c90bf5bc
diff --git a/buildSrc/src/main/java/desugaredlibrary/CustomConversionAsmRewriteDescription.java b/buildSrc/src/main/java/desugaredlibrary/CustomConversionAsmRewriteDescription.java
index d412e52..52d878b 100644
--- a/buildSrc/src/main/java/desugaredlibrary/CustomConversionAsmRewriteDescription.java
+++ b/buildSrc/src/main/java/desugaredlibrary/CustomConversionAsmRewriteDescription.java
@@ -21,6 +21,9 @@
           "j$/util/stream/Collector$Characteristics");
   private static final Set<String> WRAP_CONVERT_OWNER =
       ImmutableSet.of(
+          "j$/util/stream/DoubleStream",
+          "j$/util/stream/IntStream",
+          "j$/util/stream/LongStream",
           "j$/util/stream/Stream",
           "j$/nio/file/spi/FileSystemProvider",
           "j$/nio/file/spi/FileTypeDetector",
diff --git a/src/library_desugar/java/j$/util/stream/DoubleStream.java b/src/library_desugar/java/j$/util/stream/DoubleStream.java
new file mode 100644
index 0000000..1fdab0d
--- /dev/null
+++ b/src/library_desugar/java/j$/util/stream/DoubleStream.java
@@ -0,0 +1,16 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package j$.util.stream;
+
+public class DoubleStream {
+
+  public static java.util.stream.DoubleStream wrap_convert(j$.util.stream.DoubleStream stream) {
+    return null;
+  }
+
+  public static j$.util.stream.DoubleStream wrap_convert(java.util.stream.DoubleStream stream) {
+    return null;
+  }
+}
diff --git a/src/library_desugar/java/j$/util/stream/IntStream.java b/src/library_desugar/java/j$/util/stream/IntStream.java
new file mode 100644
index 0000000..75165a8
--- /dev/null
+++ b/src/library_desugar/java/j$/util/stream/IntStream.java
@@ -0,0 +1,16 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package j$.util.stream;
+
+public class IntStream {
+
+  public static java.util.stream.IntStream wrap_convert(j$.util.stream.IntStream stream) {
+    return null;
+  }
+
+  public static j$.util.stream.IntStream wrap_convert(java.util.stream.IntStream stream) {
+    return null;
+  }
+}
diff --git a/src/library_desugar/java/j$/util/stream/LongStream.java b/src/library_desugar/java/j$/util/stream/LongStream.java
new file mode 100644
index 0000000..d078146
--- /dev/null
+++ b/src/library_desugar/java/j$/util/stream/LongStream.java
@@ -0,0 +1,16 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package j$.util.stream;
+
+public class LongStream {
+
+  public static java.util.stream.LongStream wrap_convert(j$.util.stream.LongStream stream) {
+    return null;
+  }
+
+  public static j$.util.stream.LongStream wrap_convert(java.util.stream.LongStream stream) {
+    return null;
+  }
+}
diff --git a/src/library_desugar/java/j$/util/stream/Stream.java b/src/library_desugar/java/j$/util/stream/Stream.java
index 70c761a..dc5861d 100644
--- a/src/library_desugar/java/j$/util/stream/Stream.java
+++ b/src/library_desugar/java/j$/util/stream/Stream.java
@@ -13,4 +13,12 @@
   public static j$.util.stream.Stream<?> inverted_wrap_convert(java.util.stream.Stream<?> stream) {
     return null;
   }
+
+  public static java.util.stream.Stream<?> wrap_convert(j$.util.stream.Stream<?> stream) {
+    return null;
+  }
+
+  public static j$.util.stream.Stream<?> wrap_convert(java.util.stream.Stream<?> stream) {
+    return null;
+  }
 }
diff --git a/src/library_desugar/java/java/util/stream/FlatMapApiFlips.java b/src/library_desugar/java/java/util/stream/FlatMapApiFlips.java
new file mode 100644
index 0000000..b9e1b9e
--- /dev/null
+++ b/src/library_desugar/java/java/util/stream/FlatMapApiFlips.java
@@ -0,0 +1,164 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package java.util.stream;
+
+import static java.util.ConversionRuntimeException.exception;
+
+import java.util.function.DoubleFunction;
+import java.util.function.Function;
+import java.util.function.IntFunction;
+import java.util.function.LongFunction;
+
+public class FlatMapApiFlips {
+
+  public static Function<?, ?> flipFunctionReturningStream(Function<?, ?> function) {
+    return new FunctionStreamWrapper<>(function);
+  }
+
+  public static IntFunction<?> flipFunctionReturningStream(IntFunction<?> function) {
+    return new IntFunctionStreamWrapper<>(function);
+  }
+
+  public static DoubleFunction<?> flipFunctionReturningStream(DoubleFunction<?> function) {
+    return new DoubleFunctionStreamWrapper<>(function);
+  }
+
+  public static LongFunction<?> flipFunctionReturningStream(LongFunction<?> function) {
+    return new LongFunctionStreamWrapper<>(function);
+  }
+
+  public static class FunctionStreamWrapper<T, R> implements Function<T, R> {
+
+    public Function<T, R> function;
+
+    public FunctionStreamWrapper(Function<T, R> function) {
+      this.function = function;
+    }
+
+    private R flipStream(R maybeStream) {
+      if (maybeStream == null) {
+        return null;
+      }
+
+      if (maybeStream instanceof java.util.stream.Stream<?>) {
+        return (R) j$.util.stream.Stream.wrap_convert((java.util.stream.Stream<?>) maybeStream);
+      }
+      if (maybeStream instanceof j$.util.stream.Stream<?>) {
+        return (R) j$.util.stream.Stream.wrap_convert((j$.util.stream.Stream<?>) maybeStream);
+      }
+
+      if (maybeStream instanceof java.util.stream.IntStream) {
+        return (R) j$.util.stream.IntStream.wrap_convert((java.util.stream.IntStream) maybeStream);
+      }
+      if (maybeStream instanceof j$.util.stream.IntStream) {
+        return (R) j$.util.stream.IntStream.wrap_convert((j$.util.stream.IntStream) maybeStream);
+      }
+
+      if (maybeStream instanceof java.util.stream.DoubleStream) {
+        return (R)
+            j$.util.stream.DoubleStream.wrap_convert((java.util.stream.DoubleStream) maybeStream);
+      }
+      if (maybeStream instanceof j$.util.stream.DoubleStream) {
+        return (R)
+            j$.util.stream.DoubleStream.wrap_convert((j$.util.stream.DoubleStream) maybeStream);
+      }
+
+      if (maybeStream instanceof java.util.stream.LongStream) {
+        return (R)
+            j$.util.stream.LongStream.wrap_convert((java.util.stream.LongStream) maybeStream);
+      }
+      if (maybeStream instanceof j$.util.stream.LongStream) {
+        return (R) j$.util.stream.LongStream.wrap_convert((j$.util.stream.LongStream) maybeStream);
+      }
+
+      throw exception("java.util.stream.*Stream", maybeStream.getClass());
+    }
+
+    public R apply(T arg) {
+      return flipStream(function.apply(arg));
+    }
+  }
+
+  public static class IntFunctionStreamWrapper<R> implements IntFunction<R> {
+
+    public IntFunction<R> function;
+
+    public IntFunctionStreamWrapper(IntFunction<R> function) {
+      this.function = function;
+    }
+
+    private R flipStream(R maybeStream) {
+      if (maybeStream == null) {
+        return null;
+      }
+      if (maybeStream instanceof java.util.stream.IntStream) {
+        return (R) j$.util.stream.IntStream.wrap_convert((java.util.stream.IntStream) maybeStream);
+      }
+      if (maybeStream instanceof j$.util.stream.IntStream) {
+        return (R) j$.util.stream.IntStream.wrap_convert((j$.util.stream.IntStream) maybeStream);
+      }
+      throw exception("java.util.stream.IntStream", maybeStream.getClass());
+    }
+
+    public R apply(int arg) {
+      return flipStream(function.apply(arg));
+    }
+  }
+
+  public static class DoubleFunctionStreamWrapper<R> implements DoubleFunction<R> {
+
+    public DoubleFunction<R> function;
+
+    public DoubleFunctionStreamWrapper(DoubleFunction<R> function) {
+      this.function = function;
+    }
+
+    private R flipStream(R maybeStream) {
+      if (maybeStream == null) {
+        return null;
+      }
+      if (maybeStream instanceof java.util.stream.DoubleStream) {
+        return (R)
+            j$.util.stream.DoubleStream.wrap_convert((java.util.stream.DoubleStream) maybeStream);
+      }
+      if (maybeStream instanceof j$.util.stream.DoubleStream) {
+        return (R)
+            j$.util.stream.DoubleStream.wrap_convert((j$.util.stream.DoubleStream) maybeStream);
+      }
+      throw exception("java.util.stream.DoubleStream", maybeStream.getClass());
+    }
+
+    public R apply(double arg) {
+      return flipStream(function.apply(arg));
+    }
+  }
+
+  public static class LongFunctionStreamWrapper<R> implements LongFunction<R> {
+
+    public LongFunction<R> function;
+
+    public LongFunctionStreamWrapper(LongFunction<R> function) {
+      this.function = function;
+    }
+
+    private R flipStream(R maybeStream) {
+      if (maybeStream == null) {
+        return null;
+      }
+      if (maybeStream instanceof java.util.stream.LongStream) {
+        return (R)
+            j$.util.stream.LongStream.wrap_convert((java.util.stream.LongStream) maybeStream);
+      }
+      if (maybeStream instanceof j$.util.stream.LongStream) {
+        return (R) j$.util.stream.LongStream.wrap_convert((j$.util.stream.LongStream) maybeStream);
+      }
+      throw exception("java.util.stream.LongStream", maybeStream.getClass());
+    }
+
+    public R apply(long arg) {
+      return flipStream(function.apply(arg));
+    }
+  }
+}
diff --git a/src/library_desugar/jdk11/desugar_jdk_libs.json b/src/library_desugar/jdk11/desugar_jdk_libs.json
index 103035f..609f920 100644
--- a/src/library_desugar/jdk11/desugar_jdk_libs.json
+++ b/src/library_desugar/jdk11/desugar_jdk_libs.json
@@ -100,6 +100,13 @@
       },
       "api_generic_types_conversion": {
         "java.util.Set java.util.stream.Collector#characteristics()" : [-1, "java.util.Set java.util.stream.StreamApiFlips#flipCharacteristicSet(java.util.Set)"],
+        "java.util.stream.Stream java.util.stream.Stream#flatMap(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.Function)"],
+        "java.util.stream.DoubleStream java.util.stream.DoubleStream#flatMap(java.util.function.DoubleFunction)": [0, "java.util.function.DoubleFunction java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.DoubleFunction)"],
+        "java.util.stream.DoubleStream java.util.stream.Stream#flatMapToDouble(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.Function)"],
+        "java.util.stream.IntStream java.util.stream.Stream#flatMapToInt(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.Function)"],
+        "java.util.stream.IntStream java.util.stream.IntStream#flatMap(java.util.function.IntFunction)": [0, "java.util.function.IntFunction java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.IntFunction)"],
+        "java.util.stream.LongStream java.util.stream.Stream#flatMapToLong(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.Function)"],
+        "java.util.stream.LongStream java.util.stream.LongStream#flatMap(java.util.function.LongFunction)": [0, "java.util.function.LongFunction java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.LongFunction)"],
         "java.lang.Object java.lang.StackWalker#walk(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.StackWalkerApiFlips#flipFunctionStream(java.util.function.Function)"]
       },
       "never_outline_api": [
@@ -234,6 +241,15 @@
         "java.util.Optional": {
           "j$.util.Optional": "java.util.Optional"
         },
+        "java.util.stream.DoubleStream": {
+          "j$.util.stream.DoubleStream": "java.util.stream.DoubleStream"
+        },
+        "java.util.stream.IntStream": {
+          "j$.util.stream.IntStream": "java.util.stream.IntStream"
+        },
+        "java.util.stream.LongStream": {
+          "j$.util.stream.LongStream": "java.util.stream.LongStream"
+        },
         "java.util.stream.Stream": {
           "j$.util.stream.Stream": "java.util.stream.Stream"
         }
diff --git a/src/library_desugar/jdk11/desugar_jdk_libs_nio.json b/src/library_desugar/jdk11/desugar_jdk_libs_nio.json
index a53fd3b..e2f74e8 100644
--- a/src/library_desugar/jdk11/desugar_jdk_libs_nio.json
+++ b/src/library_desugar/jdk11/desugar_jdk_libs_nio.json
@@ -237,6 +237,13 @@
       },
       "api_generic_types_conversion": {
         "java.util.Set java.util.stream.Collector#characteristics()" : [-1, "java.util.Set java.util.stream.StreamApiFlips#flipCharacteristicSet(java.util.Set)"],
+        "java.util.stream.Stream java.util.stream.Stream#flatMap(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.Function)"],
+        "java.util.stream.DoubleStream java.util.stream.DoubleStream#flatMap(java.util.function.DoubleFunction)": [0, "java.util.function.DoubleFunction java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.DoubleFunction)"],
+        "java.util.stream.DoubleStream java.util.stream.Stream#flatMapToDouble(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.Function)"],
+        "java.util.stream.IntStream java.util.stream.Stream#flatMapToInt(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.Function)"],
+        "java.util.stream.IntStream java.util.stream.IntStream#flatMap(java.util.function.IntFunction)": [0, "java.util.function.IntFunction java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.IntFunction)"],
+        "java.util.stream.LongStream java.util.stream.Stream#flatMapToLong(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.Function)"],
+        "java.util.stream.LongStream java.util.stream.LongStream#flatMap(java.util.function.LongFunction)": [0, "java.util.function.LongFunction java.util.stream.FlatMapApiFlips#flipFunctionReturningStream(java.util.function.LongFunction)"],
         "java.lang.Object java.lang.StackWalker#walk(java.util.function.Function)": [0, "java.util.function.Function java.util.stream.StackWalkerApiFlips#flipFunctionStream(java.util.function.Function)"]
       },
       "never_outline_api": [
@@ -454,6 +461,15 @@
         "java.util.Optional": {
           "j$.util.Optional": "java.util.Optional"
         },
+        "java.util.stream.DoubleStream": {
+          "j$.util.stream.DoubleStream": "java.util.stream.DoubleStream"
+        },
+        "java.util.stream.IntStream": {
+          "j$.util.stream.IntStream": "java.util.stream.IntStream"
+        },
+        "java.util.stream.LongStream": {
+          "j$.util.stream.LongStream": "java.util.stream.LongStream"
+        },
         "java.util.stream.Stream": {
           "j$.util.stream.Stream": "java.util.stream.Stream"
         }
diff --git a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/ExtractWrapperTypesTest.java b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/ExtractWrapperTypesTest.java
index c572169..bb51f2f 100644
--- a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/ExtractWrapperTypesTest.java
+++ b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/ExtractWrapperTypesTest.java
@@ -100,9 +100,12 @@
           "java.util.Locale$FilteringMode",
           "java.util.SplittableRandom");
 
-  // TODO(b/238179854): Investigate how to fix these.
-  private static final Set<String> MISSING_GENERIC_TYPE_CONVERSION =
+  private static final Set<String> MISSING_GENERIC_TYPE_CONVERSION = ImmutableSet.of();
+
+  // Missing conversions in JDK8 and JDK11_LEGACY desugared library that are fixed in JDK11.
+  private static final Set<String> MISSING_GENERIC_TYPE_CONVERSION_8 =
       ImmutableSet.of(
+          "java.util.Set java.util.stream.Collector.characteristics()",
           "java.util.stream.Stream java.util.stream.Stream.flatMap(java.util.function.Function)",
           "java.util.stream.DoubleStream"
               + " java.util.stream.DoubleStream.flatMap(java.util.function.DoubleFunction)",
@@ -115,12 +118,7 @@
           "java.util.stream.LongStream"
               + " java.util.stream.Stream.flatMapToLong(java.util.function.Function)",
           "java.util.stream.LongStream"
-              + " java.util.stream.LongStream.flatMap(java.util.function.LongFunction)");
-
-  // Missing conversions in JDK8 desugared library that are fixed in JDK11 desugared library.
-  private static final Set<String> MISSING_GENERIC_TYPE_CONVERSION_8 =
-      ImmutableSet.of(
-          "java.util.Set java.util.stream.Collector.characteristics()",
+              + " java.util.stream.LongStream.flatMap(java.util.function.LongFunction)",
           "java.lang.Object java.lang.StackWalker.walk(java.util.function.Function)");
 
   // TODO(b/238179854): Investigate how to fix these.
diff --git a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/conversiontests/FlatMapConversionTest.java b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/conversiontests/FlatMapConversionTest.java
index b9a3dba..62561e1 100644
--- a/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/conversiontests/FlatMapConversionTest.java
+++ b/src/test/java/com/android/tools/r8/desugar/desugaredlibrary/conversiontests/FlatMapConversionTest.java
@@ -4,8 +4,13 @@
 
 package com.android.tools.r8.desugar.desugaredlibrary.conversiontests;
 
+import static com.android.tools.r8.desugar.desugaredlibrary.test.CompilationSpecification.D8_L8DEBUG;
 import static com.android.tools.r8.desugar.desugaredlibrary.test.CompilationSpecification.DEFAULT_SPECIFICATIONS;
-import static com.android.tools.r8.desugar.desugaredlibrary.test.LibraryDesugaringSpecification.getJdk8Jdk11;
+import static com.android.tools.r8.desugar.desugaredlibrary.test.LibraryDesugaringSpecification.JDK11;
+import static com.android.tools.r8.desugar.desugaredlibrary.test.LibraryDesugaringSpecification.JDK11_LEGACY;
+import static com.android.tools.r8.desugar.desugaredlibrary.test.LibraryDesugaringSpecification.JDK11_MINIMAL;
+import static com.android.tools.r8.desugar.desugaredlibrary.test.LibraryDesugaringSpecification.JDK11_PATH;
+import static com.android.tools.r8.desugar.desugaredlibrary.test.LibraryDesugaringSpecification.JDK8;
 import static org.hamcrest.CoreMatchers.containsString;
 
 import com.android.tools.r8.TestParameters;
@@ -15,12 +20,14 @@
 import com.android.tools.r8.desugar.desugaredlibrary.test.LibraryDesugaringSpecification;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.android.tools.r8.utils.StringUtils;
+import com.google.common.collect.ImmutableList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.stream.DoubleStream;
 import java.util.stream.IntStream;
 import java.util.stream.LongStream;
 import java.util.stream.Stream;
+import org.junit.Assume;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -48,7 +55,7 @@
   public static List<Object[]> data() {
     return buildParameters(
         getConversionParametersUpToExcluding(MIN_SUPPORTED),
-        getJdk8Jdk11(),
+        ImmutableList.of(JDK8, JDK11_LEGACY, JDK11_MINIMAL, JDK11, JDK11_PATH),
         DEFAULT_SPECIFICATIONS);
   }
 
@@ -62,6 +69,18 @@
   }
 
   @Test
+  public void testReference() throws Throwable {
+    Assume.assumeTrue(
+        "Run only once",
+        libraryDesugaringSpecification == JDK11 && compilationSpecification == D8_L8DEBUG);
+    testForD8()
+        .setMinApi(parameters.getApiLevel())
+        .addProgramClasses(Executor.class, CustomLibClass.class)
+        .run(parameters.getRuntime(), Executor.class)
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+
+  @Test
   public void testConvert() throws Throwable {
     testForDesugaredLibrary(parameters, libraryDesugaringSpecification, compilationSpecification)
         .addProgramClasses(Executor.class)
@@ -69,7 +88,12 @@
             new CustomLibrarySpecification(CustomLibClass.class, MIN_SUPPORTED))
         .addKeepMainRule(Executor.class)
         .run(parameters.getRuntime(), Executor.class)
-        .assertFailureWithErrorThatMatches(containsString("java.lang.ClassCastException"));
+        .applyIf(
+            libraryDesugaringSpecification == JDK8
+                || libraryDesugaringSpecification == JDK11_LEGACY,
+            r ->
+                r.assertFailureWithErrorThatMatches(containsString("java.lang.ClassCastException")),
+            r -> r.assertSuccessWithOutput(EXPECTED_RESULT));
   }
 
   static class Executor {