Desugared library keep rules generation

- The application writer now also generates
  keep rules for the desugared library used
- A temporary testing consumer is available
  to test and use the keep rules

Bug: 134732760
Change-Id: Idc3a919cc7a819ffb86af30852bd462aaccb4c1c
diff --git a/src/main/java/com/android/tools/r8/dex/CodeToKeep.java b/src/main/java/com/android/tools/r8/dex/CodeToKeep.java
new file mode 100644
index 0000000..c836e98
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/dex/CodeToKeep.java
@@ -0,0 +1,155 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.dex;
+
+import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.naming.NamingLens;
+import com.android.tools.r8.utils.DescriptorUtils;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.Pair;
+import com.google.common.collect.Sets;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+
+public abstract class CodeToKeep {
+
+  static CodeToKeep createCodeToKeep(InternalOptions options, NamingLens namingLens) {
+    if (!namingLens.hasPrefixRewritingLogic() || options.coreLibraryCompilation) {
+      return new NopCodeToKeep();
+    }
+    return new DesugaredLibraryCodeToKeep(namingLens);
+  }
+
+  abstract void recordMethod(DexMethod method);
+
+  abstract void recordField(DexField field);
+
+  abstract boolean isNop();
+
+  abstract void generateKeepRules(InternalOptions options);
+
+  public static class DesugaredLibraryCodeToKeep extends CodeToKeep {
+
+    private final NamingLens namingLens;
+    private final Map<DexType, Pair<Set<DexField>, Set<DexMethod>>> toKeep =
+        new ConcurrentHashMap<>();
+
+    public DesugaredLibraryCodeToKeep(NamingLens namingLens) {
+      this.namingLens = namingLens;
+    }
+
+    private boolean shouldKeep(DexType type) {
+      return namingLens.prefixRewrittenType(type) != null;
+    }
+
+    @Override
+    void recordMethod(DexMethod method) {
+      if (shouldKeep(method.holder)) {
+        keepClass(method.holder);
+        toKeep.get(method.holder).getSecond().add(method);
+      }
+      if (shouldKeep(method.proto.returnType)) {
+        keepClass(method.proto.returnType);
+      }
+      for (DexType type : method.proto.parameters.values) {
+        if (shouldKeep(type)) {
+          keepClass(type);
+        }
+      }
+    }
+
+    @Override
+    void recordField(DexField field) {
+      if (shouldKeep(field.holder)) {
+        keepClass(field.holder);
+        toKeep.get(field.holder).getFirst().add(field);
+      }
+      if (shouldKeep(field.type)) {
+        keepClass(field.type);
+      }
+    }
+
+    private void keepClass(DexType type) {
+      toKeep.putIfAbsent(
+          type, new Pair<>(Sets.newConcurrentHashSet(), Sets.newConcurrentHashSet()));
+    }
+
+    @Override
+    boolean isNop() {
+      return false;
+    }
+
+    private String convertType(DexType type) {
+      DexString rewriteType = namingLens.prefixRewrittenType(type);
+      DexString descriptor = rewriteType != null ? rewriteType : type.descriptor;
+      return DescriptorUtils.descriptorToJavaType(descriptor.toString());
+    }
+
+    @Override
+    void generateKeepRules(InternalOptions options) {
+      // TODO(b/134734081): Stream the consumer instead of building the String.
+      StringBuilder sb = new StringBuilder();
+      String cr = System.lineSeparator();
+      for (DexType type : toKeep.keySet()) {
+        Set<DexField> fieldsToKeep = toKeep.get(type).getFirst();
+        Set<DexMethod> methodsToKeep = toKeep.get(type).getSecond();
+        sb.append("-keep class ").append(convertType(type));
+        if (fieldsToKeep.isEmpty() && methodsToKeep.isEmpty()) {
+          sb.append(cr);
+          continue;
+        }
+        sb.append(" {").append(cr);
+        for (DexField field : fieldsToKeep) {
+          sb.append("    ")
+              .append(convertType(type))
+              .append(" ")
+              .append(field.name)
+              .append(";")
+              .append(cr);
+        }
+        for (DexMethod method : methodsToKeep) {
+          sb.append("    ")
+              .append(convertType(method.proto.returnType))
+              .append(" ")
+              .append(method.name)
+              .append("(");
+          for (int i = 0; i < method.getArity(); i++) {
+            if (i != 0) {
+              sb.append(", ");
+            }
+            sb.append(convertType(method.proto.parameters.values[i]));
+          }
+          sb.append(");").append(cr);
+        }
+        sb.append("}").append(cr);
+      }
+      options.testing.desugaredLibraryKeepRuleConsumer.accept(sb.toString(), options.reporter);
+    }
+  }
+
+  public static class NopCodeToKeep extends CodeToKeep {
+
+    @Override
+    void recordMethod(DexMethod method) {}
+
+    @Override
+    void recordField(DexField field) {}
+
+    @Override
+    boolean isNop() {
+      return true;
+    }
+
+    @Override
+    void generateKeepRules(InternalOptions options) {
+      throw new Unreachable("Has no keep rules to generate");
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/dex/DexOutputBuffer.java b/src/main/java/com/android/tools/r8/dex/DexOutputBuffer.java
index 2d1da51..3075b16 100644
--- a/src/main/java/com/android/tools/r8/dex/DexOutputBuffer.java
+++ b/src/main/java/com/android/tools/r8/dex/DexOutputBuffer.java
@@ -6,6 +6,8 @@
 import com.android.tools.r8.ByteBufferProvider;
 import com.android.tools.r8.code.Instruction;
 import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.ObjectToOffsetMapping;
 import com.android.tools.r8.utils.EncodedValueUtils;
 import com.android.tools.r8.utils.LebUtils;
@@ -90,7 +92,8 @@
     return EncodedValueUtils.putUnsigned(this, value, expectedSize);
   }
 
-  public void putInstructions(Instruction[] insns, ObjectToOffsetMapping mapping) {
+  public void putInstructions(
+      Instruction[] insns, ObjectToOffsetMapping mapping, CodeToKeep desugaredLibraryCodeToKeep) {
     int size = 0;
     for (Instruction insn : insns) {
       size += insn.getSize();
@@ -99,7 +102,16 @@
     assert byteBuffer.position() % 2 == 0;
     ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
     for (int i = 0; i < insns.length; i++) {
-      insns[i].write(shortBuffer, mapping);
+      Instruction insn = insns[i];
+      DexMethod method = insn.getMethod();
+      DexField field = insn.getField();
+      if (field != null) {
+        assert method == null;
+        desugaredLibraryCodeToKeep.recordField(field);
+      } else if (method != null) {
+        desugaredLibraryCodeToKeep.recordMethod(method);
+      }
+      insn.write(shortBuffer, mapping);
     }
     byteBuffer.position(byteBuffer.position() + shortBuffer.position() * Short.BYTES);
   }
diff --git a/src/main/java/com/android/tools/r8/dex/FileWriter.java b/src/main/java/com/android/tools/r8/dex/FileWriter.java
index 8e0593d..34b12ea 100644
--- a/src/main/java/com/android/tools/r8/dex/FileWriter.java
+++ b/src/main/java/com/android/tools/r8/dex/FileWriter.java
@@ -75,6 +75,7 @@
 
   /** Simple pair of a byte buffer and its written length. */
   public static class ByteBufferResult {
+
     // Ownership of the buffer is transferred to the receiver of this result structure.
     public final CompatByteBuffer buffer;
     public final int length;
@@ -92,6 +93,7 @@
   private final NamingLens namingLens;
   private final DexOutputBuffer dest;
   private final MixedSectionOffsets mixedSectionOffsets;
+  private final CodeToKeep desugaredLibraryCodeToKeep;
 
   public FileWriter(
       ByteBufferProvider provider,
@@ -107,6 +109,7 @@
     this.namingLens = namingLens;
     this.dest = new DexOutputBuffer(provider);
     this.mixedSectionOffsets = new MixedSectionOffsets(options, codeMapping);
+    this.desugaredLibraryCodeToKeep = CodeToKeep.createCodeToKeep(options, namingLens);
   }
 
   public static void writeEncodedAnnotation(
@@ -217,6 +220,13 @@
     writeSignature(layout);
     writeChecksum(layout);
 
+    // A consumer can manage the generated keep rules (testing only).
+    if (options.testing.desugaredLibraryKeepRuleConsumer != null
+        && !desugaredLibraryCodeToKeep.isNop()) {
+      assert !options.coreLibraryCompilation;
+      desugaredLibraryCodeToKeep.generateKeepRules(options);
+    }
+
     // Wrap backing buffer with actual length.
     return new ByteBufferResult(dest.stealByteBuffer(), layout.getEndOfFile());
   }
@@ -435,7 +445,8 @@
     dest.putInt(
         clazz.sourceFile == null ? Constants.NO_INDEX : mapping.getOffsetFor(clazz.sourceFile));
     dest.putInt(mixedSectionOffsets.getOffsetForAnnotationsDirectory(clazz));
-    dest.putInt(clazz.hasMethodsOrFields() ? mixedSectionOffsets.getOffsetFor(clazz) : Constants.NO_OFFSET);
+    dest.putInt(
+        clazz.hasMethodsOrFields() ? mixedSectionOffsets.getOffsetFor(clazz) : Constants.NO_OFFSET);
     dest.putInt(mixedSectionOffsets.getOffsetFor(clazz.getStaticValues()));
   }
 
@@ -456,7 +467,7 @@
     int insnSizeOffset = dest.position();
     dest.forward(4);
     // Write instruction stream.
-    dest.putInstructions(code.instructions, mapping);
+    dest.putInstructions(code.instructions, mapping, desugaredLibraryCodeToKeep);
     // Compute size and do the backward/forward dance to write the size at the beginning.
     int insnSize = dest.position() - insnSizeOffset - 4;
     dest.rewind(insnSize + 4);
@@ -582,6 +593,7 @@
       dest.putUleb128(nextOffset - currentOffset);
       currentOffset = nextOffset;
       dest.putUleb128(field.accessFlags.getAsDexAccessFlags());
+      desugaredLibraryCodeToKeep.recordField(field.field);
     }
   }
 
@@ -595,6 +607,7 @@
       currentOffset = nextOffset;
       dest.putUleb128(method.accessFlags.getAsDexAccessFlags());
       DexCode code = codeMapping.getCode(method);
+      desugaredLibraryCodeToKeep.recordMethod(method.method);
       if (code == null) {
         assert method.shouldNotHaveCode();
         dest.putUleb128(0);
@@ -1000,9 +1013,8 @@
 
   /**
    * Encapsulates information on the offsets of items in the sections of the mixed data part of the
-   * DEX file.
-   * Initially, items are collected using the {@link MixedSectionCollection} traversal and all
-   * offsets are unset. When writing a section, the offsets of the written items are stored.
+   * DEX file. Initially, items are collected using the {@link MixedSectionCollection} traversal and
+   * all offsets are unset. When writing a section, the offsets of the written items are stored.
    * These offsets are then used to resolve cross-references between items from different sections
    * into a file offset.
    */
diff --git a/src/main/java/com/android/tools/r8/naming/NamingLens.java b/src/main/java/com/android/tools/r8/naming/NamingLens.java
index 54a7e31..2dbf703 100644
--- a/src/main/java/com/android/tools/r8/naming/NamingLens.java
+++ b/src/main/java/com/android/tools/r8/naming/NamingLens.java
@@ -99,6 +99,14 @@
     return dexItemFactory.createType(lookupDescriptor(type));
   }
 
+  public boolean hasPrefixRewritingLogic() {
+    return false;
+  }
+
+  public DexString prefixRewrittenType(DexType type) {
+    return null;
+  }
+
   public static NamingLens getIdentityLens() {
     return new IdentityLens();
   }
diff --git a/src/main/java/com/android/tools/r8/naming/PrefixRewritingNamingLens.java b/src/main/java/com/android/tools/r8/naming/PrefixRewritingNamingLens.java
index f8e2e4e..ed5c1bc 100644
--- a/src/main/java/com/android/tools/r8/naming/PrefixRewritingNamingLens.java
+++ b/src/main/java/com/android/tools/r8/naming/PrefixRewritingNamingLens.java
@@ -78,6 +78,16 @@
   }
 
   @Override
+  public boolean hasPrefixRewritingLogic() {
+    return true;
+  }
+
+  @Override
+  public DexString prefixRewrittenType(DexType type) {
+    return classRenaming.get(type);
+  }
+
+  @Override
   public DexString lookupDescriptor(DexType type) {
     return classRenaming.getOrDefault(type, type.descriptor);
   }
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index e55b769..c7f8199 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -945,7 +945,7 @@
     public boolean disallowLoadStoreOptimization = false;
     public boolean enableNarrowingChecksInD8 = false;
     public Consumer<IRCode> irModifier = null;
-
+    public StringConsumer desugaredLibraryKeepRuleConsumer = null;
     // TODO(b/129458850) When fixed, remove this and change all usages to "true".
     public boolean enableStatefulLambdaCreateInstanceMethod = false;
 
diff --git a/src/test/java/com/android/tools/r8/desugar/corelib/EmulateLibraryInterfaceTest.java b/src/test/java/com/android/tools/r8/desugar/corelib/EmulateLibraryInterfaceTest.java
index 0281323..ee5a8d0 100644
--- a/src/test/java/com/android/tools/r8/desugar/corelib/EmulateLibraryInterfaceTest.java
+++ b/src/test/java/com/android/tools/r8/desugar/corelib/EmulateLibraryInterfaceTest.java
@@ -20,6 +20,7 @@
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.ir.desugar.InterfaceMethodRewriter;
 import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.android.tools.r8.utils.codeinspector.FoundClassSubject;
@@ -100,55 +101,52 @@
             .filter(instr -> instr.isInvokeInterface() || instr.isInvokeStatic())
             .collect(Collectors.toList());
     assertEquals(23, invokes.size());
-    assertTrue(invokes.get(0).isInvokeStatic());
-    assertTrue(invokes.get(0).toString().contains("Set$-EL;->spliterator"));
-    assertTrue(invokes.get(1).isInvokeStatic());
-    assertTrue(invokes.get(1).toString().contains("List$-EL;->spliterator"));
-    assertTrue(invokes.get(2).isInvokeStatic());
-    assertTrue(invokes.get(2).toString().contains("Collection$-EL;->stream"));
-    assertTrue(invokes.get(3).isInvokeInterface());
-    assertTrue(invokes.get(3).toString().contains("Set;->iterator"));
-    assertTrue(invokes.get(4).isInvokeStatic());
-    assertTrue(invokes.get(4).toString().contains("Collection$-EL;->stream"));
-    assertTrue(invokes.get(5).isInvokeStatic());
-    assertTrue(invokes.get(5).toString().contains("DesugarLinkedHashSet;->spliterator"));
-    assertTrue(invokes.get(9).isInvokeInterface());
-    assertTrue(invokes.get(9).toString().contains("Iterator;->remove"));
-    assertTrue(invokes.get(10).isInvokeStatic());
-    assertTrue(invokes.get(10).toString().contains("DesugarArrays;->spliterator"));
-    assertTrue(invokes.get(11).isInvokeStatic());
-    assertTrue(invokes.get(11).toString().contains("DesugarArrays;->spliterator"));
-    assertTrue(invokes.get(12).isInvokeStatic());
-    assertTrue(invokes.get(12).toString().contains("DesugarArrays;->stream"));
-    assertTrue(invokes.get(13).isInvokeStatic());
-    assertTrue(invokes.get(13).toString().contains("DesugarArrays;->stream"));
-    assertTrue(invokes.get(14).isInvokeStatic());
-    assertTrue(invokes.get(14).toString().contains("Collection$-EL;->stream"));
-    assertTrue(invokes.get(15).isInvokeStatic());
-    assertTrue(invokes.get(15).toString().contains("IntStream$-CC;->range"));
-    assertTrue(invokes.get(17).isInvokeStatic());
-    assertTrue(invokes.get(17).toString().contains("Comparator$-CC;->comparingInt"));
-    assertTrue(invokes.get(18).isInvokeStatic());
-    assertTrue(invokes.get(18).toString().contains("List$-EL;->sort"));
-    assertTrue(invokes.get(20).isInvokeStatic());
-    assertTrue(invokes.get(20).toString().contains("Comparator$-CC;->comparingInt"));
-    assertTrue(invokes.get(21).isInvokeStatic());
-    assertTrue(invokes.get(21).toString().contains("List$-EL;->sort"));
-    assertTrue(invokes.get(22).isInvokeStatic());
-    assertTrue(invokes.get(22).toString().contains("Collection$-EL;->stream"));
+    assertInvokeStaticMatching(invokes, 0, "Set$-EL;->spliterator");
+    assertInvokeStaticMatching(invokes, 1, "List$-EL;->spliterator");
+    assertInvokeStaticMatching(invokes, 2, "Collection$-EL;->stream");
+    assertInvokeInterfaceMatching(invokes, 3, "Set;->iterator");
+    assertInvokeStaticMatching(invokes, 4, "Collection$-EL;->stream");
+    assertInvokeStaticMatching(invokes, 5, "DesugarLinkedHashSet;->spliterator");
+    assertInvokeInterfaceMatching(invokes, 9, "Iterator;->remove");
+    assertInvokeStaticMatching(invokes, 10, "DesugarArrays;->spliterator");
+    assertInvokeStaticMatching(invokes, 11, "DesugarArrays;->spliterator");
+    assertInvokeStaticMatching(invokes, 12, "DesugarArrays;->stream");
+    assertInvokeStaticMatching(invokes, 13, "DesugarArrays;->stream");
+    assertInvokeStaticMatching(invokes, 14, "Collection$-EL;->stream");
+    assertInvokeStaticMatching(invokes, 15, "IntStream$-CC;->range");
+    assertInvokeStaticMatching(invokes, 17, "Comparator$-CC;->comparingInt");
+    assertInvokeStaticMatching(invokes, 18, "List$-EL;->sort");
+    assertInvokeStaticMatching(invokes, 20, "Comparator$-CC;->comparingInt");
+    assertInvokeStaticMatching(invokes, 21, "List$-EL;->sort");
+    assertInvokeStaticMatching(invokes, 22, "Collection$-EL;->stream");
     // TODO (b/134732760): Support Java 9 Stream APIs
     // assertTrue(invokes.get(17).isInvokeStatic());
     // assertTrue(invokes.get(17).toString().contains("Stream$-CC;->iterate"));
   }
 
