Enum unboxing: allow nulls in phis
Bug: 166397278
Change-Id: I042f74a811d4aacf013956fd9c230dbd0e408153
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
index d4adf44..fbe8514 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxer.java
@@ -330,7 +330,8 @@
}
for (Phi phi : value.uniquePhiUsers()) {
for (Value operand : phi.getOperands()) {
- if (getEnumUnboxingCandidateOrNull(operand.getType()) != enumClass) {
+ if (!operand.getType().isNullType()
+ && getEnumUnboxingCandidateOrNull(operand.getType()) != enumClass) {
markEnumAsUnboxable(Reason.INVALID_PHI, enumClass);
return Reason.INVALID_PHI;
}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
index c67c231..e09dab4 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingRewriter.java
@@ -31,6 +31,7 @@
import com.android.tools.r8.ir.analysis.type.TypeElement;
import com.android.tools.r8.ir.analysis.value.AbstractValue;
import com.android.tools.r8.ir.code.ArrayAccess;
+import com.android.tools.r8.ir.code.BasicBlock;
import com.android.tools.r8.ir.code.ConstNumber;
import com.android.tools.r8.ir.code.IRCode;
import com.android.tools.r8.ir.code.InstanceGet;
@@ -55,6 +56,7 @@
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
+import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
@@ -148,182 +150,213 @@
assert code.isConsistentSSABeforeTypesAreCorrect();
Map<Instruction, DexType> convertedEnums = new IdentityHashMap<>();
Set<Phi> affectedPhis = Sets.newIdentityHashSet();
- InstructionListIterator iterator = code.instructionListIterator();
- while (iterator.hasNext()) {
- Instruction instruction = iterator.next();
- // Rewrites specific enum methods, such as ordinal, into their corresponding enum unboxed
- // counterpart.
- if (instruction.isInvokeMethodWithReceiver()) {
- InvokeMethodWithReceiver invokeMethod = instruction.asInvokeMethodWithReceiver();
- DexMethod invokedMethod = invokeMethod.getInvokedMethod();
- DexType enumType = getEnumTypeOrNull(invokeMethod.getReceiver(), convertedEnums);
- if (enumType != null) {
- if (invokedMethod == factory.enumMembers.ordinalMethod
- || invokedMethod == factory.enumMembers.hashCode) {
- replaceEnumInvoke(
- iterator, invokeMethod, ordinalUtilityMethod, m -> synthesizeOrdinalMethod());
- continue;
- } else if (invokedMethod == factory.enumMembers.equals) {
- replaceEnumInvoke(
- iterator, invokeMethod, equalsUtilityMethod, m -> synthesizeEqualsMethod());
- continue;
- } else if (invokedMethod == factory.enumMembers.compareTo) {
- replaceEnumInvoke(
- iterator, invokeMethod, compareToUtilityMethod, m -> synthesizeCompareToMethod());
- continue;
- } else if (invokedMethod == factory.enumMembers.nameMethod
- || invokedMethod == factory.enumMembers.toString) {
- DexMethod toStringMethod =
- computeInstanceFieldUtilityMethod(enumType, factory.enumMembers.nameField);
- iterator.replaceCurrentInstruction(
- new InvokeStatic(
- toStringMethod, invokeMethod.outValue(), invokeMethod.arguments()));
- continue;
- } else if (invokedMethod == factory.objectMembers.getClass) {
- assert !invokeMethod.hasOutValue() || !invokeMethod.outValue().hasAnyUsers();
- replaceEnumInvoke(
- iterator, invokeMethod, zeroCheckMethod, m -> synthesizeZeroCheckMethod());
- }
- }
- // TODO(b/147860220): rewrite also other enum methods.
- } else if (instruction.isInvokeStatic()) {
- InvokeStatic invokeStatic = instruction.asInvokeStatic();
- DexMethod invokedMethod = invokeStatic.getInvokedMethod();
- if (invokedMethod == factory.enumMembers.valueOf
- && invokeStatic.inValues().get(0).isConstClass()) {
- DexType enumType =
- invokeStatic.inValues().get(0).getConstInstruction().asConstClass().getValue();
- if (enumsToUnbox.containsEnum(enumType)) {
- DexMethod valueOfMethod = computeValueOfUtilityMethod(enumType);
- Value outValue = invokeStatic.outValue();
- Value rewrittenOutValue = null;
- if (outValue != null) {
- rewrittenOutValue = code.createValue(TypeElement.getInt());
- affectedPhis.addAll(outValue.uniquePhiUsers());
+ ListIterator<BasicBlock> blocks = code.listIterator();
+ Value zeroConstValue = null;
+ while (blocks.hasNext()) {
+ BasicBlock block = blocks.next();
+ zeroConstValue = fixNullsInBlockPhis(code, block, zeroConstValue);
+ InstructionListIterator iterator = block.listIterator(code);
+ while (iterator.hasNext()) {
+ Instruction instruction = iterator.next();
+ // Rewrites specific enum methods, such as ordinal, into their corresponding enum unboxed
+ // counterpart.
+ if (instruction.isInvokeMethodWithReceiver()) {
+ InvokeMethodWithReceiver invokeMethod = instruction.asInvokeMethodWithReceiver();
+ DexMethod invokedMethod = invokeMethod.getInvokedMethod();
+ DexType enumType = getEnumTypeOrNull(invokeMethod.getReceiver(), convertedEnums);
+ if (enumType != null) {
+ if (invokedMethod == factory.enumMembers.ordinalMethod
+ || invokedMethod == factory.enumMembers.hashCode) {
+ replaceEnumInvoke(
+ iterator, invokeMethod, ordinalUtilityMethod, m -> synthesizeOrdinalMethod());
+ continue;
+ } else if (invokedMethod == factory.enumMembers.equals) {
+ replaceEnumInvoke(
+ iterator, invokeMethod, equalsUtilityMethod, m -> synthesizeEqualsMethod());
+ continue;
+ } else if (invokedMethod == factory.enumMembers.compareTo) {
+ replaceEnumInvoke(
+ iterator, invokeMethod, compareToUtilityMethod, m -> synthesizeCompareToMethod());
+ continue;
+ } else if (invokedMethod == factory.enumMembers.nameMethod
+ || invokedMethod == factory.enumMembers.toString) {
+ DexMethod toStringMethod =
+ computeInstanceFieldUtilityMethod(enumType, factory.enumMembers.nameField);
+ iterator.replaceCurrentInstruction(
+ new InvokeStatic(
+ toStringMethod, invokeMethod.outValue(), invokeMethod.arguments()));
+ continue;
+ } else if (invokedMethod == factory.objectMembers.getClass) {
+ assert !invokeMethod.hasOutValue() || !invokeMethod.outValue().hasAnyUsers();
+ replaceEnumInvoke(
+ iterator, invokeMethod, zeroCheckMethod, m -> synthesizeZeroCheckMethod());
}
- InvokeStatic invoke =
- new InvokeStatic(
- valueOfMethod,
- rewrittenOutValue,
- Collections.singletonList(invokeStatic.inValues().get(1)));
- iterator.replaceCurrentInstruction(invoke);
- convertedEnums.put(invoke, enumType);
- continue;
}
- } else if (invokedMethod == factory.javaLangSystemMethods.identityHashCode) {
- assert invokeStatic.arguments().size() == 1;
- Value argument = invokeStatic.getArgument(0);
- DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
- if (enumType != null) {
- invokeStatic.outValue().replaceUsers(argument);
- iterator.removeOrReplaceByDebugLocalRead();
- }
- } else if (invokedMethod == factory.stringMembers.valueOf) {
- assert invokeStatic.arguments().size() == 1;
- Value argument = invokeStatic.getArgument(0);
- DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
- if (enumType != null) {
- DexMethod stringValueOfMethod = computeStringValueOfUtilityMethod(enumType);
- iterator.replaceCurrentInstruction(
- new InvokeStatic(
- stringValueOfMethod, invokeStatic.outValue(), invokeStatic.arguments()));
- continue;
- }
- } else if (invokedMethod == factory.objectsMethods.requireNonNull) {
- assert invokeStatic.arguments().size() == 1;
- Value argument = invokeStatic.getArgument(0);
- DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
- if (enumType != null) {
- replaceEnumInvoke(
- iterator, invokeStatic, zeroCheckMethod, m -> synthesizeZeroCheckMethod());
- }
- } else if (invokedMethod == factory.objectsMethods.requireNonNullWithMessage) {
- assert invokeStatic.arguments().size() == 2;
- Value argument = invokeStatic.getArgument(0);
- DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
- if (enumType != null) {
- replaceEnumInvoke(
- iterator,
- invokeStatic,
- zeroCheckMessageMethod,
- m -> synthesizeZeroCheckMessageMethod());
+ // TODO(b/147860220): rewrite also other enum methods.
+ } else if (instruction.isInvokeStatic()) {
+ InvokeStatic invokeStatic = instruction.asInvokeStatic();
+ DexMethod invokedMethod = invokeStatic.getInvokedMethod();
+ if (invokedMethod == factory.enumMembers.valueOf
+ && invokeStatic.inValues().get(0).isConstClass()) {
+ DexType enumType =
+ invokeStatic.inValues().get(0).getConstInstruction().asConstClass().getValue();
+ if (enumsToUnbox.containsEnum(enumType)) {
+ DexMethod valueOfMethod = computeValueOfUtilityMethod(enumType);
+ Value outValue = invokeStatic.outValue();
+ Value rewrittenOutValue = null;
+ if (outValue != null) {
+ rewrittenOutValue = code.createValue(TypeElement.getInt());
+ affectedPhis.addAll(outValue.uniquePhiUsers());
+ }
+ InvokeStatic invoke =
+ new InvokeStatic(
+ valueOfMethod,
+ rewrittenOutValue,
+ Collections.singletonList(invokeStatic.inValues().get(1)));
+ iterator.replaceCurrentInstruction(invoke);
+ convertedEnums.put(invoke, enumType);
+ continue;
+ }
+ } else if (invokedMethod == factory.javaLangSystemMethods.identityHashCode) {
+ assert invokeStatic.arguments().size() == 1;
+ Value argument = invokeStatic.getArgument(0);
+ DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+ if (enumType != null) {
+ invokeStatic.outValue().replaceUsers(argument);
+ iterator.removeOrReplaceByDebugLocalRead();
+ }
+ } else if (invokedMethod == factory.stringMembers.valueOf) {
+ assert invokeStatic.arguments().size() == 1;
+ Value argument = invokeStatic.getArgument(0);
+ DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+ if (enumType != null) {
+ DexMethod stringValueOfMethod = computeStringValueOfUtilityMethod(enumType);
+ iterator.replaceCurrentInstruction(
+ new InvokeStatic(
+ stringValueOfMethod, invokeStatic.outValue(), invokeStatic.arguments()));
+ continue;
+ }
+ } else if (invokedMethod == factory.objectsMethods.requireNonNull) {
+ assert invokeStatic.arguments().size() == 1;
+ Value argument = invokeStatic.getArgument(0);
+ DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+ if (enumType != null) {
+ replaceEnumInvoke(
+ iterator, invokeStatic, zeroCheckMethod, m -> synthesizeZeroCheckMethod());
+ }
+ } else if (invokedMethod == factory.objectsMethods.requireNonNullWithMessage) {
+ assert invokeStatic.arguments().size() == 2;
+ Value argument = invokeStatic.getArgument(0);
+ DexType enumType = getEnumTypeOrNull(argument, convertedEnums);
+ if (enumType != null) {
+ replaceEnumInvoke(
+ iterator,
+ invokeStatic,
+ zeroCheckMessageMethod,
+ m -> synthesizeZeroCheckMessageMethod());
+ }
}
}
- }
- if (instruction.isStaticGet()) {
- StaticGet staticGet = instruction.asStaticGet();
- DexType holder = staticGet.getField().holder;
- if (enumsToUnbox.containsEnum(holder)) {
- if (staticGet.outValue() == null) {
- iterator.removeOrReplaceByDebugLocalRead();
- continue;
+ if (instruction.isStaticGet()) {
+ StaticGet staticGet = instruction.asStaticGet();
+ DexType holder = staticGet.getField().holder;
+ if (enumsToUnbox.containsEnum(holder)) {
+ if (staticGet.outValue() == null) {
+ iterator.removeOrReplaceByDebugLocalRead();
+ continue;
+ }
+ EnumValueInfoMap enumValueInfoMap = enumsToUnbox.getEnumValueInfoMap(holder);
+ assert enumValueInfoMap != null;
+ affectedPhis.addAll(staticGet.outValue().uniquePhiUsers());
+ EnumValueInfo enumValueInfo = enumValueInfoMap.getEnumValueInfo(staticGet.getField());
+ if (enumValueInfo == null && staticGet.getField().name == factory.enumValuesFieldName) {
+ utilityMethods.computeIfAbsent(
+ valuesUtilityMethod, m -> synthesizeValuesUtilityMethod());
+ DexField fieldValues = createValuesField(holder);
+ utilityFields.computeIfAbsent(fieldValues, this::computeValuesEncodedField);
+ DexMethod methodValues = createValuesMethod(holder);
+ utilityMethods.computeIfAbsent(
+ methodValues,
+ m -> computeValuesEncodedMethod(m, fieldValues, enumValueInfoMap.size()));
+ Value rewrittenOutValue =
+ code.createValue(
+ ArrayTypeElement.create(TypeElement.getInt(), definitelyNotNull()));
+ InvokeStatic invoke =
+ new InvokeStatic(methodValues, rewrittenOutValue, ImmutableList.of());
+ iterator.replaceCurrentInstruction(invoke);
+ convertedEnums.put(invoke, holder);
+ } else {
+ // Replace by ordinal + 1 for null check (null is 0).
+ assert enumValueInfo != null
+ : "Invalid read to " + staticGet.getField().name + ", error during enum analysis";
+ ConstNumber intConstant = code.createIntConstant(enumValueInfo.convertToInt());
+ iterator.replaceCurrentInstruction(intConstant);
+ convertedEnums.put(intConstant, holder);
+ }
}
- EnumValueInfoMap enumValueInfoMap = enumsToUnbox.getEnumValueInfoMap(holder);
- assert enumValueInfoMap != null;
- affectedPhis.addAll(staticGet.outValue().uniquePhiUsers());
- EnumValueInfo enumValueInfo = enumValueInfoMap.getEnumValueInfo(staticGet.getField());
- if (enumValueInfo == null && staticGet.getField().name == factory.enumValuesFieldName) {
- utilityMethods.computeIfAbsent(
- valuesUtilityMethod, m -> synthesizeValuesUtilityMethod());
- DexField fieldValues = createValuesField(holder);
- utilityFields.computeIfAbsent(fieldValues, this::computeValuesEncodedField);
- DexMethod methodValues = createValuesMethod(holder);
- utilityMethods.computeIfAbsent(
- methodValues,
- m -> computeValuesEncodedMethod(m, fieldValues, enumValueInfoMap.size()));
+ }
+
+ if (instruction.isInstanceGet()) {
+ InstanceGet instanceGet = instruction.asInstanceGet();
+ DexType holder = instanceGet.getField().holder;
+ if (enumsToUnbox.containsEnum(holder)) {
+ DexMethod fieldMethod = computeInstanceFieldMethod(instanceGet.getField());
Value rewrittenOutValue =
code.createValue(
- ArrayTypeElement.create(TypeElement.getInt(), definitelyNotNull()));
+ TypeElement.fromDexType(
+ fieldMethod.proto.returnType, Nullability.maybeNull(), appView));
InvokeStatic invoke =
- new InvokeStatic(methodValues, rewrittenOutValue, ImmutableList.of());
+ new InvokeStatic(
+ fieldMethod, rewrittenOutValue, ImmutableList.of(instanceGet.object()));
iterator.replaceCurrentInstruction(invoke);
- convertedEnums.put(invoke, holder);
- } else {
- // Replace by ordinal + 1 for null check (null is 0).
- assert enumValueInfo != null
- : "Invalid read to " + staticGet.getField().name + ", error during enum analysis";
- ConstNumber intConstant = code.createIntConstant(enumValueInfo.convertToInt());
- iterator.replaceCurrentInstruction(intConstant);
- convertedEnums.put(intConstant, holder);
+ if (enumsToUnbox.containsEnum(instanceGet.getField().type)) {
+ convertedEnums.put(invoke, instanceGet.getField().type);
+ }
}
}
- }
- if (instruction.isInstanceGet()) {
- InstanceGet instanceGet = instruction.asInstanceGet();
- DexType holder = instanceGet.getField().holder;
- if (enumsToUnbox.containsEnum(holder)) {
- DexMethod fieldMethod = computeInstanceFieldMethod(instanceGet.getField());
- Value rewrittenOutValue =
- code.createValue(
- TypeElement.fromDexType(
- fieldMethod.proto.returnType, Nullability.maybeNull(), appView));
- InvokeStatic invoke =
- new InvokeStatic(
- fieldMethod, rewrittenOutValue, ImmutableList.of(instanceGet.object()));
- iterator.replaceCurrentInstruction(invoke);
- if (enumsToUnbox.containsEnum(instanceGet.getField().type)) {
- convertedEnums.put(invoke, instanceGet.getField().type);
+ // Rewrite array accesses from MyEnum[] (OBJECT) to int[] (INT).
+ if (instruction.isArrayAccess()) {
+ ArrayAccess arrayAccess = instruction.asArrayAccess();
+ DexType enumType = getEnumTypeOrNull(arrayAccess);
+ if (enumType != null) {
+ instruction = arrayAccess.withMemberType(MemberType.INT);
+ iterator.replaceCurrentInstruction(instruction);
+ convertedEnums.put(instruction, enumType);
}
+ assert validateArrayAccess(arrayAccess);
}
}
-
- // Rewrite array accesses from MyEnum[] (OBJECT) to int[] (INT).
- if (instruction.isArrayAccess()) {
- ArrayAccess arrayAccess = instruction.asArrayAccess();
- DexType enumType = getEnumTypeOrNull(arrayAccess);
- if (enumType != null) {
- instruction = arrayAccess.withMemberType(MemberType.INT);
- iterator.replaceCurrentInstruction(instruction);
- convertedEnums.put(instruction, enumType);
- }
- assert validateArrayAccess(arrayAccess);
- }
}
assert code.isConsistentSSABeforeTypesAreCorrect();
return affectedPhis;
}
+ private Value fixNullsInBlockPhis(IRCode code, BasicBlock block, Value zeroConstValue) {
+ for (Phi phi : block.getPhis()) {
+ if (getEnumTypeOrNull(phi.getType()) != null) {
+ for (int i = 0; i < phi.getOperands().size(); i++) {
+ Value operand = phi.getOperand(i);
+ if (operand.getType().isNullType()) {
+ if (zeroConstValue == null) {
+ zeroConstValue = insertConstZero(code);
+ }
+ phi.replaceOperandAt(i, zeroConstValue);
+ }
+ }
+ }
+ }
+ return zeroConstValue;
+ }
+
+ private Value insertConstZero(IRCode code) {
+ InstructionListIterator iterator = code.entryBlock().listIterator(code);
+ while (iterator.hasNext() && iterator.peekNext().isArgument()) {
+ iterator.next();
+ }
+ return iterator.insertConstNumberInstruction(code, appView.options(), 0, TypeElement.getInt());
+ }
+
private DexMethod computeInstanceFieldMethod(DexField field) {
EnumInstanceFieldKnownData enumFieldKnownData =
unboxedEnumsInstanceFieldData.getInstanceFieldData(field.holder, field);
@@ -362,6 +395,10 @@
if (type.isInt()) {
return convertedEnums.get(receiver.definition);
}
+ return getEnumTypeOrNull(type);
+ }
+
+ private DexType getEnumTypeOrNull(TypeElement type) {
if (!type.isClassType()) {
return null;
}
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/PhiEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/PhiEnumUnboxingTest.java
index 8aa2d5e..7a0c8bb 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/PhiEnumUnboxingTest.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/PhiEnumUnboxingTest.java
@@ -17,8 +17,6 @@
@RunWith(Parameterized.class)
public class PhiEnumUnboxingTest extends EnumUnboxingTestBase {
- private static final Class<?> ENUM_CLASS = MyEnum.class;
-
private final TestParameters parameters;
private final boolean enumValueOptimization;
private final EnumKeepRules enumKeepRules;
@@ -37,11 +35,10 @@
@Test
public void testEnumUnboxing() throws Exception {
- Class<?> classToTest = Phi.class;
R8TestRunResult run =
testForR8(parameters.getBackend())
- .addProgramClasses(classToTest, ENUM_CLASS)
- .addKeepMainRule(classToTest)
+ .addProgramClasses(Phi.class, MyEnum.class)
+ .addKeepMainRule(Phi.class)
.addKeepRules(enumKeepRules.getKeepRules())
.enableInliningAnnotations()
.enableNeverClassInliningAnnotations()
@@ -50,8 +47,8 @@
.setMinApi(parameters.getApiLevel())
.compile()
.inspectDiagnosticMessages(
- m -> assertEnumIsUnboxed(ENUM_CLASS, classToTest.getSimpleName(), m))
- .run(parameters.getRuntime(), classToTest)
+ m -> assertEnumIsUnboxed(MyEnum.class, Phi.class.getSimpleName(), m))
+ .run(parameters.getRuntime(), Phi.class)
.assertSuccess();
assertLines2By2Correct(run.getStdOut());
}
@@ -66,23 +63,54 @@
static class Phi {
public static void main(String[] args) {
+ nonNullTest();
+ nullTest();
+ }
+
+ private static void nonNullTest() {
System.out.println(switchOn(1).ordinal());
System.out.println(1);
System.out.println(switchOn(2).ordinal());
System.out.println(2);
}
- // Avoid removing the switch entirely.
+ private static void nullTest() {
+ System.out.println(switchOnWithNull(1).ordinal());
+ System.out.println(1);
+ System.out.println(switchOnWithNull(2) == null);
+ System.out.println(true);
+ }
+
@NeverInline
static MyEnum switchOn(int i) {
+ MyEnum returnValue;
switch (i) {
case 0:
- return MyEnum.A;
+ returnValue = MyEnum.A;
+ break;
case 1:
- return MyEnum.B;
+ returnValue = MyEnum.B;
+ break;
default:
- return MyEnum.C;
+ returnValue = MyEnum.C;
}
+ return returnValue;
+ }
+
+ @NeverInline
+ static MyEnum switchOnWithNull(int i) {
+ MyEnum returnValue;
+ switch (i) {
+ case 0:
+ returnValue = MyEnum.A;
+ break;
+ case 1:
+ returnValue = MyEnum.B;
+ break;
+ default:
+ returnValue = null;
+ }
+ return returnValue;
}
}
}