Enum unboxing: Fix null returns

Bug: 147860220
Change-Id: I95f3206ccf73bfd791a67e9e8c7e7df09f99c261
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCode.java b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
index 064842a..b7e5a0b 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCode.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCode.java
@@ -583,6 +583,25 @@
     return true;
   }
 
+  public boolean hasConsistentReturnTypes(AppView<?> appView) {
+    DexType returnType = method.method.proto.returnType;
+    for (BasicBlock block : blocks) {
+      Return returnInstruction = block.exit().asReturn();
+      if (returnInstruction != null) {
+        if (returnInstruction.isReturnVoid()) {
+          assert returnType == appView.dexItemFactory().voidType;
+        } else {
+          assert returnInstruction
+              .returnValue()
+              .getType()
+              .lessThanOrEqualUpToNullability(
+                  TypeElement.fromDexType(returnType, Nullability.maybeNull(), appView), appView);
+        }
+      }
+    }
+    return true;
+  }
+
   public boolean isConsistentGraph() {
     assert noColorsInUse();
     assert consistentBlockNumbering();
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 0d91bab..567394c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -26,6 +26,7 @@
 import static com.android.tools.r8.ir.code.Opcodes.MOVE_EXCEPTION;
 import static com.android.tools.r8.ir.code.Opcodes.NEW_ARRAY_EMPTY;
 import static com.android.tools.r8.ir.code.Opcodes.NEW_INSTANCE;
+import static com.android.tools.r8.ir.code.Opcodes.RETURN;
 import static com.android.tools.r8.ir.code.Opcodes.STATIC_GET;
 import static com.android.tools.r8.ir.code.Opcodes.STATIC_PUT;
 
@@ -80,6 +81,7 @@
 import com.android.tools.r8.ir.code.NewArrayEmpty;
 import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.code.Phi;
+import com.android.tools.r8.ir.code.Return;
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.StaticPut;
 import com.android.tools.r8.ir.code.Value;
@@ -491,6 +493,27 @@
             }
             break;
 
+          case RETURN:
+            {
+              Return ret = current.asReturn();
+              if (ret.isReturnVoid()) {
+                break;
+              }
+              DexType returnType = code.method.method.proto.returnType;
+              Value retValue = ret.returnValue();
+              DexType initialType =
+                  retValue.getType().isPrimitiveType()
+                      ? retValue.getType().asPrimitiveType().toDexType(factory)
+                      : factory.objectType; // Place holder, any reference type will do.
+              Value rewrittenValue =
+                  rewriteValueIfDefault(code, iterator, initialType, returnType, retValue);
+              if (retValue != rewrittenValue) {
+                Return newReturn = new Return(rewrittenValue);
+                iterator.replaceCurrentInstruction(newReturn);
+              }
+            }
+            break;
+
           default:
             if (current.hasOutValue()) {
               // For all other instructions, substitute any changed type.
@@ -514,6 +537,7 @@
     }
     assert code.isConsistentSSABeforeTypesAreCorrect();
     assert code.hasNoVerticallyMergedClasses(appView);
+    assert code.hasConsistentReturnTypes(appView);
   }
 
   // If the initialValue is a default value and its type is rewritten from a reference type to a
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index 61a7c13..87d28a6 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -136,7 +136,7 @@
             }
           }
           if (outValue.getType().isNullType()) {
-            addNullDependencies(outValue.uniqueUsers(), eligibleEnums);
+            addNullDependencies(code, outValue.uniqueUsers(), eligibleEnums);
           }
         } else {
           if (instruction.isInvokeMethod()) {
@@ -178,7 +178,7 @@
           }
         }
         if (phi.getType().isNullType()) {
-          addNullDependencies(phi.uniqueUsers(), eligibleEnums);
+          addNullDependencies(code, phi.uniqueUsers(), eligibleEnums);
         }
       }
     }
@@ -244,7 +244,7 @@
     eligibleEnums.add(constClass.getValue());
   }
 