+  private void assertInvokeInterfaceMatching(List<InstructionSubject> invokes, int i, String s) {
+    assertTrue(invokes.get(i).isInvokeInterface());
+    assertTrue(invokes.get(i).toString().contains(s));
+  }
+
+  private void assertInvokeStaticMatching(List<InstructionSubject> invokes, int i, String s) {
+    assertTrue(invokes.get(i).isInvokeStatic());
+    assertTrue(invokes.get(i).toString().contains(s));
+  }
+
   @Test
   public void testProgram() throws Exception {
     Assume.assumeTrue("No desugaring for high API levels", requiresCoreLibDesugaring(parameters));
+    String[] keepRulesHolder = new String[] {""};
     D8TestRunResult d8TestRunResult =
         testForD8()
             .addProgramFiles(Paths.get(ToolHelper.EXAMPLES_JAVA9_BUILD_DIR + "stream.jar"))
             .addLibraryFiles(ToolHelper.getAndroidJar(AndroidApiLevel.P))
             .setMinApi(parameters.getApiLevel())
+            .addOptionsModification(
+                options ->
+                    options.testing.desugaredLibraryKeepRuleConsumer =
+                        (string, handler) -> keepRulesHolder[0] += string)
             .enableCoreLibraryDesugaring()
             .compile()
             .inspect(this::checkRewrittenInvokes)
@@ -156,6 +154,7 @@
             .run(parameters.getRuntime(), "stream.TestClass")
             .assertSuccess();
     assertLines2By2Correct(d8TestRunResult.getStdOut());
+    assertGeneratedKeepRulesAreCorrect(keepRulesHolder[0]);
     String stdErr = d8TestRunResult.getStdErr();
     if (parameters.getRuntime().asDex().getVm().isOlderThanOrEqual(DexVm.ART_4_4_4_HOST)) {
       // Flaky: There might be a missing method on lambda deserialization.
@@ -166,4 +165,39 @@
       assertFalse(stdErr.contains("Could not find method"));
     }
   }
