Desugared lib API callback support

Bug:134732760
Change-Id: I11438c2ac9583ec7e3dc6f25ae4534e44250d1ac
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/DesugaredLibraryAPIConverter.java b/src/main/java/com/android/tools/r8/ir/desugar/DesugaredLibraryAPIConverter.java
index a5f0a56..b9a758a 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/DesugaredLibraryAPIConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/DesugaredLibraryAPIConverter.java
@@ -5,8 +5,10 @@
 package com.android.tools.r8.ir.desugar;
 
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.CfCode;
 import com.android.tools.r8.graph.DexApplication;
 import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProto;
@@ -21,12 +23,17 @@
 import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.IRConverter;
+import com.android.tools.r8.ir.synthetic.DesugaredLibraryAPIConversionCfCodeProvider.APIConverterWrapperCfCodeProvider;
 import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.StringDiagnostic;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 
@@ -52,6 +59,7 @@
   private final AppView<?> appView;
   private final DexItemFactory factory;
   private final DesugaredLibraryWrapperSynthesizer wrapperSynthesizor;
+  private final Map<DexClass, List<DexEncodedMethod>> callBackMethods = new HashMap<>();
 
   public DesugaredLibraryAPIConverter(AppView<?> appView) {
     this.appView = appView;
@@ -65,9 +73,7 @@
       return;
     }
 
-    // TODO(b/134732760): The current code does not catch library calls into a program override
-    //  which gets rewritten. If method signature has rewritten types and method overrides library,
-    //  I should convert back.
+    generateCallBackIfNeeded(code);
 
     InstructionListIterator iterator = code.instructionListIterator();
     while (iterator.hasNext()) {
@@ -88,24 +94,115 @@
       }
       // Library methods do not understand desugared types, hence desugared types have to be
       // converted around non desugared library calls for the invoke to resolve.
-      if (appView.rewritePrefix.hasRewrittenType(invokedMethod.proto.returnType)) {
+      if (appView.rewritePrefix.hasRewrittenTypeInSignature(invokedMethod.proto)) {
         rewriteLibraryInvoke(code, invokeMethod, iterator);
-        continue;
-      }
-      for (int i = 0; i < invokedMethod.proto.parameters.values.length; i++) {
-        DexType argType = invokedMethod.proto.parameters.values[i];
-        if (appView.rewritePrefix.hasRewrittenType(argType)) {
-          rewriteLibraryInvoke(code, invokeMethod, iterator);
-          continue;
-        }
       }
     }
   }
 
