Fix Enum unboxing and Objects.equals

Bug: b/287193321
Change-Id: I7c6be8e40b43d2342b4ca011e6cbdf704ef187ef
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index 867c2dc..06816b4 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -46,6 +46,7 @@
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.InternalOptions;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -508,13 +509,28 @@
         rewriteStringValueOf(invoke, context, convertedEnums, instructionIterator, eventConsumer);
       } else if (invokedMethod == factory.objectsMethods.equals) {
         assert invoke.arguments().size() == 2;
-        Value argument = invoke.getFirstArgument();
-        DexType enumType = getEnumClassTypeOrNull(argument, convertedEnums);
-        if (enumType != null) {
+        if (Iterables.any(
+            invoke.arguments(), arg -> getEnumClassTypeOrNull(arg, convertedEnums) != null)) {
+          // If any of the input is null, replace it by const 0.
+          // If both inputs are null, no rewriting happen here.
+          List<Value> newArguments = new ArrayList<>(invoke.arguments().size());
+          for (Value arg : invoke.arguments()) {
+            if (arg.getType().isNullType()) {
+              Value constZero = insertConstZero(code);
+              newArguments.add(constZero);
+            } else {
+              assert getEnumClassTypeOrNull(arg, convertedEnums) != null;
+              newArguments.add(arg);
+            }
+          }
           replaceEnumInvoke(
               instructionIterator,
               invoke,
-              getSharedUtilityClass().ensureObjectsEqualsMethod(appView, context, eventConsumer));
+              getSharedUtilityClass().ensureObjectsEqualsMethod(appView, context, eventConsumer),
+              newArguments);
+        } else {
+          assert invoke.getArgument(0).getType().isReferenceType();
+          assert invoke.getArgument(1).getType().isReferenceType();
         }
       }
       return;
@@ -687,11 +703,19 @@
 
   private void replaceEnumInvoke(
       InstructionListIterator iterator, InvokeMethod invoke, ProgramMethod method) {
+    replaceEnumInvoke(iterator, invoke, method, invoke.arguments());
+  }
+
+  private void replaceEnumInvoke(
+      InstructionListIterator iterator,
+      InvokeMethod invoke,
+      ProgramMethod method,
+      List<Value> arguments) {
     InvokeStatic replacement =
         new InvokeStatic(
             method.getReference(),
             invoke.hasUnusedOutValue() ? null : invoke.outValue(),
-            invoke.arguments());
+            arguments);
     assert !replacement.hasOutValue()
         || !replacement.getInvokedMethod().getReturnType().isVoidType();
     iterator.replaceCurrentInstruction(replacement);
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxNull2ArgumentTest.java b/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxNull2ArgumentTest.java
new file mode 100644
index 0000000..1130986
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/enumunboxing/EnumUnboxNull2ArgumentTest.java
@@ -0,0 +1,85 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.enumunboxing;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import java.util.Objects;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+/** This is a regression test for b/287193321. */
+@RunWith(Parameterized.class)
+public class EnumUnboxNull2ArgumentTest extends TestBase {
+
+  @Parameter() public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .setMinApi(parameters)
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(options -> options.testing.disableLir())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("true", "null");
+  }
+
+  public enum MyEnum {
+    FOO("1"),
+    BAR("2");
+
+    final String value;
+
+    MyEnum(String value) {
+      this.value = value;
+    }
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      // Delay observing that arguments to bar is null until we've inlined foo() and getEnum().
+      String foo = foo();
+      String[] bar = bar(getEnum(), foo);
+      // To ensure bar(MyEnum,String) is not inlined in the first round we add a few additional
+      // calls that will be stripped during IR-processing of main.
+      if (foo != null) {
+        bar(MyEnum.FOO, foo);
+        bar(MyEnum.BAR, foo);
+      }
+      for (String b : bar) {
+        System.out.println(b);
+      }
+    }
+
+    public static String[] bar(MyEnum myEnum, String foo) {
+      if (System.currentTimeMillis() > 1) {
+        MyEnum e = System.currentTimeMillis() > 1 ? null : MyEnum.FOO;
+        // Ensure that the construction is in a separate block than entry() to have constant
+        // canonicalization align the two null values into one.
+        return new String[] {Objects.toString(Objects.equals(myEnum, e)), foo};
+      }
+      return new String[] {};
+    }
+
+    public static MyEnum getEnum() {
+      return null;
+    }
+
+    public static String foo() {
+      return null;
+    }
+  }
+}