Make EnumSet and EnumMap uses cause enum values() to be pinned.

Bug: b/204939965
Change-Id: I9d7fcfcf6b39bdeca116dfaf3f59cac16b97c496
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 3ac2bd6..586b968 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -625,6 +625,8 @@
   public final DexType javaUtilLoggingLoggerType =
       createStaticallyKnownType("Ljava/util/logging/Logger;");
   public final DexType javaUtilSetType = createStaticallyKnownType("Ljava/util/Set;");
+  public final DexType javaUtilEnumMapType = createStaticallyKnownType("Ljava/util/EnumMap;");
+  public final DexType javaUtilEnumSetType = createStaticallyKnownType("Ljava/util/EnumSet;");
 
   public final DexType androidAppActivity = createStaticallyKnownType("Landroid/app/Activity;");
   public final DexType androidAppFragment = createStaticallyKnownType("Landroid/app/Fragment;");
@@ -742,6 +744,8 @@
   public final JavaUtilLocaleMembers javaUtilLocaleMembers = new JavaUtilLocaleMembers();
   public final JavaUtilLoggingLevelMembers javaUtilLoggingLevelMembers =
       new JavaUtilLoggingLevelMembers();
+  public final JavaUtilEnumMapMembers javaUtilEnumMapMembers = new JavaUtilEnumMapMembers();
+  public final JavaUtilEnumSetMembers javaUtilEnumSetMembers = new JavaUtilEnumSetMembers();
 
   public final List<LibraryMembers> libraryMembersCollection =
       ImmutableList.of(
@@ -1547,6 +1551,29 @@
     }
   }
 
+  public class JavaUtilEnumMapMembers {
+    public final DexMethod constructor =
+        createMethod(javaUtilEnumMapType, createProto(voidType, classType), constructorMethodName);
+  }
+
+  public class JavaUtilEnumSetMembers {
+    private final DexString allOfString = createString("allOf");
+    private final DexString noneOfString = createString("noneOf");
+    private final DexString ofString = createString("of");
+    private final DexString rangeString = createString("range");
+
+    public boolean isFactoryMethod(DexMethod invokedMethod) {
+      if (!invokedMethod.getHolderType().equals(javaUtilEnumSetType)) {
+        return false;
+      }
+      DexString name = invokedMethod.getName();
+      return name.isIdenticalTo(allOfString)
+          || name.isIdenticalTo(noneOfString)
+          || name.isIdenticalTo(ofString)
+          || name.isIdenticalTo(rangeString);
+    }
+  }
+
   public class LongMembers extends BoxedPrimitiveMembers {
 
     public final DexField TYPE = createField(boxedLongType, classType, "TYPE");
@@ -2153,7 +2180,6 @@
           && accessFlags.isFinal();
     }
   }
-
   public class NullPointerExceptionMethods {
 
     public final DexMethod init =
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 9cfbfb7..df58f3d 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -97,7 +97,9 @@
 import com.android.tools.r8.ir.analysis.proto.GeneratedMessageLiteBuilderShrinker;
 import com.android.tools.r8.ir.analysis.proto.ProtoEnqueuerUseRegistry;
 import com.android.tools.r8.ir.analysis.proto.schema.ProtoEnqueuerExtension;
+import com.android.tools.r8.ir.analysis.type.ClassTypeElement;
 import com.android.tools.r8.ir.code.ArrayPut;
+import com.android.tools.r8.ir.code.ConstClass;
 import com.android.tools.r8.ir.code.ConstantValueUtils;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
@@ -1516,6 +1518,11 @@
     MethodResolutionResult resolutionResult =
         handleInvokeOfDirectTarget(invokedMethod, context, reason);
     analyses.traceInvokeDirect(invokedMethod, resolutionResult, context);
+
+    if (invokedMethod.equals(appView.dexItemFactory().javaUtilEnumMapMembers.constructor)) {
+      // EnumMap uses reflection.
+      pendingReflectiveUses.add(context);
+    }
   }
 
   void traceInvokeInterface(
@@ -1570,16 +1577,14 @@
       identifierNameStrings.add(invokedMethod);
       // Revisit the current method to implicitly add -keep rule for items with reflective access.
       pendingReflectiveUses.add(context);
-    }
-    // See comment in handleJavaLangEnumValueOf.
-    if (invokedMethod == dexItemFactory.enumMembers.valueOf) {
+    } else if (invokedMethod == dexItemFactory.enumMembers.valueOf
+        || dexItemFactory.javaUtilEnumSetMembers.isFactoryMethod(invokedMethod)) {
+      // See comment in handleEnumValueOfOrCollectionInstantiation.
       pendingReflectiveUses.add(context);
-    }
-    // Handling of application services.
-    if (dexItemFactory.serviceLoaderMethods.isLoadMethod(invokedMethod)) {
+    } else if (invokedMethod == dexItemFactory.proxyMethods.newProxyInstance) {
       pendingReflectiveUses.add(context);
-    }
-    if (invokedMethod == dexItemFactory.proxyMethods.newProxyInstance) {
+    } else if (dexItemFactory.serviceLoaderMethods.isLoadMethod(invokedMethod)) {
+      // Handling of application services.
       pendingReflectiveUses.add(context);
     }
     markTypeAsLive(invokedMethod.getHolderType(), context);
