Reland "Add reproduction for assertion error in trivial checkcast removal"

This reverts commit 6d8e69054512c9f24e818c7b1638055e3ab777ec.

Change-Id: I53b31729b8e212915c77ed756ce4195282c1dd82
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 8f4650a..4cb2fab 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -1432,7 +1432,7 @@
     // c = ...        // Even though we know c is of type A,
     // a' = (B) c;    // (this could be removed, since chained below.)
     // a'' = (A) a';  // this should remain for runtime verification.
-    assert !inTypeLattice.isDefinitelyNull();
+    assert !inTypeLattice.isDefinitelyNull() || (inValue.isPhi() && !inTypeLattice.isNullType());
     assert outTypeLattice.equalUpToNullability(castTypeLattice);
     return RemoveCheckCastInstructionIfTrivialResult.NO_REMOVALS;
   }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/checkcast/CheckCastNullForTypeTest.java b/src/test/java/com/android/tools/r8/ir/optimize/checkcast/CheckCastNullForTypeTest.java
new file mode 100644
index 0000000..4f60021
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/checkcast/CheckCastNullForTypeTest.java
@@ -0,0 +1,88 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.checkcast;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static junit.framework.TestCase.assertEquals;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+/** This is a reproduction of b/160856783. */
+@RunWith(Parameterized.class)
+public class CheckCastNullForTypeTest extends TestBase {
+
+  private final TestParameters parameters;
+  private static final String EXPECTED = StringUtils.lines("null", "null");
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public CheckCastNullForTypeTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testRuntime() throws Exception {
+    testForRuntime(parameters)
+        .addProgramClasses(A.class, Main.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClasses(A.class, Main.class)
+        .setMinApi(parameters.getApiLevel())
+        .addKeepClassAndMembersRules(Main.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutput(EXPECTED)
+        .inspect(
+            codeInspector -> {
+              ClassSubject main = codeInspector.clazz(Main.class);
+              assertThat(main, isPresent());
+              MethodSubject mainMethod = main.uniqueMethodWithName("main");
+              assertThat(mainMethod, isPresent());
+              // TODO(b/160856783): Investigate if this can be removed.
+              assertEquals(
+                  1,
+                  mainMethod
+                      .streamInstructions()
+                      .filter(instruction -> instruction.isCheckCast(Main.class.getTypeName()))
+                      .count());
+            });
+  }
+
+  public static class A {}
+
+  public static class Main {
+
+    private static void print(Main main) {
+      System.out.println(main);
+    }
+
+    public static void main(String[] args) {
+      A a = null;
+      Main main;
+      do {
+        main = (Main) (Object) a;
+      } while ((a = (A) null) != null);
+      System.out.println(a);
+      print(main);
+    }
+  }
+}