Enum unboxing: fields improved support

- fix field put null
- fix dexValue on static fields

Bug: 147860220
Change-Id: I82dbe8638eafc83da67aa27671a150ccf8902feb
diff --git a/src/main/java/com/android/tools/r8/dex/FileWriter.java b/src/main/java/com/android/tools/r8/dex/FileWriter.java
index 8b0cd00..561e18f 100644
--- a/src/main/java/com/android/tools/r8/dex/FileWriter.java
+++ b/src/main/java/com/android/tools/r8/dex/FileWriter.java
@@ -606,6 +606,7 @@
     assert PresortedComparable.isSorted(fields);
     int currentOffset = 0;
     for (DexEncodedField field : fields) {
+      assert field.validateDexValue(application.dexItemFactory);
       int nextOffset = mapping.getOffsetFor(field.field);
       assert nextOffset - currentOffset >= 0;
       dest.putUleb128(nextOffset - currentOffset);
diff --git a/src/main/java/com/android/tools/r8/graph/DexEncodedField.java b/src/main/java/com/android/tools/r8/graph/DexEncodedField.java
index e8bd9dc..a22091c 100644
--- a/src/main/java/com/android/tools/r8/graph/DexEncodedField.java
+++ b/src/main/java/com/android/tools/r8/graph/DexEncodedField.java
@@ -254,4 +254,19 @@
             : DefaultFieldOptimizationInfo.getInstance();
     return result;
   }
+
+  public boolean validateDexValue(DexItemFactory factory) {
+    if (!accessFlags.isStatic() || staticValue == null) {
+      return true;
+    }
+    if (field.type.isPrimitiveType()) {
+      assert staticValue.getType(factory) == field.type
+          : "Static " + field + " has invalid static value " + staticValue + ".";
+    }
+    if (staticValue.isDexValueNull()) {
+      assert field.type.isReferenceType() : "Static " + field + " has invalid null static value.";
+    }
+    // TODO(b/150593449): Support non primitive DexValue (String, enum) and add assertions.
+    return true;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/DexValue.java b/src/main/java/com/android/tools/r8/graph/DexValue.java
index 794f7c8..be0cf37 100644
--- a/src/main/java/com/android/tools/r8/graph/DexValue.java
+++ b/src/main/java/com/android/tools/r8/graph/DexValue.java
@@ -304,6 +304,8 @@
     }
   }
 
+  public abstract DexType getType(DexItemFactory factory);
+
   public abstract Object getBoxedValue();
 
   /** Returns an instruction that can be used to materialize this {@link DexValue} (or null). */
@@ -390,6 +392,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.byteType;
+    }
+
+    @Override
     public long getRawValue() {
       return value;
     }
@@ -463,6 +470,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.shortType;
+    }
+
+    @Override
     public long getRawValue() {
       return value;
     }
@@ -535,6 +547,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.charType;
+    }
+
+    @Override
     public long getRawValue() {
       return value;
     }
@@ -611,6 +628,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.intType;
+    }
+
+    @Override
     public long getRawValue() {
       return value;
     }
@@ -683,6 +705,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.longType;
+    }
+
+    @Override
     public long getRawValue() {
       return value;
     }
@@ -755,6 +782,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.floatType;
+    }
+
+    @Override
     public long getRawValue() {
       return Float.floatToIntBits(value);
     }
@@ -833,6 +865,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.doubleType;
+    }
+
+    @Override
     public long getRawValue() {
       return Double.doubleToRawLongBits(value);
     }
@@ -902,6 +939,11 @@
 
     protected abstract DexValueKind getValueKind();
 
+    @Override
+    public DexType getType(DexItemFactory factory) {
+      throw new Unreachable();
+    }
+
     public T getValue() {
       return value;
     }
@@ -987,6 +1029,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.stringType;
+    }
+
+    @Override
     public ConstInstruction asConstInstruction(
         AppView<? extends AppInfoWithSubtyping> appView, IRCode code, DebugLocalInfo local) {
       TypeLatticeElement type = TypeLatticeElement.stringClassType(appView, definitelyNotNull());
@@ -1040,6 +1087,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.stringType;
+    }
+
+    @Override
     public ConstInstruction asConstInstruction(
         AppView<? extends AppInfoWithSubtyping> appView, IRCode code, DebugLocalInfo local) {
       TypeLatticeElement type = TypeLatticeElement.stringClassType(appView, definitelyNotNull());
@@ -1195,6 +1247,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      throw new Unreachable();
+    }
+
+    @Override
     public Object getBoxedValue() {
       throw new Unreachable("No boxed value for DexValueArray");
     }
@@ -1265,6 +1322,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      throw new Unreachable();
+    }
+
+    @Override
     public Object getBoxedValue() {
       throw new Unreachable("No boxed value for DexValueAnnotation");
     }
@@ -1315,6 +1377,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      throw new Unreachable();
+    }
+
+    @Override
     public long getRawValue() {
       return 0;
     }
@@ -1391,6 +1458,11 @@
     }
 
     @Override
