Remove parameter annotations in uninstantiated type optimization

Bug: 131718819, 131735725
Change-Id: I4f51e999e1332afd99a96a49534aa62ccec39790
diff --git a/src/main/java/com/android/tools/r8/graph/GraphLense.java b/src/main/java/com/android/tools/r8/graph/GraphLense.java
index 95ba121..5de4ae5 100644
--- a/src/main/java/com/android/tools/r8/graph/GraphLense.java
+++ b/src/main/java/com/android/tools/r8/graph/GraphLense.java
@@ -7,6 +7,7 @@
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Invoke.Type;
 import com.android.tools.r8.ir.code.Position;
+import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.IteratorUtils;
 import com.google.common.collect.BiMap;
 import com.google.common.collect.HashBiMap;
@@ -29,6 +30,7 @@
 import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
+import java.util.function.Consumer;
 import java.util.function.Function;
 
 /**
@@ -217,6 +219,18 @@
         }
         return new RemovedArgumentsInfo(newRemovedArguments);
       }
+
+      public Consumer<DexEncodedMethod.Builder> createParameterAnnotationsRemover(
+          DexEncodedMethod method) {
+        if (numberOfRemovedArguments() > 0 && !method.parameterAnnotationsList.isEmpty()) {
+          return builder -> {
+            int firstArgumentIndex = BooleanUtils.intValue(!method.isStatic());
+            builder.removeParameterAnnotations(
+                oldIndex -> isArgumentRemoved(oldIndex + firstArgumentIndex));
+          };
+        }
+        return null;
+      }
     }
 
     private static final RewrittenPrototypeDescription none = new RewrittenPrototypeDescription();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/UninstantiatedTypeOptimization.java b/src/main/java/com/android/tools/r8/ir/optimize/UninstantiatedTypeOptimization.java
index 8f2ef40..0dc5eae 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/UninstantiatedTypeOptimization.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/UninstantiatedTypeOptimization.java
@@ -207,7 +207,11 @@
         // TODO(b/110806787): Can be extended to handle collisions by renaming the given
         // method.
         if (usedSignatures.add(wrapper)) {
-          clazz.setDirectMethod(i, encodedMethod.toTypeSubstitutedMethod(newMethod));
+          clazz.setDirectMethod(
+              i,
+              encodedMethod.toTypeSubstitutedMethod(
+                  newMethod,
+                  removedArgumentsInfo.createParameterAnnotationsRemover(encodedMethod)));
           methodMapping.put(method, newMethod);
           if (removedArgumentsInfo.hasRemovedArguments()) {
             removedArgumentsInfoPerMethod.put(newMethod, removedArgumentsInfo);
@@ -228,6 +232,7 @@
       DexMethod method = encodedMethod.method;
       RewrittenPrototypeDescription prototypeChanges =
           getPrototypeChanges(encodedMethod, DISALLOW_ARGUMENT_REMOVAL);
+      RemovedArgumentsInfo removedArgumentsInfo = prototypeChanges.getRemovedArgumentsInfo();
       DexMethod newMethod = getNewMethodSignature(encodedMethod, prototypeChanges);
       if (newMethod != method) {
         Wrapper<DexMethod> wrapper = equivalence.wrap(newMethod);
@@ -241,7 +246,11 @@
           boolean signatureIsAvailable = usedSignatures.add(wrapper);
           assert signatureIsAvailable;
 
-          clazz.setVirtualMethod(i, encodedMethod.toTypeSubstitutedMethod(newMethod));
+          clazz.setVirtualMethod(
+              i,
+              encodedMethod.toTypeSubstitutedMethod(
+                  newMethod,
+                  removedArgumentsInfo.createParameterAnnotationsRemover(encodedMethod)));
           methodMapping.put(method, newMethod);
         }
       }
@@ -251,6 +260,7 @@
       DexMethod method = encodedMethod.method;
       RewrittenPrototypeDescription prototypeChanges =
           getPrototypeChanges(encodedMethod, DISALLOW_ARGUMENT_REMOVAL);
+      RemovedArgumentsInfo removedArgumentsInfo = prototypeChanges.getRemovedArgumentsInfo();
       DexMethod newMethod = getNewMethodSignature(encodedMethod, prototypeChanges);
       if (newMethod != method) {
         Wrapper<DexMethod> wrapper = equivalence.wrap(newMethod);
@@ -261,7 +271,11 @@
         if (!methodPool.hasSeen(wrapper) && usedSignatures.add(wrapper)) {
           methodPool.seen(wrapper);
 
-          clazz.setVirtualMethod(i, encodedMethod.toTypeSubstitutedMethod(newMethod));
+          clazz.setVirtualMethod(
+              i,
+              encodedMethod.toTypeSubstitutedMethod(
+                  newMethod,
+                  removedArgumentsInfo.createParameterAnnotationsRemover(encodedMethod)));
           methodMapping.put(method, newMethod);
 
           boolean added =
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/UnusedArgumentsCollector.java b/src/main/java/com/android/tools/r8/ir/optimize/UnusedArgumentsCollector.java
index 670d35a..deeaf72 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/UnusedArgumentsCollector.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/UnusedArgumentsCollector.java
@@ -20,7 +20,6 @@
 import com.android.tools.r8.graph.GraphLense.RewrittenPrototypeDescription.RemovedArgumentsInfo;
 import com.android.tools.r8.ir.optimize.MemberPoolCollection.MemberPool;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.MethodSignatureEquivalence;
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.ThreadUtils;
@@ -40,7 +39,6 @@
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
-import java.util.function.Consumer;
 import java.util.stream.Collectors;
 
 public class UnusedArgumentsCollector {
@@ -125,18 +123,6 @@
     return appView.graphLense();
   }
 
-  public static Consumer<DexEncodedMethod.Builder> createParameterAnnotationsRemover(
-      DexEncodedMethod method, RemovedArgumentsInfo unused) {
-    if (unused.numberOfRemovedArguments() > 0 && !method.parameterAnnotationsList.isEmpty()) {
-      return builder -> {
-        int firstArgumentIndex = BooleanUtils.intValue(!method.isStatic());
-        builder.removeParameterAnnotations(
-            oldIndex -> unused.isArgumentRemoved(oldIndex + firstArgumentIndex));
-      };
-    }
-    return null;
-  }
-
   private class UsedSignatures {
 
     private final MethodSignatureEquivalence equivalence = MethodSignatureEquivalence.get();
@@ -183,7 +169,7 @@
       markSignatureAsUsed(newSignature);
 
       return method.toTypeSubstitutedMethod(
-          newSignature, createParameterAnnotationsRemover(method, unused));
+          newSignature, unused.createParameterAnnotationsRemover(method));
     }
   }
 
@@ -220,7 +206,7 @@
         DexEncodedMethod method, DexMethod newSignature, RemovedArgumentsInfo unused) {
       methodPool.seen(equivalence.wrap(newSignature));
       return method.toTypeSubstitutedMethod(
-          newSignature, createParameterAnnotationsRemover(method, unused));
+          newSignature, unused.createParameterAnnotationsRemover(method));
     }
   }
 
diff --git a/src/test/java/com/android/tools/r8/R8TestBuilder.java b/src/test/java/com/android/tools/r8/R8TestBuilder.java
index 14d6d72..177bcf1 100644
--- a/src/test/java/com/android/tools/r8/R8TestBuilder.java
+++ b/src/test/java/com/android/tools/r8/R8TestBuilder.java
@@ -223,10 +223,18 @@
   }
 
   public T enableConstantArgumentAnnotations() {
-    if (!enableConstantArgumentAnnotations) {
-      enableConstantArgumentAnnotations = true;
-      addInternalKeepRules(
-          "-keepconstantarguments class * { @com.android.tools.r8.KeepConstantArguments *; }");
+    return enableConstantArgumentAnnotations(true);
+  }
+
+  public T enableConstantArgumentAnnotations(boolean value) {
+    if (value) {
+      if (!enableConstantArgumentAnnotations) {
+        enableConstantArgumentAnnotations = true;
+        addInternalKeepRules(
+            "-keepconstantarguments class * { @com.android.tools.r8.KeepConstantArguments *; }");
+      }
+    } else {
+      assert !enableConstantArgumentAnnotations;
     }
     return self();
   }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/uninstantiatedtypes/UninstantiatedAnnotatedArgumentsTest.java b/src/test/java/com/android/tools/r8/ir/optimize/uninstantiatedtypes/UninstantiatedAnnotatedArgumentsTest.java
new file mode 100644
index 0000000..c92103e
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/uninstantiatedtypes/UninstantiatedAnnotatedArgumentsTest.java
@@ -0,0 +1,234 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.uninstantiatedtypes;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.KeepConstantArguments;
+import com.android.tools.r8.KeepUnusedArguments;
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.DexAnnotation;
+import com.android.tools.r8.graph.DexAnnotationSet;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import com.google.common.collect.ImmutableList;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class UninstantiatedAnnotatedArgumentsTest extends TestBase {
+
+  private final boolean keepUninstantiatedArguments;
+  private final TestParameters parameters;
+
+  @Parameters(name = "{1}, keep uninstantiated arguments: {0}")
+  public static List<Object[]> params() {
+    return buildParameters(BooleanUtils.values(), getTestParameters().withAllRuntimes().build());
+  }
+
+  public UninstantiatedAnnotatedArgumentsTest(
+      boolean keepUninstantiatedArguments, TestParameters parameters) {
+    this.keepUninstantiatedArguments = keepUninstantiatedArguments;
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(UninstantiatedAnnotatedArgumentsTest.class)
+        .addKeepMainRule(TestClass.class)
+        .addKeepClassRules(Instantiated.class, Uninstantiated.class)
+        .addKeepAttributes("RuntimeVisibleParameterAnnotations")
+        .enableClassInliningAnnotations()
+        .enableConstantArgumentAnnotations(keepUninstantiatedArguments)
+        .enableInliningAnnotations()
+        .enableUnusedArgumentAnnotations()
+        // TODO(b/123060011): Mapping not working in presence of argument removal.
+        .minification(keepUninstantiatedArguments)
+        .setMinApi(parameters.getRuntime())
+        .compile()
+        .inspect(this::verifyOutput)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutput(StringUtils.times(StringUtils.lines("Hello world!"), 6));
+  }
+
+  private void verifyOutput(CodeInspector inspector) {
+    ClassSubject testClassSubject = inspector.clazz(TestClass.class);
+    assertThat(testClassSubject, isPresent());
+
+    ClassSubject instantiatedClassSubject = inspector.clazz(Instantiated.class);
+    assertThat(instantiatedClassSubject, isPresent());
+
+    ClassSubject uninstantiatedClassSubject = inspector.clazz(Uninstantiated.class);
+    assertThat(uninstantiatedClassSubject, isPresent());
+
+    List<MethodSubject> methodSubjects =
+        ImmutableList.of(
+            testClassSubject.uniqueMethodWithName("testRemoveStaticFromStart"),
+            testClassSubject.uniqueMethodWithName("testRemoveStaticFromMiddle"),
+            testClassSubject.uniqueMethodWithName("testRemoveStaticFromEnd"),
+            testClassSubject.uniqueMethodWithName("testRemoveVirtualFromStart"),
+            testClassSubject.uniqueMethodWithName("testRemoveVirtualFromMiddle"),
+            testClassSubject.uniqueMethodWithName("testRemoveVirtualFromEnd"));
+
+    for (MethodSubject methodSubject : methodSubjects) {
+      assertThat(methodSubject, isPresent());
+
+      // TODO(b/131735725): Should also remove arguments from the virtual methods.
+      if (keepUninstantiatedArguments || methodSubject.getOriginalName().contains("Virtual")) {
+        assertEquals(3, methodSubject.getMethod().method.proto.parameters.size());
+        assertEquals(3, methodSubject.getMethod().parameterAnnotationsList.size());
+
+        for (int i = 0; i < 3; ++i) {
+          DexAnnotationSet annotationSet =
+              methodSubject.getMethod().parameterAnnotationsList.get(i);
+          assertEquals(1, annotationSet.annotations.length);
+
+          DexAnnotation annotation = annotationSet.annotations[0];
+          if (i == getPositionOfUnusedArgument(methodSubject)) {
+            assertEquals(
+                uninstantiatedClassSubject.getFinalName(),
+                annotation.annotation.type.toSourceString());
+          } else {
+            assertEquals(
+                instantiatedClassSubject.getFinalName(),
+                annotation.annotation.type.toSourceString());
+          }
+        }
+      } else {
+        assertEquals(2, methodSubject.getMethod().method.proto.parameters.size());
+        assertEquals(2, methodSubject.getMethod().parameterAnnotationsList.size());
+
+        for (int i = 0; i < 2; ++i) {
+          DexAnnotationSet annotationSet =
+              methodSubject.getMethod().parameterAnnotationsList.get(i);
+          assertEquals(1, annotationSet.annotations.length);
+
+          DexAnnotation annotation = annotationSet.annotations[0];
+          assertEquals(
+              instantiatedClassSubject.getFinalName(), annotation.annotation.type.toSourceString());
+        }
+      }
+    }
+  }
+
+  private static int getPositionOfUnusedArgument(MethodSubject methodSubject) {
+    switch (methodSubject.getOriginalName(false)) {
+      case "testRemoveStaticFromStart":
+      case "testRemoveVirtualFromStart":
+        return 0;
+
+      case "testRemoveStaticFromMiddle":
+      case "testRemoveVirtualFromMiddle":
+        return 1;
+
+      case "testRemoveStaticFromEnd":
+      case "testRemoveVirtualFromEnd":
+        return 2;
+
+      default:
+        throw new Unreachable();
+    }
+  }
+
+  @NeverClassInline
+  static class TestClass {
+
+    public static void main(String[] args) {
+      testRemoveStaticFromStart(null, "Hello", " world!");
+      testRemoveStaticFromMiddle("Hello", null, " world!");
+      testRemoveStaticFromEnd("Hello", " world!", null);
+      new TestClass().testRemoveVirtualFromStart(null, "Hello", " world!");
+      new TestClass().testRemoveVirtualFromMiddle("Hello", null, " world!");
+      new TestClass().testRemoveVirtualFromEnd("Hello", " world!", null);
+    }
+
+    @KeepConstantArguments
+    @KeepUnusedArguments
+    @NeverInline
+    static void testRemoveStaticFromStart(
+        @Uninstantiated Dead uninstantiated,
+        @Instantiated String instantiated,
+        @Instantiated String otherInstantiated) {
+      System.out.println(instantiated + otherInstantiated);
+    }
+
+    @KeepConstantArguments
+    @KeepUnusedArguments
+    @NeverInline
+    static void testRemoveStaticFromMiddle(
+        @Instantiated String instantiated,
+        @Uninstantiated Dead uninstantiated,
+        @Instantiated String otherInstantiated) {
+      System.out.println(instantiated + otherInstantiated);
+    }
+
+    @KeepConstantArguments
+    @KeepUnusedArguments
+    @NeverInline
+    static void testRemoveStaticFromEnd(
+        @Instantiated String instantiated,
+        @Instantiated String otherInstantiated,
+        @Uninstantiated Dead uninstantiated) {
+      System.out.println(instantiated + otherInstantiated);
+    }
+
+    @KeepConstantArguments
+    @KeepUnusedArguments
+    @NeverInline
+    void testRemoveVirtualFromStart(
+        @Uninstantiated Dead uninstantiated,
+        @Instantiated String instantiated,
+        @Instantiated String otherInstantiated) {
+      System.out.println(instantiated + otherInstantiated);
+    }
+
+    @KeepConstantArguments
+    @KeepUnusedArguments
+    @NeverInline
+    void testRemoveVirtualFromMiddle(
+        @Instantiated String instantiated,
+        @Uninstantiated Dead uninstantiated,
+        @Instantiated String otherInstantiated) {
+      System.out.println(instantiated + otherInstantiated);
+    }
+
+    @KeepConstantArguments
+    @KeepUnusedArguments
+    @NeverInline
+    void testRemoveVirtualFromEnd(
+        @Instantiated String instantiated,
+        @Instantiated String otherInstantiated,
+        @Uninstantiated Dead uninstantiated) {
+      System.out.println(instantiated + otherInstantiated);
+    }
+  }
+
+  static class Dead {}
+
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.PARAMETER)
+  @interface Instantiated {}
+
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.PARAMETER)
+  @interface Uninstantiated {}
+}