-  private void addNullDependencies(Set<Instruction> uses, Set<DexType> eligibleEnums) {
+  private void addNullDependencies(IRCode code, Set<Instruction> uses, Set<DexType> eligibleEnums) {
     for (Instruction use : uses) {
       if (use.isInvokeMethod()) {
         InvokeMethod invokeMethod = use.asInvokeMethod();
@@ -260,12 +260,16 @@
             markEnumAsUnboxable(Reason.ENUM_METHOD_CALLED_WITH_NULL_RECEIVER, enumClass);
           }
         }
-      }
-      if (use.isFieldPut()) {
+      } else if (use.isFieldPut()) {
         DexType type = use.asFieldInstruction().getField().type;
         if (enumsUnboxingCandidates.containsKey(type)) {
           eligibleEnums.add(type);
         }
+      } else if (use.isReturn()) {
+        DexType returnType = code.method.method.proto.returnType;
+        if (enumsUnboxingCandidates.containsKey(returnType)) {
+          eligibleEnums.add(returnType);
+        }
       }
     }
   }
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingReturnNullTest.java b/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingReturnNullTest.java
new file mode 100644
index 0000000..2a1be5f
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxingReturnNullTest.java
@@ -0,0 +1,102 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.enumunboxing;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.StringUtils;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class EnumUnboxingReturnNullTest extends EnumUnboxingTestBase {
+
+  private static final Class<?> ENUM_CLASS = MyEnum.class;
+  private static final String EXPECTED_RESULT =
+      StringUtils.lines(
+          "print1", "true", "print2", "true", "print2", "false", "0", "print3", "true");
+
+  private final TestParameters parameters;
+  private final boolean enumValueOptimization;
+  private final KeepRule enumKeepRules;
+
+  @Parameterized.Parameters(name = "{0} valueOpt: {1} keep: {2}")
+  public static List<Object[]> data() {
+    return enumUnboxingTestParameters();
+  }
+
+  public EnumUnboxingReturnNullTest(
+      TestParameters parameters, boolean enumValueOptimization, KeepRule enumKeepRules) {
+    this.parameters = parameters;
+    this.enumValueOptimization = enumValueOptimization;
+    this.enumKeepRules = enumKeepRules;
+  }
+
+  @Test
+  public void testEnumUnboxing() throws Exception {
+    Class<?> classToTest = ReturnNull.class;
+    testForR8(parameters.getBackend())
+        .addProgramClasses(classToTest, ENUM_CLASS)
+        .addKeepMainRule(classToTest)
+        .addKeepRules(enumKeepRules.getKeepRule())
+        .enableNeverClassInliningAnnotations()
+        .enableInliningAnnotations()
+        .addOptionsModification(opt -> enableEnumOptions(opt, enumValueOptimization))
+        .allowDiagnosticInfoMessages()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspectDiagnosticMessages(
+            m -> assertEnumIsUnboxed(ENUM_CLASS, classToTest.getSimpleName(), m))
+        .run(parameters.getRuntime(), classToTest)
+        .assertSuccessWithOutput(EXPECTED_RESULT);
+  }
+
+  @NeverClassInline
+  enum MyEnum {
+    A,
+    B,
+    C
+  }
+
+  static class ReturnNull {
+
+    public static void main(String[] args) {
+      MyEnum myEnum1 = printAndReturnNull();
+      System.out.println(myEnum1 == null);
+      MyEnum myEnum2 = printAndReturnMaybeNull(true);
+      System.out.println(myEnum2 == null);
+      MyEnum myEnum3 = printAndReturnMaybeNull(false);
+      System.out.println(myEnum3 == null);
+      System.out.println(MyEnum.A.ordinal());
+      MyEnum[] myEnums = printAndReturnNullArray();
+      System.out.println(myEnums == null);
+    }
+
+    @NeverInline
+    static MyEnum printAndReturnNull() {
+      System.out.println("print1");
+      return null;
+    }
+
+    @NeverInline
+    static MyEnum printAndReturnMaybeNull(boolean bool) {
+      System.out.println("print2");
+      if (bool) {
+        return null;
+      } else {
+        return MyEnum.B;
+      }
+    }
+
+    @NeverInline
+    static MyEnum[] printAndReturnNullArray() {
+      System.out.println("print3");
+      return null;
+    }
+  }
+}