+    public DexType getType(DexItemFactory factory) {
+      return factory.booleanType;
+    }
+
+    @Override
     public long getRawValue() {
       return BooleanUtils.longValue(value);
     }
diff --git a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
index 663524d..dd2cde4 100644
--- a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
+++ b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
@@ -4,11 +4,9 @@
 
 package com.android.tools.r8.graph;
 
-import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.code.ConstInstruction;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Position;
-import com.android.tools.r8.ir.code.ValueType;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.BooleanUtils;
 import com.google.common.collect.Ordering;
@@ -150,28 +148,6 @@
       return newType == appView.dexItemFactory().voidType;
     }
 
-    public boolean defaultValueHasChanged() {
-      if (newType.isPrimitiveType()) {
-        if (oldType.isPrimitiveType()) {
-          return ValueType.fromDexType(newType) != ValueType.fromDexType(oldType);
-        }
-        return true;
-      } else if (oldType.isPrimitiveType()) {
-        return true;
-      }
-      // All reference types uses null as default value.
-      assert newType.isReferenceType();
-      assert oldType.isReferenceType();
-      return false;
-    }
-
-    public TypeLatticeElement defaultValueLatticeElement(AppView<?> appView) {
-      if (newType.isPrimitiveType()) {
-        return TypeLatticeElement.fromDexType(newType, null, appView);
-      }
-      return TypeLatticeElement.getNull();
-    }
-
     @Override
     public boolean isRewrittenTypeInfo() {
       return true;
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/value/SingleFieldValue.java b/src/main/java/com/android/tools/r8/ir/analysis/value/SingleFieldValue.java
index dffd9e0..5479737 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/value/SingleFieldValue.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/value/SingleFieldValue.java
@@ -111,7 +111,7 @@
   public SingleValue rewrittenWithLens(AppView<AppInfoWithLiveness> appView, GraphLense lens) {
     DexField rewrittenField = lens.lookupField(field);
     assert !appView.unboxedEnums().containsEnum(field.holder)
-        || !appView.definitionFor(rewrittenField).accessFlags.isEnum();
+        || !appView.appInfo().resolveField(rewrittenField).accessFlags.isEnum();
     return appView.abstractValueFactory().createSingleFieldValue(rewrittenField);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index d7cfc6b..b8b3c4c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -61,6 +61,7 @@
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.StaticPut;
 import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.code.ValueType;
 import com.android.tools.r8.logging.Log;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
@@ -195,22 +196,14 @@
                 ArgumentInfo argumentInfo = argumentInfoCollection.getArgumentInfo(i);
                 if (argumentInfo.isRewrittenTypeInfo()) {
                   RewrittenTypeInfo argInfo = argumentInfo.asRewrittenTypeInfo();
-                  Value value = invoke.inValues().get(i);
-                  // When converting types the default value may change (for example default value
-                  // of a reference type is null while default value of int is 0).
-                  if (argInfo.defaultValueHasChanged()
-                      && value.isConstNumber()
-                      && value.definition.asConstNumber().isZero()) {
-                    iterator.previous();
-                    // TODO(b/150188380): Add API to insert a const instruction with a type lattice.
-                    Value rewrittenNull =
-                        iterator.insertConstIntInstruction(code, appView.options(), 0);
-                    iterator.next();
-                    rewrittenNull.setTypeLattice(argInfo.defaultValueLatticeElement(appView));
-                    newInValues.add(rewrittenNull);
-                  } else {
-                    newInValues.add(invoke.inValues().get(i));
-                  }
+                  Value rewrittenValue =
+                      rewriteValueIfDefault(
+                          code,
+                          iterator,
+                          argInfo.getOldType(),
+                          argInfo.getNewType(),
+                          invoke.inValues().get(i));
+                  newInValues.add(rewrittenValue);
                 } else if (!argumentInfo.isRemovedArgumentInfo()) {
                   newInValues.add(invoke.inValues().get(i));
                 }
@@ -311,9 +304,12 @@
             iterator.replaceCurrentInstruction(
                 new InvokeStatic(replacementMethod, null, current.inValues()));
           } else if (actualField != field) {
+            Value rewrittenValue =
+                rewriteValueIfDefault(
+                    code, iterator, field.type, actualField.type, instancePut.value());
             InstancePut newInstancePut =
                 InstancePut.createPotentiallyInvalid(
-                    actualField, instancePut.object(), instancePut.value());
+                    actualField, instancePut.object(), rewrittenValue);
             iterator.replaceCurrentInstruction(newInstancePut);
           }
         } else if (current.isStaticGet()) {
@@ -348,7 +344,10 @@
             iterator.replaceCurrentInstruction(
                 new InvokeStatic(replacementMethod, current.outValue(), current.inValues()));
           } else if (actualField != field) {
-            StaticPut newStaticPut = new StaticPut(staticPut.value(), actualField);
+            Value rewrittenValue =
+                rewriteValueIfDefault(
+                    code, iterator, field.type, actualField.type, staticPut.value());
+            StaticPut newStaticPut = new StaticPut(rewrittenValue, actualField);
             iterator.replaceCurrentInstruction(newStaticPut);
           }
         } else if (current.isCheckCast()) {
@@ -414,6 +413,49 @@
     assert code.hasNoVerticallyMergedClasses(appView);
   }
 
