Make Record equals allocation free
Bug: b/378780054
Change-Id: Ie88fb8544e261a6c0b8033c71c39d43b65125a9d
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/CfInstructionDesugaringEventConsumer.java b/src/main/java/com/android/tools/r8/ir/desugar/CfInstructionDesugaringEventConsumer.java
index 437692a..ace4c1e 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/CfInstructionDesugaringEventConsumer.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/CfInstructionDesugaringEventConsumer.java
@@ -221,7 +221,7 @@
@Override
public void acceptRecordEqualsHelperMethod(ProgramMethod method, ProgramMethod context) {
- // Intentionally empty. Added to the program using ProgramAdditions.
+ methodProcessor.scheduleMethodForProcessing(method, outermostEventConsumer);
}
@Override
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/records/RecordInstructionDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/records/RecordInstructionDesugaring.java
index dee85f5..f484a69 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/records/RecordInstructionDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/records/RecordInstructionDesugaring.java
@@ -15,6 +15,9 @@
import com.android.tools.r8.cf.code.CfInvoke;
import com.android.tools.r8.cf.code.CfInvokeDynamic;
import com.android.tools.r8.cf.code.CfLoad;
+import com.android.tools.r8.cf.code.CfStackInstruction;
+import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
+import com.android.tools.r8.cf.code.CfStore;
import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
import com.android.tools.r8.dex.Constants;
import com.android.tools.r8.errors.CompilationError;
@@ -36,10 +39,12 @@
import com.android.tools.r8.ir.desugar.CfInstructionDesugaring;
import com.android.tools.r8.ir.desugar.CfInstructionDesugaringEventConsumer;
import com.android.tools.r8.ir.desugar.DesugarDescription;
+import com.android.tools.r8.ir.desugar.FreshLocalProvider;
import com.android.tools.r8.ir.desugar.LocalStackAllocator;
import com.android.tools.r8.ir.desugar.ProgramAdditions;
import com.android.tools.r8.ir.desugar.records.RecordDesugaringEventConsumer.RecordInstructionDesugaringEventConsumer;
import com.android.tools.r8.ir.desugar.records.RecordRewriterHelper.RecordInvokeDynamic;
+import com.android.tools.r8.ir.synthetic.RecordCfCodeProvider.RecordEqCfCodeProvider;
import com.android.tools.r8.ir.synthetic.RecordCfCodeProvider.RecordEqualsCfCodeProvider;
import com.android.tools.r8.ir.synthetic.RecordCfCodeProvider.RecordGetFieldsAsObjectsCfCodeProvider;
import com.android.tools.r8.ir.synthetic.RecordCfCodeProvider.RecordHashCfCodeProvider;
@@ -179,6 +184,7 @@
dexItemFactory) ->
desugarInvokeDynamicOnRecord(
instruction.asInvokeDynamic(),
+ freshLocalProvider,
localStackAllocator,
eventConsumer,
context,
@@ -206,6 +212,7 @@
@SuppressWarnings("ReferenceEquality")
private List<CfInstruction> desugarInvokeDynamicOnRecord(
CfInvokeDynamic invokeDynamic,
+ FreshLocalProvider freshLocalProvider,
LocalStackAllocator localStackAllocator,
CfInstructionDesugaringEventConsumer eventConsumer,
ProgramMethod context,
@@ -223,6 +230,7 @@
if (recordInvokeDynamic.getMethodName() == factory.hashCodeMethodName) {
return desugarInvokeRecordHashCode(
recordInvokeDynamic,
+ freshLocalProvider,
localStackAllocator,
eventConsumer,
context,
@@ -234,12 +242,6 @@
throw new Unreachable("Invoke dynamic needs record desugaring but could not be desugared.");
}
- private ProgramMethod synthesizeEqualsRecordMethod(
- DexProgramClass clazz, DexMethod getFieldsAsObjects, DexMethod method) {
- return synthesizeMethod(
- clazz, new RecordEqualsCfCodeProvider(appView, clazz.type, getFieldsAsObjects), method);
- }
-
private ProgramMethod synthesizeGetFieldsAsObjectsMethod(
DexProgramClass clazz, DexField[] fields, DexMethod method) {
return synthesizeMethod(
@@ -290,14 +292,18 @@
ProgramAdditions programAdditions,
ProgramMethod context,
RecordInstructionDesugaringEventConsumer eventConsumer) {
- DexMethod getFieldsAsObjects =
- ensureGetFieldsAsObjects(recordInvokeDynamic, programAdditions, context, eventConsumer);
DexProgramClass clazz = recordInvokeDynamic.getRecordClass();
DexMethod method = equalsRecordMethod(clazz.type);
assert clazz.lookupProgramMethod(method) == null;
+ Pair<List<DexField>, List<DexType>> pair = sortedInstanceFields(clazz.instanceFields());
ProgramMethod equalsHelperMethod =
programAdditions.ensureMethod(
- method, () -> synthesizeEqualsRecordMethod(clazz, getFieldsAsObjects, method));
+ method,
+ () ->
+ synthesizeMethod(
+ clazz,
+ new RecordEqCfCodeProvider(appView, clazz.type, pair.getFirst()),
+ method));
eventConsumer.acceptRecordEqualsHelperMethod(equalsHelperMethod, context);
}
@@ -358,8 +364,36 @@
return recordClass.instanceFields().size() < MAX_FIELDS_FOR_OUTLINE;
}
+ // Answers an ordered map of the fields with the type for the hashCode proto.
+ private Pair<List<DexField>, List<DexType>> sortedInstanceFields(
+ List<DexEncodedField> instanceFields) {
+ Map<DexType, List<DexField>> temp = new IdentityHashMap<>();
+ for (DexEncodedField instanceField : instanceFields) {
+ DexType protoType =
+ instanceField.getType().isBooleanType()
+ ? instanceField.getType()
+ : ValueType.fromDexType(instanceField.getType()).toDexType(factory);
+ temp.computeIfAbsent(protoType, ignored -> new ArrayList<>())
+ .add(instanceField.getReference());
+ }
+ Pair<List<DexField>, List<DexType>> pair = new Pair<>(new ArrayList<>(), new ArrayList<>());
+ for (DexType orderedSharedType : orderedSharedTypes) {
+ List<DexField> dexFields = temp.get(orderedSharedType);
+ if (dexFields != null) {
+ for (DexField dexField : dexFields) {
+ pair.getFirst().add(dexField);
+ pair.getSecond().add(orderedSharedType);
+ }
+ }
+ }
+ assert pair.getFirst().size() == instanceFields.size();
+ assert pair.getSecond().size() == instanceFields.size();
+ return pair;
+ }
+
private List<CfInstruction> desugarInvokeRecordHashCode(
RecordInvokeDynamic recordInvokeDynamic,
+ FreshLocalProvider freshLocalProvider,
LocalStackAllocator localStackAllocator,
RecordInstructionDesugaringEventConsumer eventConsumer,
ProgramMethod context,
@@ -370,11 +404,14 @@
if (shouldOutlineMethods(recordClass)) {
Pair<List<DexField>, List<DexType>> sortedFields =
sortedInstanceFields(recordClass.instanceFields());
+ int freshLocal = freshLocalProvider.getFreshLocal(ValueType.OBJECT.requiredRegisters());
+ instructions.add(new CfStackInstruction(Opcode.Dup));
+ instructions.add(new CfStore(ValueType.OBJECT, freshLocal));
DexField field = sortedFields.getFirst().get(0);
int extraStack = field.getType().getRequiredRegisters();
instructions.add(new CfInstanceFieldRead(field));
for (int i = 1; i < sortedFields.getFirst().size(); i++) {
- instructions.add(new CfLoad(ValueType.OBJECT, 0));
+ instructions.add(new CfLoad(ValueType.OBJECT, freshLocal));
field = sortedFields.getFirst().get(i);
instructions.add(new CfInstanceFieldRead(field));
extraStack += field.getType().getRequiredRegisters();
@@ -411,37 +448,13 @@
return instructions;
}
- // Answers an ordered map of the fields with the type for the hashCode proto.
- private Pair<List<DexField>, List<DexType>> sortedInstanceFields(
- List<DexEncodedField> instanceFields) {
- Map<DexType, List<DexField>> temp = new IdentityHashMap<>();
- for (DexEncodedField instanceField : instanceFields) {
- DexType protoType =
- instanceField.getType().isBooleanType()
- ? instanceField.getType()
- : ValueType.fromDexType(instanceField.getType()).toDexType(factory);
- temp.computeIfAbsent(protoType, ignored -> new ArrayList<>())
- .add(instanceField.getReference());
- }
- Pair<List<DexField>, List<DexType>> pair = new Pair<>(new ArrayList<>(), new ArrayList<>());
- for (DexType orderedSharedType : orderedSharedTypes) {
- List<DexField> dexFields = temp.get(orderedSharedType);
- if (dexFields != null) {
- for (DexField dexField : dexFields) {
- pair.getFirst().add(dexField);
- pair.getSecond().add(orderedSharedType);
- }
- }
- }
- assert pair.getFirst().size() == instanceFields.size();
- assert pair.getSecond().size() == instanceFields.size();
- return pair;
- }
-
private List<CfInstruction> desugarInvokeRecordEquals(RecordInvokeDynamic recordInvokeDynamic) {
- DexMethod equalsRecord = equalsRecordMethod(recordInvokeDynamic.getRecordType());
- assert recordInvokeDynamic.getRecordClass().lookupProgramMethod(equalsRecord) != null;
- return Collections.singletonList(new CfInvoke(Opcodes.INVOKESPECIAL, equalsRecord, false));
+ ArrayList<CfInstruction> instructions = new ArrayList<>();
+ DexProgramClass recordClass = recordInvokeDynamic.getRecordClass();
+ DexMethod equalsMethod = equalsRecordMethod(recordInvokeDynamic.getRecordType());
+ assert recordClass.lookupProgramMethod(equalsMethod) != null;
+ instructions.add(new CfInvoke(Opcodes.INVOKESPECIAL, equalsMethod, false));
+ return instructions;
}
private List<CfInstruction> desugarInvokeRecordToString(
diff --git a/src/main/java/com/android/tools/r8/ir/synthetic/RecordCfCodeProvider.java b/src/main/java/com/android/tools/r8/ir/synthetic/RecordCfCodeProvider.java
index 1dbbbbc..8ba50fe 100644
--- a/src/main/java/com/android/tools/r8/ir/synthetic/RecordCfCodeProvider.java
+++ b/src/main/java/com/android/tools/r8/ir/synthetic/RecordCfCodeProvider.java
@@ -8,11 +8,13 @@
import com.android.tools.r8.cf.code.CfArithmeticBinop.Opcode;
import com.android.tools.r8.cf.code.CfArrayStore;
import com.android.tools.r8.cf.code.CfCheckCast;
+import com.android.tools.r8.cf.code.CfCmp;
import com.android.tools.r8.cf.code.CfConstNumber;
import com.android.tools.r8.cf.code.CfFrame;
import com.android.tools.r8.cf.code.CfIf;
import com.android.tools.r8.cf.code.CfIfCmp;
import com.android.tools.r8.cf.code.CfInstanceFieldRead;
+import com.android.tools.r8.cf.code.CfInstanceOf;
import com.android.tools.r8.cf.code.CfInstruction;
import com.android.tools.r8.cf.code.CfInvoke;
import com.android.tools.r8.cf.code.CfLabel;
@@ -28,6 +30,7 @@
import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.ir.code.Cmp.Bias;
import com.android.tools.r8.ir.code.IfType;
import com.android.tools.r8.ir.code.MemberType;
import com.android.tools.r8.ir.code.NumericType;
@@ -58,6 +61,14 @@
this.outline = outline;
}
+ public static void registerSynthesizedCodeReferences(DexItemFactory factory) {
+ factory.createSynthesizedType("Ljava/lang/Objects;");
+ factory.createSynthesizedType("Ljava/lang/Double;");
+ factory.createSynthesizedType("Ljava/lang/Float;");
+ factory.createSynthesizedType("Ljava/lang/Boolean;");
+ factory.createSynthesizedType("Ljava/lang/Long;");
+ }
+
private void addInvokeStatic(List<CfInstruction> instructions, DexMethod method) {
instructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method, false));
}
@@ -103,6 +114,80 @@
}
}
+ public static class RecordEqCfCodeProvider extends RecordCfCodeProvider {
+
+ private final List<DexField> fieldsToCompare;
+
+ public RecordEqCfCodeProvider(
+ AppView<?> appView, DexType holder, List<DexField> fieldsToCompare) {
+ super(appView, holder);
+ this.fieldsToCompare = fieldsToCompare;
+ }
+
+ public static void registerSynthesizedCodeReferences(DexItemFactory factory) {
+ factory.createSynthesizedType("Ljava/lang/Objects;");
+ }
+
+ private void addInvokeStatic(List<CfInstruction> instructions, DexMethod method) {
+ instructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method, false));
+ }
+
+ private void pushComparison(
+ List<CfInstruction> instructions, DexField field, CfLabel falseLabel) {
+ ValueType valueType = ValueType.fromDexType(field.getType());
+ instructions.add(new CfLoad(ValueType.OBJECT, 0));
+ instructions.add(new CfInstanceFieldRead(field));
+ instructions.add(new CfLoad(ValueType.OBJECT, 2));
+ instructions.add(new CfInstanceFieldRead(field));
+ if (valueType == ValueType.DOUBLE) {
+ instructions.add(new CfCmp(Bias.LT, NumericType.DOUBLE));
+ instructions.add(new CfIf(IfType.NE, ValueType.INT, falseLabel));
+ } else if (valueType == ValueType.FLOAT) {
+ instructions.add(new CfCmp(Bias.LT, NumericType.FLOAT));
+ instructions.add(new CfIf(IfType.NE, ValueType.INT, falseLabel));
+ } else if (valueType == ValueType.LONG) {
+ instructions.add(new CfCmp(Bias.NONE, NumericType.LONG));
+ instructions.add(new CfIf(IfType.NE, ValueType.INT, falseLabel));
+ } else if (valueType.isObject()) {
+ addInvokeStatic(instructions, appView.dexItemFactory().objectsMethods.equals);
+ instructions.add(new CfIf(IfType.EQ, ValueType.INT, falseLabel));
+ } else {
+ assert valueType == ValueType.INT;
+ instructions.add(new CfIfCmp(IfType.NE, ValueType.INT, falseLabel));
+ }
+ }
+
+ @Override
+ public CfCode generateCfCode() {
+ CfFrame frame = buildFrame();
+ List<CfInstruction> instructions = new ArrayList<>();
+ CfLabel falseLabel = new CfLabel();
+ instructions.add(new CfLoad(ValueType.OBJECT, 1));
+ instructions.add(new CfInstanceOf(getHolder()));
+ instructions.add(new CfIf(IfType.EQ, ValueType.INT, falseLabel));
+ instructions.add(new CfLoad(ValueType.OBJECT, 1));
+ instructions.add(new CfCheckCast(getHolder()));
+ instructions.add(new CfStore(ValueType.OBJECT, 2));
+ for (int i = 0; i < fieldsToCompare.size(); i++) {
+ pushComparison(instructions, fieldsToCompare.get(i), falseLabel);
+ }
+ instructions.add(new CfConstNumber(1, ValueType.INT));
+ instructions.add(new CfReturn(ValueType.INT));
+ instructions.add(falseLabel);
+ instructions.add(frame);
+ instructions.add(new CfConstNumber(0, ValueType.INT));
+ instructions.add(new CfReturn(ValueType.INT));
+ return standardCfCodeFromInstructions(instructions);
+ }
+
+ private CfFrame buildFrame() {
+ return CfFrame.builder()
+ .appendLocal(FrameType.initialized(getHolder()))
+ .appendLocal(FrameType.initialized(appView.dexItemFactory().objectType))
+ .build();
+ }
+ }
+
/**
* Generates a method which answers all field values as an array of objects. If the field value is
* a primitive type, it uses the primitive wrapper to wrap it.