Remove Objects#requireNonNull for definitely null reference.
Bug: 124246610
Change-Id: Ib593f20c01cc64905da4c32e8baf937446237c7b
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index e47b1b9..7b398f3 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -1611,62 +1611,69 @@
AppInfoWithLiveness appInfoWithLiveness = appView.appInfo().withLiveness();
Set<Value> needToWidenValues = Sets.newIdentityHashSet();
Set<Value> needToNarrowValues = Sets.newIdentityHashSet();
- InstructionIterator iterator = code.instructionIterator();
- while (iterator.hasNext()) {
- Instruction current = iterator.next();
- if (current.isInvokeMethod()) {
- InvokeMethod invoke = current.asInvokeMethod();
- Value outValue = invoke.outValue();
- // TODO(b/124246610): extend to other variants that receive error messages or supplier.
- if (invoke.getInvokedMethod() == dexItemFactory.objectsMethods.requireNonNull) {
- Value obj = invoke.arguments().get(0);
- if ((outValue == null && obj.hasLocalInfo())
- || (outValue != null && !obj.hasSameOrNoLocal(outValue))) {
- continue;
- }
- Nullability nullability = obj.getTypeLattice().nullability();
- if (nullability.isDefinitelyNotNull()) {
- if (outValue != null) {
- outValue.replaceUsers(obj);
- needToNarrowValues.addAll(outValue.affectedValues());
+ Set<BasicBlock> blocksToBeRemoved = Sets.newIdentityHashSet();
+ ListIterator<BasicBlock> blockIterator = code.listIterator();
+ while (blockIterator.hasNext()) {
+ BasicBlock block = blockIterator.next();
+ if (blocksToBeRemoved.contains(block)) {
+ continue;
+ }
+ InstructionListIterator iterator = block.listIterator();
+ while (iterator.hasNext()) {
+ Instruction current = iterator.next();
+ if (current.isInvokeMethod()) {
+ InvokeMethod invoke = current.asInvokeMethod();
+ Value outValue = invoke.outValue();
+ // TODO(b/124246610): extend to other variants that receive error messages or supplier.
+ if (invoke.getInvokedMethod() == dexItemFactory.objectsMethods.requireNonNull) {
+ Value obj = invoke.arguments().get(0);
+ if ((outValue == null && obj.hasLocalInfo())
+ || (outValue != null && !obj.hasSameOrNoLocal(outValue))) {
+ continue;
}
- iterator.removeOrReplaceByDebugLocalRead();
- } else if (nullability.isDefinitelyNull()) {
- // TODO(b/124246610): throw NPE.
- // Refactor UninstantiatedTypeOptimization#replaceCurrentInstructionWithThrowNull
- // and move it to iterator.
- }
- } else if (outValue != null && !outValue.hasLocalInfo()) {
- if (appView
- .dexItemFactory()
- .libraryMethodsReturningReceiver
- .contains(invoke.getInvokedMethod())) {
- if (checkArgumentType(invoke, 0)) {
- outValue.replaceUsers(invoke.arguments().get(0));
- invoke.setOutValue(null);
+ Nullability nullability = obj.getTypeLattice().nullability();
+ if (nullability.isDefinitelyNotNull()) {
+ if (outValue != null) {
+ outValue.replaceUsers(obj);
+ needToNarrowValues.addAll(outValue.affectedValues());
+ }
+ iterator.removeOrReplaceByDebugLocalRead();
+ } else if (obj.isAlwaysNull(appView) && appView.appInfo().hasSubtyping()) {
+ iterator.replaceCurrentInstructionWithThrowNull(
+ appView.withSubtyping(), code, blockIterator, blocksToBeRemoved);
}
- } else if (appInfoWithLiveness != null) {
- DexEncodedMethod target =
- invoke.lookupSingleTarget(appInfoWithLiveness, code.method.method.holder);
- if (target != null) {
- DexMethod invokedMethod = target.method;
- // Check if the invoked method is known to return one of its arguments.
- DexEncodedMethod definition = appView.definitionFor(invokedMethod);
- if (definition != null && definition.getOptimizationInfo().returnsArgument()) {
- int argumentIndex = definition.getOptimizationInfo().getReturnedArgument();
- // Replace the out value of the invoke with the argument and ignore the out value.
- if (argumentIndex >= 0 && checkArgumentType(invoke, argumentIndex)) {
- Value argument = invoke.arguments().get(argumentIndex);
- assert outValue.verifyCompatible(argument.outType());
- if (argument
- .getTypeLattice()
- .lessThanOrEqual(outValue.getTypeLattice(), appView)) {
- needToNarrowValues.addAll(outValue.affectedValues());
- } else {
- needToWidenValues.addAll(outValue.affectedValues());
+ } else if (outValue != null && !outValue.hasLocalInfo()) {
+ if (appView
+ .dexItemFactory()
+ .libraryMethodsReturningReceiver
+ .contains(invoke.getInvokedMethod())) {
+ if (checkArgumentType(invoke, 0)) {
+ outValue.replaceUsers(invoke.arguments().get(0));
+ invoke.setOutValue(null);
+ }
+ } else if (appInfoWithLiveness != null) {
+ DexEncodedMethod target =
+ invoke.lookupSingleTarget(appInfoWithLiveness, code.method.method.holder);
+ if (target != null) {
+ DexMethod invokedMethod = target.method;
+ // Check if the invoked method is known to return one of its arguments.
+ DexEncodedMethod definition = appView.definitionFor(invokedMethod);
+ if (definition != null && definition.getOptimizationInfo().returnsArgument()) {
+ int argumentIndex = definition.getOptimizationInfo().getReturnedArgument();
+ // Replace the out value of the invoke with the argument and ignore the out value.
+ if (argumentIndex >= 0 && checkArgumentType(invoke, argumentIndex)) {
+ Value argument = invoke.arguments().get(argumentIndex);
+ assert outValue.verifyCompatible(argument.outType());
+ if (argument
+ .getTypeLattice()
+ .lessThanOrEqual(outValue.getTypeLattice(), appView)) {
+ needToNarrowValues.addAll(outValue.affectedValues());
+ } else {
+ needToWidenValues.addAll(outValue.affectedValues());
+ }
+ outValue.replaceUsers(argument);
+ invoke.setOutValue(null);
}
- outValue.replaceUsers(argument);
- invoke.setOutValue(null);
}
}
}
@@ -1674,6 +1681,11 @@
}
}
}
+ if (!blocksToBeRemoved.isEmpty()) {
+ code.removeBlocks(blocksToBeRemoved);
+ code.removeAllTrivialPhis();
+ assert code.getUnreachableBlocks().isEmpty();
+ }
if (!needToWidenValues.isEmpty() || !needToNarrowValues.isEmpty()) {
TypeAnalysis analysis = new TypeAnalysis(appView, code.method);
// If out value of invoke < argument (e.g., losing non-null info), widen users type.
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java b/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
index f8ef59e..3d80826 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/ObjectsRequireNonNullTest.java
@@ -4,27 +4,47 @@
package com.android.tools.r8.ir.optimize;
import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThat;
+import static org.junit.Assume.assumeTrue;
+import com.android.tools.r8.D8TestRunResult;
import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.R8TestRunResult;
import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
import com.android.tools.r8.TestRunResult;
import com.android.tools.r8.ToolHelper.DexVm.Version;
-import com.android.tools.r8.VmTestRunner;
-import com.android.tools.r8.VmTestRunner.IgnoreIfVmOlderThan;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.utils.StringUtils;
import com.android.tools.r8.utils.codeinspector.ClassSubject;
import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
import com.android.tools.r8.utils.codeinspector.MethodSubject;
import com.google.common.collect.Streams;
import java.util.Objects;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
class ObjectsRequireNonNullTestMain {
+ static class Uninitialized {
+ void noWayToCall() {
+ System.out.println("Uninitialized, hence no way to call this.");
+ }
+ }
+
+ @NeverPropagateValue
+ @NeverInline
+ static void consumeUninitialized(Uninitialized arg) {
+ Uninitialized nonNullArg = Objects.requireNonNull(arg);
+ // Dead code.
+ nonNullArg.noWayToCall();
+ }
+
static class Foo {
@NeverInline
void bar() {
@@ -60,22 +80,50 @@
} catch (NullPointerException npe) {
System.out.println("Expected NPE");
}
+
+ try {
+ consumeUninitialized(null);
+ } catch (NullPointerException npe) {
+ System.out.println("Expected NPE");
+ }
}
}
-@RunWith(VmTestRunner.class)
+@RunWith(Parameterized.class)
public class ObjectsRequireNonNullTest extends TestBase {
private static final String JAVA_OUTPUT = StringUtils.lines(
"Foo::toString",
"Foo::bar",
"Foo::bar",
+ "Expected NPE",
"Expected NPE"
);
private static final Class<?> MAIN = ObjectsRequireNonNullTestMain.class;
+ @Parameterized.Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters()
+ .withCfRuntimes()
+ // Objects#requireNonNull will be desugared VMs older than API level K.
+ .withDexRuntimesStartingFromExcluding(Version.V4_4_4)
+ .build();
+ }
+
+ private final TestParameters parameters;
+
+ public ObjectsRequireNonNullTest(TestParameters parameters) {
+ this.parameters = parameters;
+ }
+
@Test
public void testJvmOutput() throws Exception {
- testForJvm().addTestClasspath().run(MAIN).assertSuccessWithOutput(JAVA_OUTPUT);
+ assumeTrue(
+ "Only run JVM reference once (for CF backend)",
+ parameters.getBackend() == Backend.CF);
+ testForJvm()
+ .addTestClasspath()
+ .run(parameters.getRuntime(), MAIN)
+ .assertSuccessWithOutput(JAVA_OUTPUT);
}
private static boolean isObjectsRequireNonNull(DexMethod method) {
@@ -92,48 +140,64 @@
})).count();
}
- private void test(TestRunResult result, int expectedCount) throws Exception {
+ private void test(
+ TestRunResult result,
+ int expectedCountInMain,
+ int expectedCountInConsumer) throws Exception {
CodeInspector codeInspector = result.inspector();
ClassSubject mainClass = codeInspector.clazz(MAIN);
MethodSubject mainMethod = mainClass.mainMethod();
assertThat(mainMethod, isPresent());
- long count = countObjectsRequireNonNull(mainMethod);
- assertEquals(expectedCount, count);
+ assertEquals(expectedCountInMain, countObjectsRequireNonNull(mainMethod));
MethodSubject unknownArg = mainClass.uniqueMethodWithName("unknownArg");
assertThat(unknownArg, isPresent());
// Due to the nullable argument, requireNonNull should remain.
assertEquals(1, countObjectsRequireNonNull(unknownArg));
+
+ MethodSubject uninit = mainClass.uniqueMethodWithName("consumeUninitialized");
+ assertThat(uninit, isPresent());
+ assertEquals(expectedCountInConsumer, countObjectsRequireNonNull(uninit));
+ if (expectedCountInConsumer == 0) {
+ assertEquals(
+ 0, Streams.stream(uninit.iterateInstructions(InstructionSubject::isInvoke)).count());
+ assertEquals(
+ 1, Streams.stream(uninit.iterateInstructions(InstructionSubject::isThrow)).count());
+ }
}
@Test
- @IgnoreIfVmOlderThan(Version.V4_4_4)
public void testD8() throws Exception {
- TestRunResult result = testForD8()
+ assumeTrue("Only run D8 for Dex backend", parameters.getBackend() == Backend.DEX);
+ D8TestRunResult result = testForD8()
.debug()
.addProgramClassesAndInnerClasses(MAIN)
- .run(MAIN)
+ .setMinApi(parameters.getRuntime())
+ .run(parameters.getRuntime(), MAIN)
.assertSuccessWithOutput(JAVA_OUTPUT);
- test(result, 2);
+ test(result, 2, 1);
result = testForD8()
.release()
.addProgramClassesAndInnerClasses(MAIN)
- .run(MAIN)
+ .setMinApi(parameters.getRuntime())
+ .run(parameters.getRuntime(), MAIN)
.assertSuccessWithOutput(JAVA_OUTPUT);
- test(result, 0);
+ test(result, 0, 1);
}
@Test
- @IgnoreIfVmOlderThan(Version.V4_4_4)
public void testR8() throws Exception {
- // CF disables move result optimization.
- TestRunResult result = testForR8(Backend.DEX)
+ assumeTrue("CF disables move result optimization", parameters.getBackend() == Backend.DEX);
+ R8TestRunResult result = testForR8(parameters.getBackend())
.addProgramClassesAndInnerClasses(MAIN)
.enableInliningAnnotations()
+ .enableMemberValuePropagationAnnotations()
.addKeepMainRule(MAIN)
- .run(MAIN)
+ .noMinification()
+ .setMinApi(parameters.getRuntime())
+ .run(parameters.getRuntime(), MAIN)
.assertSuccessWithOutput(JAVA_OUTPUT);
- test(result, 0);
+ test(result, 0, 0);
}
}