+  // If the initialValue is a default value and its type is rewritten from a reference type to a
+  // primitive type, then the default value type lattice needs to be changed.
+  private Value rewriteValueIfDefault(
+      IRCode code,
+      InstructionListIterator iterator,
+      DexType oldType,
+      DexType newType,
+      Value initialValue) {
+    if (initialValue.isConstNumber()
+        && initialValue.definition.asConstNumber().isZero()
+        && defaultValueHasChanged(oldType, newType)) {
+      iterator.previous();
+      // TODO(b/150188380): Add API to insert a const instruction with a type lattice.
+      Value rewrittenDefaultValue = iterator.insertConstIntInstruction(code, appView.options(), 0);
+      iterator.next();
+      rewrittenDefaultValue.setTypeLattice(defaultValueLatticeElement(newType));
+      return rewrittenDefaultValue;
+    }
+    return initialValue;
+  }
+
+  private boolean defaultValueHasChanged(DexType oldType, DexType newType) {
+    if (newType.isPrimitiveType()) {
+      if (oldType.isPrimitiveType()) {
+        return ValueType.fromDexType(newType) != ValueType.fromDexType(oldType);
+      }
+      return true;
+    } else if (oldType.isPrimitiveType()) {
+      return true;
+    }
+    // All reference types uses null as default value.
+    assert newType.isReferenceType();
+    assert oldType.isReferenceType();
+    return false;
+  }
+
+  private TypeLatticeElement defaultValueLatticeElement(DexType type) {
+    if (type.isPrimitiveType()) {
+      return TypeLatticeElement.fromDexType(type, null, appView);
+    }
+    return TypeLatticeElement.getNull();
+  }
+
   public DexCallSite rewriteCallSite(DexCallSite callSite, DexEncodedMethod context) {
     DexItemFactory dexItemFactory = appView.dexItemFactory();
     DexProto newMethodProto =
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index e96b211..f1adda9 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -17,6 +17,8 @@
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.DexValue.DexValueInt;
+import com.android.tools.r8.graph.DexValue.DexValueNull;
 import com.android.tools.r8.graph.EnumValueInfoMapCollection.EnumValueInfoMap;
 import com.android.tools.r8.graph.GraphLense;
 import com.android.tools.r8.graph.GraphLense.NestedGraphLense;
@@ -612,7 +614,13 @@
         if (newType != field.type) {
           DexField newField = factory.createField(field.holder, newType, field.name);
           lensBuilder.move(field, newField);
-          setter.setField(i, encodedField.toTypeSubstitutedField(newField));
+          DexEncodedField newEncodedField = encodedField.toTypeSubstitutedField(newField);
+          setter.setField(i, newEncodedField);
+          if (encodedField.isStatic() && encodedField.hasExplicitStaticValue()) {
+            assert encodedField.getStaticValue() == DexValueNull.NULL;
+            newEncodedField.setStaticValue(DexValueInt.DEFAULT);
+            // TODO(b/150593449): Support conversion from DexValueEnum to DexValueInt.
+          }
         }
       }
     }
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/FieldPutEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/FieldPutEnumUnboxingTest.java
index f2ed08b..eedc8f1 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/FieldPutEnumUnboxingTest.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/FieldPutEnumUnboxingTest.java
@@ -81,10 +81,12 @@
       C
     }
 
-    MyEnum e;
+    MyEnum e = null;
 
     public static void main(String[] args) {
       InstanceFieldPut fieldPut = new InstanceFieldPut();
+      System.out.println(fieldPut.e == null);
+      System.out.println("true");
       fieldPut.setA();
       System.out.println(fieldPut.e.ordinal());
       System.out.println(0);
@@ -113,9 +115,11 @@
       C
     }
 
-    static MyEnum e;
+    static MyEnum e = null;
 
     public static void main(String[] args) {
+      System.out.println(StaticFieldPut.e == null);
+      System.out.println("true");
       setA();
       System.out.println(StaticFieldPut.e.ordinal());
       System.out.println(0);
diff --git a/src/test/java/com/android/tools/r8/rewrite/enums/EnumOptimizationTest.java b/src/test/java/com/android/tools/r8/rewrite/enums/EnumOptimizationTest.java
index 78a36f2..91f0c4d 100644
--- a/src/test/java/com/android/tools/r8/rewrite/enums/EnumOptimizationTest.java
+++ b/src/test/java/com/android/tools/r8/rewrite/enums/EnumOptimizationTest.java
@@ -49,6 +49,7 @@
   private void configure(InternalOptions options) {
     options.enableEnumValueOptimization = enableOptimization;
     options.enableEnumSwitchMapRemoval = enableOptimization;
+    options.enableEnumUnboxing = false;
   }
 
   @Test