Fix NumberUnboxer for non unboxable transitive deps

Bug: b/307872552
Change-Id: I257decaacfcc3a12785b855ad9cc0e3dd4d099f4
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerBoxingStatusResolution.java b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerBoxingStatusResolution.java
index ae23092..a685b00 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerBoxingStatusResolution.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerBoxingStatusResolution.java
@@ -24,15 +24,16 @@
 
   // TODO(b/307872552): Add threshold to NumberUnboxing options.
   private static final int UNBOX_DELTA_THRESHOLD = 0;
+  private final Map<DexMethod, MethodBoxingStatus> methodBoxingStatus;
   private final Map<DexMethod, MethodBoxingStatusResult> boxingStatusResultMap =
       new IdentityHashMap<>();
 
-  static class MethodBoxingStatusResult {
+  public NumberUnboxerBoxingStatusResolution(
+      Map<DexMethod, MethodBoxingStatus> methodBoxingStatus) {
+    this.methodBoxingStatus = methodBoxingStatus;
+  }
 
-    public static MethodBoxingStatusResult createNonUnboxable(DexMethod method) {
-      // Replace by singleton.
-      return new MethodBoxingStatusResult(method, NO_UNBOX);
-    }
+  static class MethodBoxingStatusResult {
 
     public static MethodBoxingStatusResult create(DexMethod method) {
       return new MethodBoxingStatusResult(method, TO_PROCESS);
@@ -88,18 +89,18 @@
     }
   }
 
-  void markNoneUnboxable(DexMethod method) {
-    boxingStatusResultMap.put(method, MethodBoxingStatusResult.createNonUnboxable(method));
-  }
-
   private MethodBoxingStatusResult getMethodBoxingStatusResult(DexMethod method) {
+    assert methodBoxingStatus.containsKey(method);
     return boxingStatusResultMap.computeIfAbsent(method, MethodBoxingStatusResult::create);
   }
 
   BoxingStatusResult get(TransitiveDependency transitiveDependency) {
     assert transitiveDependency.isMethodDependency();
-    MethodBoxingStatusResult methodBoxingStatusResult =
-        getMethodBoxingStatusResult(transitiveDependency.asMethodDependency().getMethod());
+    DexMethod method = transitiveDependency.asMethodDependency().getMethod();
+    if (!methodBoxingStatus.containsKey(method)) {
+      return NO_UNBOX;
+    }
+    MethodBoxingStatusResult methodBoxingStatusResult = getMethodBoxingStatusResult(method);
     if (transitiveDependency.isMethodRet()) {
       return methodBoxingStatusResult.getRet();
     }
@@ -109,8 +110,14 @@
 
   void register(TransitiveDependency transitiveDependency, BoxingStatusResult boxingStatusResult) {
     assert transitiveDependency.isMethodDependency();
-    MethodBoxingStatusResult methodBoxingStatusResult =
-        getMethodBoxingStatusResult(transitiveDependency.asMethodDependency().getMethod());
+    DexMethod method = transitiveDependency.asMethodDependency().getMethod();
+    if (boxingStatusResult == NO_UNBOX) {
+      if (!methodBoxingStatus.containsKey(method)) {
+        // Nothing to unbox, nothing to register.
+        return;
+      }
+    }
+    MethodBoxingStatusResult methodBoxingStatusResult = getMethodBoxingStatusResult(method);
     if (transitiveDependency.isMethodRet()) {
       methodBoxingStatusResult.setRet(boxingStatusResult);
       return;
@@ -120,22 +127,18 @@
         boxingStatusResult, transitiveDependency.asMethodArg().getParameterIndex());
   }
 
-  public Map<DexMethod, MethodBoxingStatusResult> resolve(
-      Map<DexMethod, MethodBoxingStatus> methodBoxingStatus) {
+  public Map<DexMethod, MethodBoxingStatusResult> resolve() {
     assert allProcessedAndUnboxable(methodBoxingStatus);
     List<DexMethod> methods = ListUtils.sort(methodBoxingStatus.keySet(), DexMethod::compareTo);
     for (DexMethod method : methods) {
       MethodBoxingStatus status = methodBoxingStatus.get(method);
-      if (status.isNoneUnboxable()) {
-        markNoneUnboxable(method);
-        continue;
-      }
+      assert !status.isNoneUnboxable();
       MethodBoxingStatusResult methodBoxingStatusResult = getMethodBoxingStatusResult(method);
       if (status.getReturnStatus().isNotUnboxable()) {
         methodBoxingStatusResult.setRet(NO_UNBOX);
       } else {
         if (methodBoxingStatusResult.getRet() == TO_PROCESS) {
-          resolve(methodBoxingStatus, new MethodRet(method));
+          resolve(new MethodRet(method));
         }
       }
       for (int i = 0; i < status.getArgStatuses().length; i++) {
@@ -144,16 +147,25 @@
           methodBoxingStatusResult.setArg(NO_UNBOX, i);
         } else {
           if (methodBoxingStatusResult.getArg(i) == TO_PROCESS) {
-            resolve(methodBoxingStatus, new MethodArg(i, method));
+            resolve(new MethodArg(i, method));
           }
         }
       }
     }
+    assert noResultForNoneUnboxable();
     assert allProcessed();
     clearNoneUnboxable();
     return boxingStatusResultMap;
   }
 
+  private boolean noResultForNoneUnboxable() {
+    boxingStatusResultMap.forEach(
+        (k, v) -> {
+          assert methodBoxingStatus.containsKey(k);
+        });
+    return true;
+  }
+
   private boolean allProcessedAndUnboxable(Map<DexMethod, MethodBoxingStatus> methodBoxingStatus) {
     methodBoxingStatus.forEach(
         (k, v) -> {
@@ -178,11 +190,14 @@
     return true;
   }
 
-  private ValueBoxingStatus getValueBoxingStatus(
-      TransitiveDependency dep, Map<DexMethod, MethodBoxingStatus> methodBoxingStatus) {
+  private ValueBoxingStatus getValueBoxingStatus(TransitiveDependency dep) {
     // Later we will implement field dependencies.
     assert dep.isMethodDependency();
     MethodBoxingStatus status = methodBoxingStatus.get(dep.asMethodDependency().getMethod());
+    if (status == null) {
+      // Nothing was recorded because nothing was unboxable.
+      return ValueBoxingStatus.NOT_UNBOXABLE;
+    }
     if (dep.isMethodRet()) {
       return status.getReturnStatus();
     }
@@ -190,8 +205,7 @@
     return status.getArgStatus(dep.asMethodArg().getParameterIndex());
   }
 
-  private void resolve(
-      Map<DexMethod, MethodBoxingStatus> methodBoxingStatus, TransitiveDependency dep) {
+  private void resolve(TransitiveDependency dep) {
     WorkList<TransitiveDependency> workList = WorkList.newIdentityWorkList(dep);
     int delta = 0;
     while (workList.hasNext()) {
@@ -201,7 +215,7 @@
         delta++;
         continue;
       }
-      ValueBoxingStatus valueBoxingStatus = getValueBoxingStatus(next, methodBoxingStatus);
+      ValueBoxingStatus valueBoxingStatus = getValueBoxingStatus(next);
       if (boxingStatusResult == NO_UNBOX || valueBoxingStatus.isNotUnboxable()) {
         // TODO(b/307872552): Unbox when a non unboxable non null dependency is present.
         // If a dependency is not unboxable, we need to prove it's non-null, else we cannot unbox.
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerImpl.java b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerImpl.java
index f8d2500..0c471b0 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerImpl.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/numberunboxer/NumberUnboxerImpl.java
@@ -340,7 +340,7 @@
       ExecutorService executorService)
       throws ExecutionException {
     Map<DexMethod, MethodBoxingStatusResult> unboxingResult =
-        new NumberUnboxerBoxingStatusResolution().resolve(candidateBoxingStatus);
+        new NumberUnboxerBoxingStatusResolution(candidateBoxingStatus).resolve();
     if (unboxingResult.isEmpty()) {
       return;
     }
diff --git a/src/main/java/com/android/tools/r8/utils/ArrayUtils.java b/src/main/java/com/android/tools/r8/utils/ArrayUtils.java
index 27d7516..176ca67 100644
--- a/src/main/java/com/android/tools/r8/utils/ArrayUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/ArrayUtils.java
@@ -3,6 +3,7 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.utils;
 
+import com.android.tools.r8.graph.DexType;
 import java.lang.reflect.Array;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -235,4 +236,13 @@
     optionals[ts.length] = Optional.empty();
     return optionals;
   }
+
+  public static boolean any(DexType[] values, Predicate<DexType> predicate) {
+    for (DexType value : values) {
+      if (predicate.test(value)) {
+        return true;
+      }
+    }
+    return false;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/ListUtils.java b/src/main/java/com/android/tools/r8/utils/ListUtils.java
index 4be2a4e..769905a 100644
--- a/src/main/java/com/android/tools/r8/utils/ListUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/ListUtils.java
@@ -284,6 +284,15 @@
     return existingMappedRanges == null ? null : last(existingMappedRanges);
   }
 
+  public static <T> boolean all(List<T> items, Predicate<T> predicate) {
+    for (T item : items) {
+      if (predicate.test(item)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   public interface ReferenceAndIntConsumer<T> {
     void accept(T item, int index);
   }
diff --git a/src/test/java/com/android/tools/r8/numberunboxing/CannotUnboxNumberUnboxingTest.java b/src/test/java/com/android/tools/r8/numberunboxing/CannotUnboxNumberUnboxingTest.java
index 04f6b1b..a143236 100644
--- a/src/test/java/com/android/tools/r8/numberunboxing/CannotUnboxNumberUnboxingTest.java
+++ b/src/test/java/com/android/tools/r8/numberunboxing/CannotUnboxNumberUnboxingTest.java
@@ -45,7 +45,7 @@
         .compile()
         .inspect(this::assertUnboxing)
         .run(parameters.getRuntime(), Main.class)
-        .assertSuccessWithOutputLines("null", "2", "2", "1");
+        .assertSuccessWithOutputLines("null", "2", "2", "1", "null", "2", "1");
   }
 
   private void assertUnboxing(CodeInspector codeInspector) {
@@ -60,7 +60,36 @@
   }
 
   static class Main {
+
     public static void main(String[] args) {
+      cannotUnboxPrint();
+      depsNonUnboxable();
+    }
+
+    @NeverInline
+    private static void depsNonUnboxable() {
+      try {
+        forward(null);
+      } catch (NullPointerException npe) {
+        System.out.println("null");
+      }
+      forward(1);
+      forward(0);
+    }
+
+    @NeverInline
+    private static void forward(Integer i) {
+      // Here print2 will get i as a deps which is non-unboxable.
+      print2(i);
+    }
+
+    @NeverInline
+    private static void print2(Integer boxed) {
+      System.out.println(boxed + 1);
+    }
+
+    @NeverInline
+    private static void cannotUnboxPrint() {
       try {
         print(System.currentTimeMillis() > 0 ? null : -1);
       } catch (NullPointerException npe) {