Use candidate model for methods which may be number unboxed

- Remove method from candidate list when non unboxable

Change-Id: I88c7b5c4fc71a6cb3e7b56f57b7584e70e2bc29c
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/MethodBoxingStatus.java b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/MethodBoxingStatus.java
index 75d2340..6e1f5a7 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/MethodBoxingStatus.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/MethodBoxingStatus.java
@@ -9,12 +9,14 @@
 public class MethodBoxingStatus {
 
   public static final MethodBoxingStatus NONE_UNBOXABLE = new MethodBoxingStatus(null, null);
+  public static final MethodBoxingStatus UNPROCESSED_CANDIDATE = new MethodBoxingStatus(null, null);
 
   private final ValueBoxingStatus returnStatus;
   private final ValueBoxingStatus[] argStatuses;
 
   public static MethodBoxingStatus create(
       ValueBoxingStatus returnStatus, ValueBoxingStatus[] argStatuses) {
+    assert !ArrayUtils.contains(argStatuses, null);
     if (returnStatus.isNotUnboxable()
         && ArrayUtils.all(argStatuses, ValueBoxingStatus.NOT_UNBOXABLE)) {
       return NONE_UNBOXABLE;
@@ -31,6 +33,12 @@
     if (isNoneUnboxable() || other.isNoneUnboxable()) {
       return NONE_UNBOXABLE;
     }
+    if (isUnprocessedCandidate()) {
+      return other;
+    }
+    if (other.isUnprocessedCandidate()) {
+      return this;
+    }
     assert argStatuses.length == other.argStatuses.length;
     ValueBoxingStatus[] newArgStatuses = new ValueBoxingStatus[argStatuses.length];
     for (int i = 0; i < other.argStatuses.length; i++) {
@@ -43,6 +51,10 @@
     return this == NONE_UNBOXABLE;
   }
 
+  public boolean isUnprocessedCandidate() {
+    return this == UNPROCESSED_CANDIDATE;
+  }
+
   public ValueBoxingStatus getReturnStatus() {
     assert !isNoneUnboxable();
     return returnStatus;
@@ -50,6 +62,7 @@
 
   public ValueBoxingStatus getArgStatus(int i) {
     assert !isNoneUnboxable();
+    assert argStatuses[i] != null;
     return argStatuses[i];
   }
 
@@ -62,7 +75,9 @@
   public String toString() {
     StringBuilder sb = new StringBuilder();
     sb.append("MethodBoxingStatus[");
-    if (this == NONE_UNBOXABLE) {
+    if (isUnprocessedCandidate()) {
+      sb.append("UNPROCESSED_CANDIDATE");
+    } else if (isNoneUnboxable()) {
       sb.append("NONE_UNBOXABLE");
     } else {
       for (int i = 0; i < argStatuses.length; i++) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerBoxingStatusResolution.java b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerBoxingStatusResolution.java
index 4a9cbfd..ae23092 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerBoxingStatusResolution.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerBoxingStatusResolution.java
@@ -122,6 +122,7 @@
 
   public Map<DexMethod, MethodBoxingStatusResult> resolve(
       Map<DexMethod, MethodBoxingStatus> methodBoxingStatus) {
+    assert allProcessedAndUnboxable(methodBoxingStatus);
     List<DexMethod> methods = ListUtils.sort(methodBoxingStatus.keySet(), DexMethod::compareTo);
     for (DexMethod method : methods) {
       MethodBoxingStatus status = methodBoxingStatus.get(method);
@@ -153,6 +154,15 @@
     return boxingStatusResultMap;
   }
 
+  private boolean allProcessedAndUnboxable(Map<DexMethod, MethodBoxingStatus> methodBoxingStatus) {
+    methodBoxingStatus.forEach(
+        (k, v) -> {
+          assert !v.isNoneUnboxable() : v + " registered for " + k;
+          assert !v.isUnprocessedCandidate() : v + " registered for " + k;
+        });
+    return true;
+  }
+
   private void clearNoneUnboxable() {
     boxingStatusResultMap.values().removeIf(MethodBoxingStatusResult::isNoneUnboxable);
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerImpl.java b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerImpl.java
index b171360..f8d2500 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerImpl.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerImpl.java
@@ -4,11 +4,13 @@
 
 package com.android.tools.r8.ir.optimize.numberunboxer;
 
+import static com.android.tools.r8.ir.optimize.numberunboxer.MethodBoxingStatus.UNPROCESSED_CANDIDATE;
 import static com.android.tools.r8.ir.optimize.numberunboxer.NumberUnboxerBoxingStatusResolution.MethodBoxingStatusResult.BoxingStatusResult.UNBOX;
 import static com.android.tools.r8.ir.optimize.numberunboxer.ValueBoxingStatus.NOT_UNBOXABLE;
 
 import com.android.tools.r8.errors.Unimplemented;
 import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClassAndMember;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
@@ -31,6 +33,7 @@
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.DexMethodSignatureMap;
+import com.google.common.collect.Iterables;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
@@ -48,9 +51,10 @@
   private final DexItemFactory factory;
   private final Set<DexType> boxedTypes;
 
-  // Temporarily keep the information here, and not in the MethodOptimizationInfo as the
-  // optimization is developed and unfinished.
-  private final Map<DexMethod, MethodBoxingStatus> methodBoxingStatus = new ConcurrentHashMap<>();
+  // All candidate methods are initialized to UNPROCESSED_CANDIDATE (bottom) and methods not in
+  // this map are not subject to unboxing.
+  private final Map<DexMethod, MethodBoxingStatus> candidateBoxingStatus =
+      new ConcurrentHashMap<>();
   private Map<DexMethod, DexMethod> virtualMethodsRepresentative;
 
   public NumberUnboxerImpl(AppView<AppInfoWithLiveness> appView) {
@@ -89,22 +93,36 @@
   // TODO(b/307872552): Do not store irrelevant representative.
   private Map<DexMethod, DexMethod> computeVirtualMethodRepresentative(
       Set<DexProgramClass> component) {
-    DexMethodSignatureMap<List<DexMethod>> componentVirtualMethods = DexMethodSignatureMap.create();
+    DexMethodSignatureMap<List<ProgramMethod>> componentVirtualMethods =
+        DexMethodSignatureMap.create();
     for (DexProgramClass clazz : component) {
       for (ProgramMethod virtualProgramMethod : clazz.virtualProgramMethods()) {
-        DexMethod reference = virtualProgramMethod.getReference();
-        List<DexMethod> set =
+        List<ProgramMethod> set =
             componentVirtualMethods.computeIfAbsent(virtualProgramMethod, k -> new ArrayList<>());
-        set.add(reference);
+        set.add(virtualProgramMethod);
+      }
+      for (ProgramMethod candidate : clazz.directProgramMethods()) {
+        if (shouldConsiderForUnboxing(candidate)) {
+          candidateBoxingStatus.put(candidate.getReference(), UNPROCESSED_CANDIDATE);
+        }
       }
     }
     Map<DexMethod, DexMethod> vMethodRepresentative = new IdentityHashMap<>();
-    for (List<DexMethod> vMethods : componentVirtualMethods.values()) {
+    for (List<ProgramMethod> vMethods : componentVirtualMethods.values()) {
       if (vMethods.size() > 1) {
-        vMethods.sort(Comparator.naturalOrder());
-        DexMethod representative = vMethods.get(0);
-        for (int i = 1; i < vMethods.size(); i++) {
-          vMethodRepresentative.put(vMethods.get(i), representative);
+        if (Iterables.all(vMethods, this::shouldConsiderForUnboxing)) {
+          vMethods.sort(Comparator.comparing(DexClassAndMember::getReference));
+          ProgramMethod representative = vMethods.get(0);
+          for (int i = 1; i < vMethods.size(); i++) {
+            vMethodRepresentative.put(
+                vMethods.get(i).getReference(), representative.getReference());
+          }
+        }
+      } else {
+        assert vMethods.size() == 1;
+        ProgramMethod candidate = vMethods.get(0);
+        if (shouldConsiderForUnboxing(candidate)) {
+          candidateBoxingStatus.put(candidate.getReference(), UNPROCESSED_CANDIDATE);
         }
       }
     }
@@ -113,8 +131,12 @@
 
   private void registerMethodUnboxingStatusIfNeeded(
       ProgramMethod method, ValueBoxingStatus returnStatus, ValueBoxingStatus[] args) {
-    if (args == null && returnStatus == null) {
-      // We don't register anything if nothing unboxable was found.
+    DexMethod representative =
+        virtualMethodsRepresentative.getOrDefault(method.getReference(), method.getReference());
+    if (args == null && (returnStatus == null || returnStatus.isNotUnboxable())) {
+      // Effectively NOT_UNBOXABLE, remove the candidate.
+      // TODO(b/307872552): Do we need to remove at the end of the wave for determinism?
+      candidateBoxingStatus.remove(representative);
       return;
     }
     ValueBoxingStatus nonNullReturnStatus = returnStatus == null ? NOT_UNBOXABLE : returnStatus;
@@ -122,16 +144,13 @@
         args == null ? ValueBoxingStatus.notUnboxableArray(method.getReference().getArity()) : args;
     MethodBoxingStatus unboxingStatus = MethodBoxingStatus.create(nonNullReturnStatus, nonNullArgs);
     assert !unboxingStatus.isNoneUnboxable();
-    DexMethod representative =
-        virtualMethodsRepresentative.getOrDefault(method.getReference(), method.getReference());
-    methodBoxingStatus.compute(
-        representative,
-        (m, old) -> {
-          if (old == null) {
-            return unboxingStatus;
-          }
-          return old.merge(unboxingStatus);
-        });
+    MethodBoxingStatus newStatus =
+        candidateBoxingStatus.computeIfPresent(
+            representative, (m, old) -> old.merge(unboxingStatus));
+    if (newStatus != null && newStatus.isNoneUnboxable()) {
+      // TODO(b/307872552): Do we need to remove at the end of the wave for determinism?
+      candidateBoxingStatus.remove(representative);
+    }
   }
 
   /**
@@ -150,6 +169,7 @@
         if (unboxingStatus.mayBeUnboxable()) {
           if (args == null) {
             args = new ValueBoxingStatus[contextReference.getArity()];
+            Arrays.fill(args, NOT_UNBOXABLE);
           }
           args[next.asArgument().getIndex() - shift] = unboxingStatus;
         }
@@ -206,12 +226,25 @@
   }
 
   private boolean shouldConsiderForUnboxing(Value value) {
+    return value.getType().isClassType()
+        && shouldConsiderForUnboxing(value.getType().asClassType().getClassType());
+  }
+
+  private boolean shouldConsiderForUnboxing(ProgramMethod method) {
+    if (appView.getKeepInfo().isPinned(method, appView.options())) {
+      return false;
+    }
+    return shouldConsiderForUnboxing(method.getReturnType())
+        || Iterables.any(method.getParameters(), this::shouldConsiderForUnboxing);
+  }
+
+  private boolean shouldConsiderForUnboxing(DexType type) {
     // TODO(b/307872552): So far we consider only boxed type value to unbox them into their
     // corresponding primitive type, for example, Integer -> int. It would be nice to support
     // the pattern checkCast(BoxType) followed by a boxing operation, so that for example when
     // we have MyClass<T> and T is proven to be an Integer, we can unbox into int.
-    return value.getType().isClassType()
-        && boxedTypes.contains(value.getType().asClassType().getClassType());
+    // Types to consider: Object, Serializable, Comparable, Number.
+    return boxedTypes.contains(type);
   }
 
   // Inputs are values flowing into a method return, an invoke argument or a field write.
@@ -222,7 +255,7 @@
     DexType boxedType = inValue.getType().asClassType().getClassType();
     DexType primitiveType = factory.primitiveToBoxed.inverse().get(boxedType);
     DexMethod boxPrimitiveMethod = factory.getBoxPrimitiveMethod(primitiveType);
-    if (!inValue.isPhi()) {
+    if (!inValue.getAliasedValue().isPhi()) {
       Instruction definition = inValue.getAliasedValue().getDefinition();
       if (definition.isArgument()) {
         int shift = BooleanUtils.intValue(!context.getDefinition().isStatic());
@@ -307,7 +340,7 @@
       ExecutorService executorService)
       throws ExecutionException {
     Map<DexMethod, MethodBoxingStatusResult> unboxingResult =
-        new NumberUnboxerBoxingStatusResolution().resolve(methodBoxingStatus);
+        new NumberUnboxerBoxingStatusResolution().resolve(candidateBoxingStatus);
     if (unboxingResult.isEmpty()) {
       return;
     }
diff --git a/src/test/java/com/android/tools/r8/numberunboxing/CannotUnboxNumberUnboxingTest.java b/src/test/java/com/android/tools/r8/numberunboxing/CannotUnboxNumberUnboxingTest.java
new file mode 100644
index 0000000..04f6b1b
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/numberunboxing/CannotUnboxNumberUnboxingTest.java
@@ -0,0 +1,79 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.numberunboxing;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class CannotUnboxNumberUnboxingTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public CannotUnboxNumberUnboxingTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testNumberUnboxing() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .enableInliningAnnotations()
+        .addOptionsModification(opt -> opt.testing.enableNumberUnboxer = true)
+        .setMinApi(parameters)
+        .compile()
+        .inspect(this::assertUnboxing)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("null", "2", "2", "1");
+  }
+
+  private void assertUnboxing(CodeInspector codeInspector) {
+    ClassSubject mainClass = codeInspector.clazz(Main.class);
+    assertThat(mainClass, isPresent());
+
+    MethodSubject methodSubject = mainClass.uniqueMethodWithOriginalName("print");
+    assertThat(methodSubject, isPresent());
+    assertEquals("java.lang.Integer", methodSubject.getOriginalSignature().parameters[0]);
+    assertEquals(
+        "java.lang.Integer", methodSubject.getFinalSignature().asMethodSignature().parameters[0]);
+  }
+
+  static class Main {
+    public static void main(String[] args) {
+      try {
+        print(System.currentTimeMillis() > 0 ? null : -1);
+      } catch (NullPointerException npe) {
+        System.out.println("null");
+      }
+      print(System.currentTimeMillis() > 0 ? 1 : 0);
+      print(1);
+      print(0);
+    }
+
+    @NeverInline
+    private static void print(Integer boxed) {
+      System.out.println(boxed + 1);
+    }
+  }
+}