+  private void generateCallBackIfNeeded(IRCode code) {
+    // Any override of a library method can be called by the library.
+    // We duplicate the method to have a vivified type version callable by the library and
+    // a type version callable by the program. We need to add the vivified version to the rootset
+    // as it is actually overriding a library method (after changing the vivified type to the core
+    // library type), but the enqueuer cannot see that.
+    // To avoid too much computation we first look if the method would need to be rewritten if
+    // it would override a library method, then check if it overrides a library method.
+    if (code.method.isPrivateMethod() || code.method.isStatic()) {
+      return;
+    }
+    DexMethod method = code.method.method;
+    if (appView.rewritePrefix.hasRewrittenType(method.holder) || method.holder.isArrayType()) {
+      return;
+    }
+    DexClass dexClass = appView.definitionFor(method.holder);
+    if (dexClass == null) {
+      return;
+    }
+    if (!appView.rewritePrefix.hasRewrittenTypeInSignature(method.proto)) {
+      return;
+    }
+    if (overridesLibraryMethod(dexClass, method)) {
+      generateCallBack(dexClass, code.method);
+    }
+  }
+
+  private boolean overridesLibraryMethod(DexClass theClass, DexMethod method) {
+    // We look up everywhere to see if there is a supertype/interface implementing the method...
+    LinkedList<DexType> workList = new LinkedList<>();
+    Collections.addAll(workList, theClass.interfaces.values);
+    // There is no methods with desugared types on Object.
+    if (theClass.superType != factory.objectType) {
+      workList.add(theClass.superType);
+    }
+    while (!workList.isEmpty()) {
+      DexType current = workList.removeFirst();
+      DexClass dexClass = appView.definitionFor(current);
+      if (dexClass == null) {
+        continue;
+      }
+      workList.addAll(Arrays.asList(dexClass.interfaces.values));
+      if (dexClass.superType != factory.objectType) {
+        workList.add(dexClass.superType);
+      }
+      if (!dexClass.isLibraryClass()) {
+        continue;
+      }
+      DexEncodedMethod dexEncodedMethod = dexClass.lookupVirtualMethod(method);
+      if (dexEncodedMethod != null) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  private synchronized void generateCallBack(DexClass dexClass, DexEncodedMethod originalMethod) {
+    DexMethod methodToInstall =
+        methodWithVivifiedTypeInSignature(originalMethod.method, dexClass.type);
+    CfCode cfCode =
+        new APIConverterWrapperCfCodeProvider(
+                appView, originalMethod.method, null, this, dexClass.isInterface())
+            .generateCfCode();
+    DexEncodedMethod newDexEncodedMethod =
+        wrapperSynthesizor.newSynthesizedMethod(methodToInstall, originalMethod, cfCode);
+    newDexEncodedMethod.setCode(cfCode, appView);
+    addCallBackSignature(dexClass, newDexEncodedMethod);
+  }
+
+  private synchronized void addCallBackSignature(DexClass dexClass, DexEncodedMethod method) {
+    callBackMethods.putIfAbsent(dexClass, new ArrayList<>());
+    List<DexEncodedMethod> dexEncodedMethods = callBackMethods.get(dexClass);
+    dexEncodedMethods.add(method);
+  }
+
+  DexMethod methodWithVivifiedTypeInSignature(DexMethod originalMethod, DexType holder) {
+    DexType[] newParameters = originalMethod.proto.parameters.values.clone();
+    int index = 0;
+    for (DexType param : originalMethod.proto.parameters.values) {
+      if (appView.rewritePrefix.hasRewrittenType(param)) {
+        newParameters[index] = this.vivifiedTypeFor(param);
+      }
+      index++;
+    }
+    DexType returnType = originalMethod.proto.returnType;
+    DexType newReturnType =
+        appView.rewritePrefix.hasRewrittenType(returnType)
+            ? this.vivifiedTypeFor(returnType)
+            : returnType;
+    DexProto newProto = factory.createProto(newReturnType, newParameters);
+    return factory.createMethod(holder, newProto, originalMethod.name);
+  }
+
   public void generateWrappers(
       DexApplication.Builder<?> builder, IRConverter irConverter, ExecutorService executorService)
       throws ExecutionException {
     wrapperSynthesizor.finalizeWrappers(builder, irConverter, executorService);
+    for (DexClass dexClass : callBackMethods.keySet()) {
+      // TODO(b/134732760): add the methods in the root set.
+      List<DexEncodedMethod> dexEncodedMethods = callBackMethods.get(dexClass);
+      dexClass.appendVirtualMethods(dexEncodedMethods);
+      irConverter.optimizeSynthesizedMethodsConcurrently(dexEncodedMethods, executorService);
+    }
   }
 
   private void warnInvalidInvoke(DexType type, DexMethod invokedMethod, String debugString) {
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/DesugaredLibraryWrapperSynthesizer.java b/src/main/java/com/android/tools/r8/ir/desugar/DesugaredLibraryWrapperSynthesizer.java
index dbdc087..036e8b4 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/DesugaredLibraryWrapperSynthesizer.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/DesugaredLibraryWrapperSynthesizer.java
@@ -19,7 +19,6 @@
 import com.android.tools.r8.graph.DexLibraryClass;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DexTypeList;
 import com.android.tools.r8.graph.FieldAccessFlags;
@@ -174,7 +173,7 @@
 
   public DexProgramClass generateTypeWrapper(DexClass dexClass, DexType typeWrapperType) {
     DexType type = dexClass.type;
-    DexEncodedField wrapperField = synthetizeWrappedValueField(typeWrapperType, type);
+    DexEncodedField wrapperField = synthesizeWrappedValueField(typeWrapperType, type);
     return synthesizeWrapper(
         converter.vivifiedTypeFor(type),
         dexClass,
@@ -186,7 +185,7 @@
       DexClass dexClass, DexType vivifiedTypeWrapperType) {
     DexType type = dexClass.type;
     DexEncodedField wrapperField =
-        synthetizeWrappedValueField(vivifiedTypeWrapperType, converter.vivifiedTypeFor(type));
+        synthesizeWrappedValueField(vivifiedTypeWrapperType, converter.vivifiedTypeFor(type));
     return synthesizeWrapper(
         type,
         dexClass,
@@ -220,7 +219,7 @@
         DexEncodedField.EMPTY_ARRAY, // No static fields.
         new DexEncodedField[] {wrapperField},
         new DexEncodedMethod[] {
-          synthetizeConstructor(wrapperField.field)
+          synthesizeConstructor(wrapperField.field)
         }, // Conversions methods will be added later.
         virtualMethods,
         factory.getSkipNameValidationForTesting(),
@@ -281,7 +280,8 @@
       DexClass holderClass = appView.definitionFor(dexEncodedMethod.method.holder);
       assert holderClass != null;
       DexMethod methodToInstall =
-          methodWithVivifiedTypeInSignature(dexEncodedMethod.method, wrapperField.field.holder);
+          converter.methodWithVivifiedTypeInSignature(
+              dexEncodedMethod.method, wrapperField.field.holder);
       CfCode cfCode =
           new APIConverterWrapperCfCodeProvider(
                   appView,
@@ -290,6 +290,7 @@
                   converter,
                   holderClass.isInterface())
               .generateCfCode();
+
       DexEncodedMethod newDexEncodedMethod =
           newSynthesizedMethod(methodToInstall, dexEncodedMethod, cfCode);
       generatedMethods.add(newDexEncodedMethod);
@@ -297,25 +298,7 @@
     return generatedMethods.toArray(DexEncodedMethod.EMPTY_ARRAY);
   }
 
-  private DexMethod methodWithVivifiedTypeInSignature(DexMethod originalMethod, DexType holder) {
-    DexType[] newParameters = originalMethod.proto.parameters.values.clone();
-    int index = 0;
-    for (DexType param : originalMethod.proto.parameters.values) {
-      if (appView.rewritePrefix.hasRewrittenType(param)) {
-        newParameters[index] = converter.vivifiedTypeFor(param);
-      }
-      index++;
-    }
-    DexType returnType = originalMethod.proto.returnType;
-    DexType newReturnType =
-        appView.rewritePrefix.hasRewrittenType(returnType)
-            ? converter.vivifiedTypeFor(returnType)
-            : returnType;
-    DexProto newProto = factory.createProto(newReturnType, newParameters);
-    return factory.createMethod(holder, newProto, originalMethod.name);
-  }
-
-  private DexEncodedMethod newSynthesizedMethod(
+  DexEncodedMethod newSynthesizedMethod(
       DexMethod methodToInstall, DexEncodedMethod template, Code code) {
     MethodAccessFlags newFlags = template.accessFlags.copy();
     assert newFlags.isPublic();
@@ -366,7 +349,7 @@
     return implementedMethods;
   }
 
-  private DexEncodedField synthetizeWrappedValueField(DexType holder, DexType fieldType) {
+  private DexEncodedField synthesizeWrappedValueField(DexType holder, DexType fieldType) {
     DexField field = factory.createField(holder, fieldType, factory.wrapperFieldName);
     // Field is package private to be accessible from convert methods without a getter.
     FieldAccessFlags fieldAccessFlags =
@@ -374,7 +357,7 @@
     return new DexEncodedField(field, fieldAccessFlags, DexAnnotationSet.empty(), null);
   }
 
-  private DexEncodedMethod synthetizeConstructor(DexField field) {
+  private DexEncodedMethod synthesizeConstructor(DexField field) {
     DexMethod method =
         factory.createMethod(
             field.holder,
@@ -453,7 +436,7 @@
     Pair<DexType, DexProgramClass> reverse = vivifiedTypeWrappers.get(type);
     assert reverse == null || reverse.getSecond() != null;
     synthesizedClass.addDirectMethod(
-        synthetizeConversionMethod(
+        synthesizeConversionMethod(
             synthesizedClass.type,
             type,
             converter.vivifiedTypeFor(type),
@@ -463,14 +446,14 @@
   private void generateVivifiedTypeConversions(DexType type, DexProgramClass synthesizedClass) {
     Pair<DexType, DexProgramClass> reverse = typeWrappers.get(type);
     synthesizedClass.addDirectMethod(
-        synthetizeConversionMethod(
+        synthesizeConversionMethod(
             synthesizedClass.type,
             converter.vivifiedTypeFor(type),
             type,
             reverse == null ? null : reverse.getSecond()));
   }
 
-  private DexEncodedMethod synthetizeConversionMethod(
+  private DexEncodedMethod synthesizeConversionMethod(
       DexType holder, DexType argType, DexType returnType, DexClass reverseWrapperClassOrNull) {
     DexMethod method =
         factory.createMethod(
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/PrefixRewritingMapper.java b/src/main/java/com/android/tools/r8/ir/desugar/PrefixRewritingMapper.java
index 244dfaf..1549b58 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/PrefixRewritingMapper.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/PrefixRewritingMapper.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexString;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.utils.DescriptorUtils;
@@ -29,6 +30,18 @@
     return rewrittenType(type) != null;
   }
 
+  public boolean hasRewrittenTypeInSignature(DexProto proto) {
+    if (hasRewrittenType(proto.returnType)) {
+      return true;
+    }
+    for (DexType paramType : proto.parameters.values) {
+      if (hasRewrittenType(paramType)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   public abstract boolean isRewriting();
 
   public static class DesugarPrefixRewritingMapper extends PrefixRewritingMapper {
diff --git a/src/main/java/com/android/tools/r8/ir/synthetic/DesugaredLibraryAPIConversionCfCodeProvider.java b/src/main/java/com/android/tools/r8/ir/synthetic/DesugaredLibraryAPIConversionCfCodeProvider.java
index d61ee3e..2a1d196 100644
--- a/src/main/java/com/android/tools/r8/ir/synthetic/DesugaredLibraryAPIConversionCfCodeProvider.java
+++ b/src/main/java/com/android/tools/r8/ir/synthetic/DesugaredLibraryAPIConversionCfCodeProvider.java
@@ -154,7 +154,8 @@
         DexField wrapperField,
         DesugaredLibraryAPIConverter converter,
         boolean itfCall) {
-      super(appView, wrapperField.holder);
+      //  Var wrapperField is null if should forward to receiver.
+      super(appView, wrapperField == null ? forwardMethod.holder : wrapperField.holder);
       this.forwardMethod = forwardMethod;
       this.wrapperField = wrapperField;
       this.converter = converter;
@@ -168,8 +169,13 @@
       // Wrapped value is a type. Method uses vivifiedTypes as external. Forward method should
       // use types.
 
-      instructions.add(new CfLoad(ValueType.fromDexType(wrapperField.holder), 0));
-      instructions.add(new CfFieldInstruction(Opcodes.GETFIELD, wrapperField, wrapperField));
+      // Var wrapperField is null if should forward to receiver.
+      if (wrapperField == null) {
+        instructions.add(new CfLoad(ValueType.fromDexType(forwardMethod.holder), 0));
+      } else {
+        instructions.add(new CfLoad(ValueType.fromDexType(wrapperField.holder), 0));
+        instructions.add(new CfFieldInstruction(Opcodes.GETFIELD, wrapperField, wrapperField));
+      }
       int stackIndex = 1;
       for (DexType param : forwardMethod.proto.parameters.values) {
         instructions.add(new CfLoad(ValueType.fromDexType(param), stackIndex));
diff --git a/src/test/java/com/android/tools/r8/desugar/corelib/conversionTests/CallBackConversionTest.java b/src/test/java/com/android/tools/r8/desugar/corelib/conversionTests/CallBackConversionTest.java
new file mode 100644
index 0000000..f42a03d
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/desugar/corelib/conversionTests/CallBackConversionTest.java
@@ -0,0 +1,91 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.desugar.corelib.conversionTests;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.assertTrue;
+
+import com.android.tools.r8.TestRuntime.DexRuntime;
+import com.android.tools.r8.ToolHelper.DexVm;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import java.nio.file.Path;
+import java.util.List;
+import java.util.function.Consumer;
+import org.junit.Test;
+
+public class CallBackConversionTest extends APIConversionTestBase {
+
+  @Test
+  public void testCallBack() throws Exception {
+    Path customLib = testForD8().addProgramClasses(CustomLibClass.class).compile().writeToZip();
+    testForD8()
+        .setMinApi(AndroidApiLevel.B)
+        .addProgramClasses(Impl.class)
+        .addLibraryClasses(CustomLibClass.class)
+        .enableCoreLibraryDesugaring(AndroidApiLevel.B)
+        .compile()
+        .inspect(
+            i -> {
+              // foo(j$) and foo(java)
+              List<FoundMethodSubject> virtualMethods = i.clazz(Impl.class).virtualMethods();
+              assertEquals(2, virtualMethods.size());
+              assertTrue(
+                  virtualMethods.stream()
+                      .anyMatch(
+                          m ->
+                              m.getMethod()
+                                  .method
+                                  .proto
+                                  .parameters
+                                  .values[0]
+                                  .toString()
+                                  .equals("j$.util.function.Consumer")));
+              assertTrue(
+                  virtualMethods.stream()
+                      .anyMatch(
+                          m ->
+                              m.getMethod()
+                                  .method
+                                  .proto
+                                  .parameters
+                                  .values[0]
+                                  .toString()
+                                  .equals("java.util.function.Consumer")));
+            })
+        .addDesugaredCoreLibraryRunClassPath(
+            this::buildDesugaredLibraryWithConversionExtension, AndroidApiLevel.B)
+        .addRunClasspathFiles(customLib)
+        .run(new DexRuntime(DexVm.ART_9_0_0_HOST), Impl.class)
+        .assertSuccessWithOutput(StringUtils.lines("0", "1", "0", "1"));
+  }
+
+  static class Impl extends CustomLibClass {
+
+    public int foo(Consumer<Object> o) {
+      o.accept(0);
+      return 1;
+    }
+
+    public static void main(String[] args) {
+      Impl impl = new Impl();
+      // Call foo through java parameter.
+      System.out.println(CustomLibClass.callFoo(impl, System.out::println));
+      // Call foo through j$ parameter.
+      System.out.println(impl.foo(System.out::println));
+    }
+  }
+
+  abstract static class CustomLibClass {
+
+    public abstract int foo(Consumer<Object> consumer);
+
+    @SuppressWarnings({"UnusedReturnValue", "WeakerAccess"})
+    public static int callFoo(CustomLibClass object, Consumer<Object> consumer) {
+      return object.foo(consumer);
+    }
+  }
+}