+
+  private void assertGeneratedKeepRulesAreCorrect(String keepRules) {
+    String expectedResult =
+        StringUtils.lines(
+            "-keep class j$.util.List$-EL {",
+            "    void sort(java.util.List, java.util.Comparator);",
+            "    j$.util.Spliterator spliterator(java.util.List);",
+            "}",
+            "-keep class j$.util.Collection$-EL {",
+            "    j$.util.stream.Stream stream(java.util.Collection);",
+            "}",
+            "-keep class j$.util.stream.IntStream$-CC {",
+            "    j$.util.stream.IntStream range(int, int);",
+            "}",
+            "-keep class j$.util.Comparator$-CC {",
+            "    java.util.Comparator comparingInt(j$.util.function.ToIntFunction);",
+            "}",
+            "-keep class j$.util.Set$-EL {",
+            "    j$.util.Spliterator spliterator(java.util.Set);",
+            "}",
+            "-keep class j$.util.DesugarArrays {",
+            "    j$.util.Spliterator spliterator(java.lang.Object[]);",
+            "    j$.util.stream.Stream stream(java.lang.Object[], int, int);",
+            "    j$.util.stream.Stream stream(java.lang.Object[]);",
+            "    j$.util.Spliterator spliterator(java.lang.Object[], int, int);",
+            "}",
+            "-keep class j$.util.stream.IntStream",
+            "-keep class j$.util.DesugarLinkedHashSet {",
+            "    j$.util.Spliterator spliterator(java.util.LinkedHashSet);",
+            "}",
+            "-keep class j$.util.stream.Stream",
+            "-keep class j$.util.Spliterator",
+            "-keep class j$.util.function.ToIntFunction");
+    assertEquals(expectedResult, keepRules);
+  }
 }