Fix cast to int in enum unboxer

Bug: b/406300397
Change-Id: Ief0a7c4ef0af60f1b2425f0f40398e4ffc18051d
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index d68abb1..512d204 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -1199,8 +1199,11 @@
               .getArgumentInfo(argumentIndex)
               .asRewrittenTypeInfo();
       if (rewrittenTypeInfo != null && rewrittenTypeInfo.hasCastType()) {
-        iterator.previous();
         Value object = invoke.getArgument(argumentIndex);
+        if (object.getType().isNullType()) {
+          continue;
+        }
+        iterator.previous();
         CheckCast checkCast =
             SafeCheckCast.builder()
                 .setObject(object)
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 3c9111e..e114635 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -25,6 +25,7 @@
 import com.android.tools.r8.ir.code.ArrayPut;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.BasicBlockIterator;
+import com.android.tools.r8.ir.code.CheckCast;
 import com.android.tools.r8.ir.code.ConstNumber;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.If;
@@ -181,8 +182,14 @@
           iterator.removeOrReplaceByDebugLocalRead();
           continue;
         }
-
-        if (instruction.isInitClass()) {
+        if (instruction.isCheckCast()) {
+          CheckCast checkCast = instruction.asCheckCast();
+          DexType enumType = getEnumClassTypeOrNull(checkCast.getType());
+          if (enumType != null) {
+            checkCast.outValue().replaceUsers(checkCast.object());
+            iterator.removeOrReplaceByDebugLocalRead();
+          }
+        } else if (instruction.isInitClass()) {
           InitClass initClass = instruction.asInitClass();
           DexType enumType = getEnumClassTypeOrNull(initClass.getClassValue());
           if (enumType != null) {
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/CastToIntEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/CastToIntEnumUnboxingTest.java
index 86e341d..5825cfe 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/CastToIntEnumUnboxingTest.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/CastToIntEnumUnboxingTest.java
@@ -3,8 +3,6 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.enumunboxing;
 
-import static com.android.tools.r8.utils.codeinspector.AssertUtils.assertFailsCompilation;
-
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
@@ -23,22 +21,24 @@
 
   @Parameters(name = "{0}")
   public static TestParametersCollection data() {
-    return getTestParameters().withAllRuntimesAndApiLevels().build();
+    return getTestParameters().withAllRuntimesAndApiLevels().withPartialCompilation().build();
   }
 
   @Test
   public void test() throws Exception {
-    assertFailsCompilation(
-        () ->
-            testForR8(parameters.getBackend())
-                .addInnerClasses(getClass())
-                .addKeepMainRule(Main.class)
-                .addEnumUnboxingInspector(inspector -> inspector.assertUnboxed(MyEnum.class))
-                .enableInliningAnnotations()
-                .setMinApi(parameters)
-                .compile()
-                .run(parameters.getRuntime(), Main.class)
-                .assertSuccessWithOutputLines("B"));
+    testForR8(parameters)
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addEnumUnboxingInspector(
+            inspector -> {
+              if (!parameters.isRandomPartialCompilation()) {
+                inspector.assertUnboxed(MyEnum.class);
+              }
+            })
+        .enableInliningAnnotations()
+        .compile()
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("B");
   }
 
   static class Main {
diff --git a/src/test/testbase/java/com/android/tools/r8/R8PartialTestBuilder.java b/src/test/testbase/java/com/android/tools/r8/R8PartialTestBuilder.java
index d20a667..9412134 100644
--- a/src/test/testbase/java/com/android/tools/r8/R8PartialTestBuilder.java
+++ b/src/test/testbase/java/com/android/tools/r8/R8PartialTestBuilder.java
@@ -13,6 +13,7 @@
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.Box;
 import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.codeinspector.EnumUnboxingInspector;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -190,6 +191,15 @@
   }
 
   @Override
+  public R8PartialTestBuilder addEnumUnboxingInspector(Consumer<EnumUnboxingInspector> inspector) {
+    return addR8PartialR8OptionsModification(
+        options ->
+            options.testing.unboxedEnumsConsumer =
+                (dexItemFactory, unboxedEnums) ->
+                    inspector.accept(new EnumUnboxingInspector(dexItemFactory, unboxedEnums)));
+  }
+
+  @Override
   public R8PartialTestBuilder allowUnnecessaryDontWarnWildcards() {
     return addR8PartialR8OptionsModification(
         options -> options.getTestingOptions().allowUnnecessaryDontWarnWildcards = true);
diff --git a/src/test/testbase/java/com/android/tools/r8/TestParameters.java b/src/test/testbase/java/com/android/tools/r8/TestParameters.java
index 25702c9..d980f2c 100644
--- a/src/test/testbase/java/com/android/tools/r8/TestParameters.java
+++ b/src/test/testbase/java/com/android/tools/r8/TestParameters.java
@@ -221,6 +221,10 @@
     return partialCompilationTestParameters;
   }
 
+  public boolean isRandomPartialCompilation() {
+    return getPartialCompilationTestParameters().isRandom();
+  }
+
   // Access to underlying runtime/wrapper.
   public TestRuntime getRuntime() {
     return runtime;