Enum unboxing working on r8

Bug: 147860220
Change-Id: I906bdf40a383c06dde3b1fc81ae6fb9bd57e095f
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/DestructivePhiTypeUpdater.java b/src/main/java/com/android/tools/r8/ir/analysis/type/DestructivePhiTypeUpdater.java
index 3af208a..6449294 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/DestructivePhiTypeUpdater.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/DestructivePhiTypeUpdater.java
@@ -64,7 +64,10 @@
         affectedValues.addAll(phi.affectedValues());
       }
     }
-    assert new TypeAnalysis(appView).verifyValuesUpToDate(affectedPhis);
+
+    // TODO(b/150409786): Move all code rewriting passes into the lens code rewriter.
+    // assert new TypeAnalysis(appView).verifyValuesUpToDate(affectedPhis);
+
     // Now that the types of all transitively type affected phis have been reset, we can
     // perform a narrowing, starting from the values that are affected by those phis.
     if (!affectedValues.isEmpty()) {
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCode.java b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
index e620cb1..952bd5c 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCode.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
@@ -545,11 +545,16 @@
   }
 
   public boolean isConsistentSSA() {
+    isConsistentSSABeforeTypesAreCorrect();
+    assert verifyNoImpreciseOrBottomTypes();
+    return true;
+  }
+
+  public boolean isConsistentSSABeforeTypesAreCorrect() {
     assert isConsistentGraph();
     assert consistentDefUseChains();
     assert validThrowingInstructions();
     assert noCriticalEdges();
-    assert verifyNoImpreciseOrBottomTypes();
     assert verifyNoValueWithOnlyAssumeInstructionAsUsers();
     return true;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 2b20b77..c49cb3d 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -300,8 +300,10 @@
       this.lambdaMerger =
           options.enableLambdaMerging ? new LambdaMerger(appViewWithLiveness) : null;
       this.lensCodeRewriter = new LensCodeRewriter(appViewWithLiveness);
+      this.enumUnboxer = options.enableEnumUnboxing ? new EnumUnboxer(appViewWithLiveness) : null;
       this.inliner =
-          new Inliner(appViewWithLiveness, mainDexClasses, lambdaMerger, lensCodeRewriter);
+          new Inliner(
+              appViewWithLiveness, mainDexClasses, lambdaMerger, lensCodeRewriter, enumUnboxer);
       this.outliner = new Outliner(appViewWithLiveness);
       this.memberValuePropagation =
           options.enableValuePropagation ? new MemberValuePropagation(appViewWithLiveness) : null;
@@ -331,7 +333,6 @@
               : null;
       this.enumValueOptimizer =
           options.enableEnumValueOptimization ? new EnumValueOptimizer(appViewWithLiveness) : null;
-      this.enumUnboxer = options.enableEnumUnboxing ? new EnumUnboxer(appViewWithLiveness) : null;
     } else {
       this.classInliner = null;
       this.classStaticizer = null;
@@ -1182,6 +1183,7 @@
     // check. In the latter case, the type checker should be extended to detect the issue such that
     // we will return with finalizeEmptyThrowingCode() above.
     assert code.verifyTypes(appView);
+    assert code.isConsistentSSA();
 
     assertionsRewriter.run(method, code, timing);
 
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 0266798..d7cfc6b 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -410,7 +410,7 @@
     if (!affectedPhis.isEmpty()) {
       new DestructivePhiTypeUpdater(appView).recomputeAndPropagateTypes(code, affectedPhis);
     }
-    assert code.isConsistentSSA();
+    assert code.isConsistentSSABeforeTypesAreCorrect();
     assert code.hasNoVerticallyMergedClasses(appView);
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
index ee62dc3..0403117 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
@@ -44,6 +44,7 @@
 import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.conversion.PostOptimization;
 import com.android.tools.r8.ir.desugar.TwrCloseResourceRewriter;
+import com.android.tools.r8.ir.optimize.enums.EnumUnboxer;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackIgnore;
 import com.android.tools.r8.ir.optimize.inliner.DefaultInliningReasonStrategy;
@@ -78,6 +79,7 @@
   private final Set<DexMethod> blacklist;
   private final LambdaMerger lambdaMerger;
   private final LensCodeRewriter lensCodeRewriter;
+  private final EnumUnboxer enumUnboxer;
   final MainDexClasses mainDexClasses;
 
   // State for inlining methods which are known to be called twice.
@@ -90,7 +92,8 @@
       AppView<AppInfoWithLiveness> appView,
       MainDexClasses mainDexClasses,
       LambdaMerger lambdaMerger,
-      LensCodeRewriter lensCodeRewriter) {
+      LensCodeRewriter lensCodeRewriter,
+      EnumUnboxer enumUnboxer) {
     Kotlin.Intrinsics intrinsics = appView.dexItemFactory().kotlin.intrinsics;
     this.appView = appView;
     this.blacklist =
@@ -99,6 +102,7 @@
             : ImmutableSet.of(intrinsics.throwNpe, intrinsics.throwParameterIsNullException);
     this.lambdaMerger = lambdaMerger;
     this.lensCodeRewriter = lensCodeRewriter;
+    this.enumUnboxer = enumUnboxer;
     this.mainDexClasses = mainDexClasses;
   }
 