@@ -5147,8 +5152,10 @@
       handleJavaLangReflectConstructorNewInstance(method, invoke);
       return;
     }
-    if (invokedMethod == dexItemFactory.enumMembers.valueOf) {
-      handleJavaLangEnumValueOf(method, invoke);
+    if (invokedMethod == dexItemFactory.enumMembers.valueOf
+        || invokedMethod == dexItemFactory.javaUtilEnumMapMembers.constructor
+        || dexItemFactory.javaUtilEnumSetMembers.isFactoryMethod(invokedMethod)) {
+      handleEnumValueOfOrCollectionInstantiation(method, invoke);
       return;
     }
     if (invokedMethod == dexItemFactory.proxyMethods.newProxyInstance) {
@@ -5480,17 +5487,44 @@
     }
   }
 
-  private void handleJavaLangEnumValueOf(ProgramMethod method, InvokeMethod invoke) {
+  private void handleEnumValueOfOrCollectionInstantiation(
+      ProgramMethod context, InvokeMethod invoke) {
+    if (invoke.inValues().isEmpty()) {
+      // Should never happen.
+      return;
+    }
+
     // The use of java.lang.Enum.valueOf(java.lang.Class, java.lang.String) will indirectly
     // access the values() method of the enum class passed as the first argument. The method
     // SomeEnumClass.valueOf(java.lang.String) which is generated by javac for all enums will
     // call this method.
-    if (invoke.inValues().get(0).isConstClass()) {
-      DexType type = invoke.inValues().get(0).definition.asConstClass().getType();
-      DexProgramClass clazz = getProgramClassOrNull(type, method);
-      if (clazz != null && clazz.isEnum()) {
-        markEnumValuesAsReachable(clazz, KeepReason.invokedFrom(method));
+    // Likewise, EnumSet and EnumMap call values() on the passed in Class.
+    Value firstArg = invoke.getFirstNonReceiverArgument();
+    if (firstArg.isPhi()) {
+      return;
+    }
+    DexType type;
+    if (invoke
+        .getInvokedMethod()
+        .getParameter(0)
+        .isIdenticalTo(appView.dexItemFactory().classType)) {
+      // EnumMap.<init>(), EnumSet.noneOf(), EnumSet.allOf(), Enum.valueOf().
+      ConstClass constClass = firstArg.definition.asConstClass();
+      if (constClass == null || !constClass.getType().isClassType()) {
+        return;
       }
+      type = constClass.getType();
+    } else {
+      // EnumSet.of(), EnumSet.range()
+      ClassTypeElement typeElement = firstArg.getType().asClassType();
+      if (typeElement == null) {
+        return;
+      }
+      type = typeElement.getClassType();
+    }
+    DexProgramClass clazz = getProgramClassOrNull(type, context);
+    if (clazz != null && clazz.isEnum()) {
+      markEnumValuesAsReachable(clazz, KeepReason.invokedFrom(context));
     }
   }
 
diff --git a/src/test/java/com/android/tools/r8/shaking/enums/EnumCollectionsTest.java b/src/test/java/com/android/tools/r8/shaking/enums/EnumCollectionsTest.java
new file mode 100644
index 0000000..2a9872c
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/shaking/enums/EnumCollectionsTest.java
@@ -0,0 +1,208 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking.enums;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import java.util.Arrays;
+import java.util.EnumMap;
+import java.util.EnumSet;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class EnumCollectionsTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDefaultRuntimes().withMaximumApiLevel().build();
+  }
+
+  private static final List<String> EXPECTED_OUTPUT =
+      Arrays.asList(
+          "none: [A, B]",
+          "all: [B, C]",
+          "of: [C]",
+          "of: [D, E]",
+          "of: [E, F, G]",
+          "of: [F, G, H, I]",
+          "of: [G, H, I, J, K]",
+          "of: [H, I, J, K, L, M]",
+          "range: [I, J]",
+          "map: {J=1}",
+          "valueOf: K",
+          "phi: [B]");
+
+  public static class TestMain {
+    public enum EnumA {
+      A,
+      B
+    }
+
+    public enum EnumB {
+      B,
+      C
+    }
+
+    public enum EnumC {
+      C,
+      D
+    }
+
+    public enum EnumD {
+      D,
+      E
+    }
+
+    public enum EnumE {
+      E,
+      F,
+      G
+    }
+
+    public enum EnumF {
+      F,
+      G,
+      H,
+      I
+    }
+
+    public enum EnumG {
+      G,
+      H,
+      I,
+      J,
+      K
+    }
+
+    public enum EnumH {
+      H,
+      I,
+      J,
+      K,
+      L,
+      M
+    }
+
+    public enum EnumI {
+      I,
+      J
+    }
+
+    public enum EnumJ {
+      J,
+      K
+    }
+
+    public enum EnumK {
+      K,
+      L
+    }
+
+    @NeverInline
+    private static void noneOf() {
+      System.out.println("none: " + EnumSet.complementOf(EnumSet.noneOf(EnumA.class)));
+    }
+
+    @NeverInline
+    private static void allOf() {
+      System.out.println("all: " + EnumSet.allOf(EnumB.class));
+    }
+
+    @NeverInline
+    private static void of1() {
+      System.out.println("of: " + EnumSet.of(EnumC.C));
+    }
+
+    @NeverInline
+    private static void of2() {
+      System.out.println("of: " + EnumSet.of(EnumD.D, EnumD.E));
+    }
+
+    @NeverInline
+    private static void of3() {
+      System.out.println("of: " + EnumSet.of(EnumE.E, EnumE.F, EnumE.G));
+    }
+
+    @NeverInline
+    private static void of4() {
+      System.out.println("of: " + EnumSet.of(EnumF.F, EnumF.G, EnumF.H, EnumF.I));
+    }
+
+    @NeverInline
+    private static void of5() {
+      System.out.println("of: " + EnumSet.of(EnumG.G, EnumG.H, EnumG.I, EnumG.J, EnumG.K));
+    }
+
+    @NeverInline
+    private static void ofVarArgs() {
+      System.out.println("of: " + EnumSet.of(EnumH.H, EnumH.I, EnumH.J, EnumH.K, EnumH.L, EnumH.M));
+    }
+
+    @NeverInline
+    private static void range() {
+      System.out.println("range: " + EnumSet.range(EnumI.I, EnumI.J));
+    }
+
+    @NeverInline
+    private static void map() {
+      EnumMap<EnumJ, Integer> map = new EnumMap<>(EnumJ.class);
+      map.put(EnumJ.J, 1);
+      System.out.println("map: " + map);
+    }
+
+    @NeverInline
+    private static void valueOf() {
+      System.out.println("valueOf: " + EnumK.valueOf("K"));
+    }
+
+    public static void main(String[] args) {
+      // Use different methods to ensure Enqueuer.traceInvokeStatic() triggers for each one.
+      noneOf();
+      allOf();
+      of1();
+      of2();
+      of3();
+      of4();
+      of5();
+      ofVarArgs();
+      range();
+      map();
+      valueOf();
+      // Ensure phi as argument does not cause issues.
+      System.out.println(
+          "phi: " + EnumSet.of((Enum) (args.length > 10 ? (Object) EnumA.A : (Object) EnumB.B)));
+    }
+  }
+
+  @Test
+  public void testRuntime() throws Exception {
+    testForRuntime(parameters)
+        .addProgramClassesAndInnerClasses(TestMain.class)
+        .run(parameters.getRuntime(), TestMain.class)
+        .assertSuccessWithOutputLines(EXPECTED_OUTPUT);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .setMinApi(parameters)
+        .addProgramClassesAndInnerClasses(TestMain.class)
+        .enableInliningAnnotations()
+        .addKeepMainRule(TestMain.class)
+        .compile()
+        .run(parameters.getRuntime(), TestMain.class)
+        .assertSuccessWithOutputLines(EXPECTED_OUTPUT);
+  }
+}