Enum unboxing: support fields

Bug: 147860220
Change-Id: I088b2a0366031b0eec58dc3fe34972561e2d38f3
diff --git a/src/main/java/com/android/tools/r8/graph/GraphLense.java b/src/main/java/com/android/tools/r8/graph/GraphLense.java
index e01f0ee..11bf08e 100644
--- a/src/main/java/com/android/tools/r8/graph/GraphLense.java
+++ b/src/main/java/com/android/tools/r8/graph/GraphLense.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.graph;
 
 import com.android.tools.r8.ir.code.Invoke.Type;
+import com.android.tools.r8.shaking.AppInfoWithLiveness.EnumValueInfo;
 import com.google.common.collect.BiMap;
 import com.google.common.collect.HashBiMap;
 import com.google.common.collect.ImmutableMap;
@@ -322,10 +323,22 @@
 
   public static <T extends DexReference, S> ImmutableMap<T, S> rewriteReferenceKeys(
       Map<T, S> original, Function<T, T> rewrite) {
-    ImmutableMap.Builder<T, S> builder = new ImmutableMap.Builder<>();
-    for (T item : original.keySet()) {
-      builder.put(rewrite.apply(item), original.get(item));
-    }
+    ImmutableMap.Builder<T, S> builder = ImmutableMap.builder();
+    original.forEach((item, value) -> builder.put(rewrite.apply(item), value));
+    return builder.build();
+  }
+
+  // TODO(b/150193407): Move to enumInfoMap and rewrite fields.
+  public static ImmutableMap<DexType, Map<DexField, EnumValueInfo>> rewriteEnumValueInfoMaps(
+      Map<DexType, Map<DexField, EnumValueInfo>> original, GraphLense lens) {
+    ImmutableMap.Builder<DexType, Map<DexField, EnumValueInfo>> builder = ImmutableMap.builder();
+    original.forEach(
+        (enumType, map) -> {
+          DexType dexType = lens.lookupType(enumType);
+          if (!dexType.isPrimitiveType()) {
+            builder.put(dexType, map);
+          }
+        });
     return builder.build();
   }
 