@@ -587,7 +591,8 @@
         DexEncodedMethod context,
         InliningIRProvider inliningIRProvider,
         LambdaMerger lambdaMerger,
-        LensCodeRewriter lensCodeRewriter) {
+        LensCodeRewriter lensCodeRewriter,
+        EnumUnboxer enumUnboxer) {
       DexItemFactory dexItemFactory = appView.dexItemFactory();
       InternalOptions options = appView.options();
 
@@ -721,6 +726,9 @@
       if (inliningIRProvider.shouldApplyCodeRewritings(code.method)) {
         assert lensCodeRewriter != null;
         lensCodeRewriter.rewrite(code, target);
+        if (enumUnboxer != null) {
+          enumUnboxer.rewriteCode(code);
+        }
       }
       if (lambdaMerger != null) {
         lambdaMerger.rewriteCodeForInlining(target, code, context, inliningIRProvider);
@@ -964,7 +972,13 @@
 
           InlineeWithReason inlinee =
               action.buildInliningIR(
-                  appView, invoke, context, inliningIRProvider, lambdaMerger, lensCodeRewriter);
+                  appView,
+                  invoke,
+                  context,
+                  inliningIRProvider,
+                  lambdaMerger,
+                  lensCodeRewriter,
+                  enumUnboxer);
           if (strategy.willExceedBudget(
               code, invoke, inlinee, block, whyAreYouNotInliningReporter)) {
             assert whyAreYouNotInliningReporter.unsetReasonHasBeenReportedFlag();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index 8bbc20e..5bcfac5 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -84,9 +84,12 @@
 
   public void analyzeEnums(IRCode code) {
     // Enum <clinit> and <init> are analyzed in between the two processing phases using optimization
-    // feedback.
+    // feedback. Methods valueOf and values are generated by javac and are analyzed differently.
     DexClass dexClass = appView.definitionFor(code.method.method.holder);
-    if (dexClass.isEnum() && code.method.isInitializer()) {
+    if (dexClass.isEnum()
+        && (code.method.isInitializer()
+            || appView.dexItemFactory().enumMethods.isValueOfMethod(code.method.method, dexClass)
+            || appView.dexItemFactory().enumMethods.isValuesMethod(code.method.method, dexClass))) {
       return;
     }
     analyzeEnumsInMethod(code);
@@ -129,10 +132,13 @@
         if (outValue != null) {
           DexProgramClass enumClass = getEnumUnboxingCandidateOrNull(outValue.getTypeLattice());
           if (enumClass != null) {
-            Reason reason =
-                validateEnumUsages(
-                    code, outValue.uniqueUsers(), outValue.uniquePhiUsers(), enumClass);
+            Reason reason = validateEnumUsages(code, outValue, enumClass);
             if (reason == Reason.ELIGIBLE) {
+              if (instruction.isCheckCast()) {
+                // We are doing a type check, which typically means the in-value is of an upper
+                // type and cannot be dealt with.
+                markEnumAsUnboxable(Reason.DOWN_CAST, enumClass);
+              }
               eligibleEnums.add(enumClass.type);
             }
           }
@@ -146,13 +152,6 @@
         if (instruction.isConstClass()) {
           ConstClass constClass = instruction.asConstClass();
           if (enumsUnboxingCandidates.containsKey(constClass.getValue())) {
-            DexMethod context = code.method.method;
-            DexClass dexClass = appView.definitionFor(context.holder);
-            if (dexClass != null
-                && dexClass.isEnum()
-                && factory.enumMethods.isValueOfMethod(context, dexClass)) {
-              continue;
-            }
             markEnumAsUnboxable(
                 Reason.CONST_CLASS, appView.definitionForProgramType(constClass.getValue()));
           }
@@ -161,8 +160,7 @@
       for (Phi phi : block.getPhis()) {
         DexProgramClass enumClass = getEnumUnboxingCandidateOrNull(phi.getTypeLattice());
         if (enumClass != null) {
-          Reason reason =
-              validateEnumUsages(code, phi.uniqueUsers(), phi.uniquePhiUsers(), enumClass);
+          Reason reason = validateEnumUsages(code, phi, enumClass);
           if (reason == Reason.ELIGIBLE) {
             eligibleEnums.add(enumClass.type);
           }
@@ -210,16 +208,15 @@
     }
   }
 
-  private Reason validateEnumUsages(
-      IRCode code, Set<Instruction> uses, Set<Phi> phiUses, DexProgramClass enumClass) {
-    for (Instruction user : uses) {
-      Reason reason = instructionAllowEnumUnboxing(user, code, enumClass);
+  private Reason validateEnumUsages(IRCode code, Value value, DexProgramClass enumClass) {
+    for (Instruction user : value.uniqueUsers()) {
+      Reason reason = instructionAllowEnumUnboxing(user, code, enumClass, value);
       if (reason != Reason.ELIGIBLE) {
         markEnumAsUnboxable(reason, enumClass);
         return reason;
       }
     }
-    for (Phi phi : phiUses) {
+    for (Phi phi : value.uniquePhiUsers()) {
       for (Value operand : phi.getOperands()) {
         if (getEnumUnboxingCandidateOrNull(operand.getTypeLattice()) != enumClass) {
           markEnumAsUnboxable(Reason.INVALID_PHI, enumClass);
@@ -289,7 +286,7 @@
   }
 
   private Reason instructionAllowEnumUnboxing(
-      Instruction instruction, IRCode code, DexProgramClass enumClass) {
+      Instruction instruction, IRCode code, DexProgramClass enumClass, Value enumValue) {
 
     // All invokes in the library are invalid, besides a few cherry picked cases such as ordinal().
     if (instruction.isInvokeMethod()) {
@@ -301,25 +298,33 @@
         }
         return Reason.INVALID_INVOKE_ON_ARRAY;
       }
-      DexEncodedMethod invokedEncodedMethod =
+      DexEncodedMethod encodedSingleTarget =
           invokeMethod.lookupSingleTarget(appView, code.method.method.holder);
-      if (invokedEncodedMethod == null) {
+      if (encodedSingleTarget == null) {
         return Reason.INVALID_INVOKE;
       }
-      DexMethod invokedMethod = invokedEncodedMethod.method;
-      DexClass dexClass = appView.definitionFor(invokedMethod.holder);
+      DexMethod singleTarget = encodedSingleTarget.method;
+      DexClass dexClass = appView.definitionFor(singleTarget.holder);
       if (dexClass == null) {
         return Reason.INVALID_INVOKE;
       }
       if (dexClass.isProgramClass()) {
         // All invokes in the program are generally valid, but specific care is required
         // for values() and valueOf().
-        if (dexClass.isEnum() && factory.enumMethods.isValuesMethod(invokedMethod, dexClass)) {
+        if (dexClass.isEnum() && factory.enumMethods.isValuesMethod(singleTarget, dexClass)) {
           return Reason.VALUES_INVOKE;
         }
-        if (dexClass.isEnum() && factory.enumMethods.isValueOfMethod(invokedMethod, dexClass)) {
+        if (dexClass.isEnum() && factory.enumMethods.isValueOfMethod(singleTarget, dexClass)) {
           return Reason.VALUE_OF_INVOKE;
         }
+        int offset = BooleanUtils.intValue(!encodedSingleTarget.isStatic());
+        for (int i = 0; i < singleTarget.proto.parameters.size(); i++) {
+          if (invokeMethod.inValues().get(offset + i) == enumValue) {
+            if (singleTarget.proto.parameters.values[i] != enumClass.type) {
+              return Reason.GENERIC_INVOKE;
+            }
+          }
+        }
         return Reason.ELIGIBLE;
       }
       if (dexClass.isClasspathClass()) {
@@ -332,17 +337,17 @@
       // TODO(b/147860220): Methods toString(), name(), compareTo(), EnumSet and EnumMap may be
       // interesting to model. A the moment rewrite only Enum#ordinal().
       if (debugLogEnabled) {
-        if (invokedMethod == factory.enumMethods.compareTo) {
+        if (singleTarget == factory.enumMethods.compareTo) {
           return Reason.COMPARE_TO_INVOKE;
         }
-        if (invokedMethod == factory.enumMethods.name) {
+        if (singleTarget == factory.enumMethods.name) {
           return Reason.NAME_INVOKE;
         }
-        if (invokedMethod == factory.enumMethods.toString) {
+        if (singleTarget == factory.enumMethods.toString) {
           return Reason.TO_STRING_INVOKE;
         }
       }
-      if (invokedMethod != factory.enumMethods.ordinal) {
+      if (singleTarget != factory.enumMethods.ordinal) {
         return Reason.UNSUPPORTED_LIBRARY_CALL;
       }
       return Reason.ELIGIBLE;
@@ -393,7 +398,7 @@
 
     if (instruction.isAssume()) {
       Value outValue = instruction.outValue();
-      return validateEnumUsages(code, outValue.uniqueUsers(), outValue.uniquePhiUsers(), enumClass);
+      return validateEnumUsages(code, outValue, enumClass);
     }
 
     // Return is used for valueOf methods.
@@ -467,9 +472,11 @@
 
   public enum Reason {
     ELIGIBLE,
+    DOWN_CAST,
     SUBTYPES,
     INTERFACE,
     INSTANCE_FIELD,
+    GENERIC_INVOKE,
     UNEXPECTED_STATIC_FIELD,
     VIRTUAL_METHOD,
     UNEXPECTED_DIRECT_METHOD,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 264d61f..ea26434 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -129,9 +129,8 @@
     }
     if (!affectedPhis.isEmpty()) {
       new DestructivePhiTypeUpdater(appView).recomputeAndPropagateTypes(code, affectedPhis);
-      assert code.verifyTypes(appView);
     }
-    assert code.isConsistentSSA();
+    assert code.isConsistentSSABeforeTypesAreCorrect();
   }
 
   private boolean validateEnumToUnboxRemoved(Instruction instruction) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackDelayed.java b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackDelayed.java
index c241c0a..3c689a7 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackDelayed.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/OptimizationFeedbackDelayed.java
@@ -177,7 +177,13 @@
   public synchronized void methodReturnsAbstractValue(
       DexEncodedMethod method, AppView<AppInfoWithLiveness> appView, AbstractValue value) {
     if (appView.appInfo().mayPropagateValueFor(method.method)) {
-      getMethodOptimizationInfoForUpdating(method).markReturnsAbstractValue(value);
+      UpdatableMethodOptimizationInfo info = getMethodOptimizationInfoForUpdating(method);
+      assert !info.getAbstractReturnValue().isSingleValue()
+              || info.getAbstractReturnValue().asSingleValue() == value
+              || appView.graphLense().lookupPrototypeChanges(method.method).getRewrittenReturnInfo()
+                  != null
+          : "return single value changed from " + info.getAbstractReturnValue() + " to " + value;
+      info.markReturnsAbstractValue(value);
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/UpdatableMethodOptimizationInfo.java b/src/main/java/com/android/tools/r8/ir/optimize/info/UpdatableMethodOptimizationInfo.java
index af494c8..3b2a85b 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/UpdatableMethodOptimizationInfo.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/UpdatableMethodOptimizationInfo.java
@@ -384,8 +384,6 @@
   }
 
   void markReturnsAbstractValue(AbstractValue value) {
-    assert !abstractReturnValue.isSingleValue() || abstractReturnValue.asSingleValue() == value
-        : "return single value changed from " + abstractReturnValue + " to " + value;
     abstractReturnValue = value;
   }
 
diff --git a/src/main/java/com/android/tools/r8/utils/AndroidApp.java b/src/main/java/com/android/tools/r8/utils/AndroidApp.java
index cccaa07..8dd1b92 100644
--- a/src/main/java/com/android/tools/r8/utils/AndroidApp.java
+++ b/src/main/java/com/android/tools/r8/utils/AndroidApp.java
@@ -210,10 +210,15 @@
 
   public int applicationSize() throws IOException, ResourceException {
     int bytes = 0;
+    assert getDexProgramResourcesForTesting().size() == 0
+        || getClassProgramResourcesForTesting().size() == 0;
     try (Closer closer = Closer.create()) {
       for (ProgramResource dex : getDexProgramResourcesForTesting()) {
         bytes += ByteStreams.toByteArray(closer.register(dex.getByteStream())).length;
       }
+      for (ProgramResource cf : getClassProgramResourcesForTesting()) {
+        bytes += ByteStreams.toByteArray(closer.register(cf.getByteStream())).length;
+      }
     }
     return bytes;
   }
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/FailingMethodEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/FailingMethodEnumUnboxingTest.java
index 0391875..6a3ffea 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/FailingMethodEnumUnboxingTest.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/FailingMethodEnumUnboxingTest.java
@@ -28,7 +28,9 @@
     StaticFieldPutObject.class,
     ToString.class,
     EnumSetTest.class,
-    FailingPhi.class
+    FailingPhi.class,
+      FailingReturnType.class,
+      FailingParameterType.class
   };
 
   private final TestParameters parameters;
@@ -200,4 +202,48 @@
       }
     }
   }
+
+  static class FailingReturnType {
+
+    @NeverClassInline
+    enum MyEnum {
+      A,
+      B,
+      C
+    }
+
+    public static void main(String[] args) {
+      System.out.println(returnObject(MyEnum.A) == MyEnum.A);
+      System.out.println("true");
+      System.out.println(returnObject(MyEnum.B) == MyEnum.B);
+      System.out.println("true");
+    }
+
+    @NeverInline
+    static Object returnObject(MyEnum e) {
+      return System.currentTimeMillis() >= 0 ? e : new Object();
+    }
+  }
+
+  static class FailingParameterType {
+
+    @NeverClassInline
+    enum MyEnum {
+      A,
+      B,
+      C
+    }
+
+    public static void main(String[] args) {
+      System.out.println(objectToInt(MyEnum.A));
+      System.out.println("0");
+      System.out.println(objectToInt(MyEnum.B));
+      System.out.println("1");
+    }
+
+    @NeverInline
+    static int objectToInt(Object e) {
+      return e instanceof Enum ? ((Enum) e).ordinal() : e.hashCode();
+    }
+  }
 }