@@ -476,6 +489,7 @@
   // This lens clears all code rewriting (lookup methods mimics identity lens behavior) but still
   // relies on the previous lens for names (getRenamed/Original methods).
   public static class ClearCodeRewritingGraphLens extends IdentityGraphLense {
+
     private final GraphLense previous;
 
     public ClearCodeRewritingGraphLens(GraphLense previous) {
diff --git a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
index af2bdd4..53510bd 100644
--- a/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
+++ b/src/main/java/com/android/tools/r8/graph/ObjectAllocationInfoCollectionImpl.java
@@ -141,8 +141,11 @@
         GraphLense lens) {
       objectAllocationInfos.classesWithAllocationSiteTracking.forEach(
           (clazz, allocationSitesForClass) -> {
-            DexProgramClass rewrittenClass =
-                asProgramClassOrNull(definitions.definitionFor(lens.lookupType(clazz.type)));
+            DexType type = lens.lookupType(clazz.type);
+            if (type.isPrimitiveType()) {
+              return;
+            }
+            DexProgramClass rewrittenClass = asProgramClassOrNull(definitions.definitionFor(type));
             assert rewrittenClass != null;
             assert !classesWithAllocationSiteTracking.containsKey(rewrittenClass);
             classesWithAllocationSiteTracking.put(
@@ -152,8 +155,11 @@
           });
       objectAllocationInfos.classesWithoutAllocationSiteTracking.forEach(
           clazz -> {
-            DexProgramClass rewrittenClass =
-                asProgramClassOrNull(definitions.definitionFor(lens.lookupType(clazz.type)));
+            DexType type = lens.lookupType(clazz.type);
+            if (type.isPrimitiveType()) {
+              return;
+            }
+            DexProgramClass rewrittenClass = asProgramClassOrNull(definitions.definitionFor(type));
             assert rewrittenClass != null;
             assert !classesWithAllocationSiteTracking.containsKey(rewrittenClass);
             assert !classesWithoutAllocationSiteTracking.contains(rewrittenClass);
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index d99d342..053c575 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -26,6 +26,7 @@
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.ConstClass;
+import com.android.tools.r8.ir.code.FieldInstruction;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.If;
 import com.android.tools.r8.ir.code.Instruction;
@@ -236,9 +237,15 @@
     }
     ImmutableSet<DexType> enumsToUnbox = ImmutableSet.copyOf(this.enumsUnboxingCandidates.keySet());
     appView.setUnboxedEnums(enumsToUnbox);
-    GraphLense enumUnboxingLens = new TreeFixer(enumsToUnbox).fixupTypeReferences();
-    appView.setGraphLense(enumUnboxingLens);
+    NestedGraphLense enumUnboxingLens = new TreeFixer(enumsToUnbox).fixupTypeReferences();
     enumUnboxerRewriter = new EnumUnboxingRewriter(appView, enumsToUnbox);
+    if (enumUnboxingLens != null) {
+      appView.setGraphLense(enumUnboxingLens);
+      appView.setAppInfo(
+          appView
+              .appInfo()
+              .rewrittenWithLens(appView.appInfo().app().asDirect(), enumUnboxingLens));
+    }
     postBuilder.put(this);
     postBuilder.mapDexEncodedMethods(appView);
   }
@@ -341,28 +348,27 @@
       return Reason.ELIGIBLE;
     }
 
-    // TODO(b/147860220): Re-enable enum unboxing with fields of enum types.
     // A field put is valid only if the field is not on an enum, and the field type and the valuePut
     // have identical enum type.
-    // if (instruction.isFieldPut()) {
-    //   FieldInstruction fieldInstruction = instruction.asFieldInstruction();
-    //   DexEncodedField field = appView.appInfo().resolveField(fieldInstruction.getField());
-    //   if (field == null) {
-    //     return Reason.INVALID_FIELD_PUT;
-    //   }
-    //   DexProgramClass dexClass = appView.definitionForProgramType(field.field.holder);
-    //   if (dexClass == null) {
-    //     return Reason.INVALID_FIELD_PUT;
-    //   }
-    //   if (dexClass.isEnum()) {
-    //     return Reason.FIELD_PUT_ON_ENUM;
-    //   }
-    //   // The put value has to be of the field type.
-    //   if (field.field.type != enumClass.type) {
-    //     return Reason.TYPE_MISSMATCH_FIELD_PUT;
-    //   }
-    //   return Reason.ELIGIBLE;
-    // }
+    if (instruction.isFieldPut()) {
+      FieldInstruction fieldInstruction = instruction.asFieldInstruction();
+      DexEncodedField field = appView.appInfo().resolveField(fieldInstruction.getField());
+      if (field == null) {
+        return Reason.INVALID_FIELD_PUT;
+      }
+      DexProgramClass dexClass = appView.definitionForProgramType(field.field.holder);
+      if (dexClass == null) {
+        return Reason.INVALID_FIELD_PUT;
+      }
+      if (dexClass.isEnum()) {
+        return Reason.FIELD_PUT_ON_ENUM;
+      }
+      // The put value has to be of the field type.
+      if (field.field.type != enumClass.type) {
+        return Reason.TYPE_MISSMATCH_FIELD_PUT;
+      }
+      return Reason.ELIGIBLE;
+    }
 
     // An If using enum as inValue is valid if it matches e == null
     // or e == X with X of same enum type as e. Ex: if (e == MyEnum.A).
@@ -499,7 +505,7 @@
       this.enumsToUnbox = enumsToUnbox;
     }
 
-    private GraphLense fixupTypeReferences() {
+    private NestedGraphLense fixupTypeReferences() {
       // Fix all methods and fields using enums to unbox.
       for (DexProgramClass clazz : appView.appInfo().classes()) {
         if (enumsToUnbox.contains(clazz.type)) {
@@ -662,10 +668,9 @@
             to, RewrittenPrototypeDescription.createForRewrittenTypes(returnInfo, builder.build()));
       }
 
-      @Override
-      public GraphLense build(DexItemFactory dexItemFactory, GraphLense previousLense) {
+      public EnumUnboxingLens build(DexItemFactory dexItemFactory, GraphLense previousLense) {
         if (typeMap.isEmpty() && methodMap.isEmpty() && fieldMap.isEmpty()) {
-          return previousLense;
+          return null;
         }
         return new EnumUnboxingLens(
             typeMap,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 9888aec..1626d7f 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -36,6 +36,7 @@
 import com.android.tools.r8.origin.SynthesizedOrigin;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.AppInfoWithLiveness.EnumValueInfo;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -52,7 +53,8 @@
 
   private final AppView<AppInfoWithLiveness> appView;
   private final DexItemFactory factory;
-  private final Set<DexType> enumsToUnbox;
+  // TODO(b/150193407): Use enumInfoMap instead of Map<DexField, EnumValueInfo>.
+  private final Map<DexType, Map<DexField, EnumValueInfo>> enumsToUnbox;
 
   private final DexType utilityClassType;
   private final DexMethod ordinalUtilityMethod;
@@ -62,7 +64,11 @@
   EnumUnboxingRewriter(AppView<AppInfoWithLiveness> appView, Set<DexType> enumsToUnbox) {
     this.appView = appView;
     this.factory = appView.dexItemFactory();
-    this.enumsToUnbox = enumsToUnbox;
+    ImmutableMap.Builder<DexType, Map<DexField, EnumValueInfo>> builder = ImmutableMap.builder();
+    for (DexType toUnbox : enumsToUnbox) {
+      builder.put(toUnbox, appView.appInfo().withLiveness().getEnumValueInfoMapFor(toUnbox));
+    }
+    this.enumsToUnbox = builder.build();
 
     this.utilityClassType = factory.createType("L" + ENUM_UNBOXING_UTILITY_CLASS_NAME + ";");
     this.ordinalUtilityMethod =
@@ -102,16 +108,15 @@
       if (instruction.isStaticGet()) {
         StaticGet staticGet = instruction.asStaticGet();
         DexType holder = staticGet.getField().holder;
-        if (enumsToUnbox.contains(holder)) {
+        if (enumsToUnbox.containsKey(holder)) {
           if (staticGet.outValue() == null) {
             iterator.removeInstructionIgnoreOutValue();
             continue;
           }
-          Map<DexField, EnumValueInfo> enumValueInfoMapFor =
-              appView.appInfo().withLiveness().getEnumValueInfoMapFor(holder);
-          assert enumValueInfoMapFor != null;
+          Map<DexField, EnumValueInfo> enumValueInfoMap = enumsToUnbox.get(holder);
+          assert enumValueInfoMap != null;
           // Replace by ordinal + 1 for null check (null is 0).
-          EnumValueInfo enumValueInfo = enumValueInfoMapFor.get(staticGet.getField());
+          EnumValueInfo enumValueInfo = enumValueInfoMap.get(staticGet.getField());
           assert enumValueInfo != null
               : "Invalid read to " + staticGet.getField().name + ", error during enum analysis";
           instruction = new ConstNumber(staticGet.outValue(), enumValueInfo.ordinal + 1);
@@ -137,12 +142,12 @@
     }
     TypeLatticeElement typeLattice = instruction.outValue().getTypeLattice();
     assert !typeLattice.isClassType()
-        || !enumsToUnbox.contains(typeLattice.asClassTypeLatticeElement().getClassType());
+        || !enumsToUnbox.containsKey(typeLattice.asClassTypeLatticeElement().getClassType());
     if (typeLattice.isArrayType()) {
       TypeLatticeElement arrayBaseTypeLattice =
           typeLattice.asArrayTypeLatticeElement().getArrayBaseTypeLattice();
       assert !arrayBaseTypeLattice.isClassType()
-          || !enumsToUnbox.contains(
+          || !enumsToUnbox.containsKey(
               arrayBaseTypeLattice.asClassTypeLatticeElement().getClassType());
     }
     return true;
@@ -190,7 +195,7 @@
 
   // TODO(b/150178516): Add a test for this case.
   private boolean utilityClassInMainDexList() {
-    for (DexType toUnbox : enumsToUnbox) {
+    for (DexType toUnbox : enumsToUnbox.keySet()) {
       if (appView.appInfo().isInMainDexList(toUnbox)) {
         return true;
       }
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index fd8742d..6e5899d 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.shaking;
 
 import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+import static com.android.tools.r8.graph.GraphLense.rewriteEnumValueInfoMaps;
 import static com.android.tools.r8.graph.GraphLense.rewriteReferenceKeys;
 
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
@@ -1065,7 +1066,7 @@
         // Don't rewrite pruned types - the removed types are identified by their original name.
         prunedTypes,
         rewriteReferenceKeys(switchMaps, lens::lookupField),
-        rewriteReferenceKeys(enumValueInfoMaps, lens::lookupType),
+        rewriteEnumValueInfoMaps(enumValueInfoMaps, lens),
         rewriteItems(instantiatedLambdas, lens::lookupType),
         constClassReferences);
   }
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/FieldPutEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/FieldPutEnumUnboxingTest.java
index f67cd54..129230a 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/FieldPutEnumUnboxingTest.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/FieldPutEnumUnboxingTest.java
@@ -12,7 +12,6 @@
 import com.android.tools.r8.R8TestRunResult;
 import com.android.tools.r8.TestParameters;
 import java.util.List;
-import org.junit.Assume;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -42,8 +41,6 @@
 
   @Test
   public void testEnumUnboxing() throws Exception {
-    // TODO(b/147860220): Fix fields of type enums.
-    Assume.assumeTrue("Fix fields", false);
     R8TestCompileResult compile =
         testForR8(parameters.getBackend())
             .addInnerClasses(FieldPutEnumUnboxingTest.class)