Make Record equals allocation free

Bug: b/378780054
Change-Id: Ie88fb8544e261a6c0b8033c71c39d43b65125a9d
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/CfInstructionDesugaringEventConsumer.java b/src/main/java/com/android/tools/r8/ir/desugar/CfInstructionDesugaringEventConsumer.java
index 437692a..ace4c1e 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/CfInstructionDesugaringEventConsumer.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/CfInstructionDesugaringEventConsumer.java
@@ -221,7 +221,7 @@
 
     @Override
     public void acceptRecordEqualsHelperMethod(ProgramMethod method, ProgramMethod context) {
-      // Intentionally empty. Added to the program using ProgramAdditions.
+      methodProcessor.scheduleMethodForProcessing(method, outermostEventConsumer);
     }
 
     @Override
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/records/RecordInstructionDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/records/RecordInstructionDesugaring.java
index dee85f5..f484a69 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/records/RecordInstructionDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/records/RecordInstructionDesugaring.java
@@ -15,6 +15,9 @@
 import com.android.tools.r8.cf.code.CfInvoke;
 import com.android.tools.r8.cf.code.CfInvokeDynamic;
 import com.android.tools.r8.cf.code.CfLoad;
+import com.android.tools.r8.cf.code.CfStackInstruction;
+import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
+import com.android.tools.r8.cf.code.CfStore;
 import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.errors.CompilationError;
@@ -36,10 +39,12 @@
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaring;
 import com.android.tools.r8.ir.desugar.CfInstructionDesugaringEventConsumer;
 import com.android.tools.r8.ir.desugar.DesugarDescription;
+import com.android.tools.r8.ir.desugar.FreshLocalProvider;
 import com.android.tools.r8.ir.desugar.LocalStackAllocator;
 import com.android.tools.r8.ir.desugar.ProgramAdditions;
 import com.android.tools.r8.ir.desugar.records.RecordDesugaringEventConsumer.RecordInstructionDesugaringEventConsumer;
 import com.android.tools.r8.ir.desugar.records.RecordRewriterHelper.RecordInvokeDynamic;
+import com.android.tools.r8.ir.synthetic.RecordCfCodeProvider.RecordEqCfCodeProvider;
 import com.android.tools.r8.ir.synthetic.RecordCfCodeProvider.RecordEqualsCfCodeProvider;
 import com.android.tools.r8.ir.synthetic.RecordCfCodeProvider.RecordGetFieldsAsObjectsCfCodeProvider;
 import com.android.tools.r8.ir.synthetic.RecordCfCodeProvider.RecordHashCfCodeProvider;
@@ -179,6 +184,7 @@
                 dexItemFactory) ->
                 desugarInvokeDynamicOnRecord(
                     instruction.asInvokeDynamic(),
+                    freshLocalProvider,
                     localStackAllocator,
                     eventConsumer,
                     context,
@@ -206,6 +212,7 @@
   @SuppressWarnings("ReferenceEquality")
   private List<CfInstruction> desugarInvokeDynamicOnRecord(
       CfInvokeDynamic invokeDynamic,
+      FreshLocalProvider freshLocalProvider,
       LocalStackAllocator localStackAllocator,
       CfInstructionDesugaringEventConsumer eventConsumer,
       ProgramMethod context,
@@ -223,6 +230,7 @@
     if (recordInvokeDynamic.getMethodName() == factory.hashCodeMethodName) {
       return desugarInvokeRecordHashCode(
           recordInvokeDynamic,
+          freshLocalProvider,
           localStackAllocator,
           eventConsumer,
           context,
@@ -234,12 +242,6 @@
     throw new Unreachable("Invoke dynamic needs record desugaring but could not be desugared.");
   }
 
-  private ProgramMethod synthesizeEqualsRecordMethod(
-      DexProgramClass clazz, DexMethod getFieldsAsObjects, DexMethod method) {
-    return synthesizeMethod(
-        clazz, new RecordEqualsCfCodeProvider(appView, clazz.type, getFieldsAsObjects), method);
-  }
-
   private ProgramMethod synthesizeGetFieldsAsObjectsMethod(
       DexProgramClass clazz, DexField[] fields, DexMethod method) {
     return synthesizeMethod(
@@ -290,14 +292,18 @@
       ProgramAdditions programAdditions,
       ProgramMethod context,
       RecordInstructionDesugaringEventConsumer eventConsumer) {
-    DexMethod getFieldsAsObjects =
-        ensureGetFieldsAsObjects(recordInvokeDynamic, programAdditions, context, eventConsumer);
     DexProgramClass clazz = recordInvokeDynamic.getRecordClass();
     DexMethod method = equalsRecordMethod(clazz.type);
     assert clazz.lookupProgramMethod(method) == null;
+    Pair<List<DexField>, List<DexType>> pair = sortedInstanceFields(clazz.instanceFields());
     ProgramMethod equalsHelperMethod =
         programAdditions.ensureMethod(
-            method, () -> synthesizeEqualsRecordMethod(clazz, getFieldsAsObjects, method));
+            method,
+            () ->
+                synthesizeMethod(
+                    clazz,
+                    new RecordEqCfCodeProvider(appView, clazz.type, pair.getFirst()),
+                    method));
     eventConsumer.acceptRecordEqualsHelperMethod(equalsHelperMethod, context);
   }
 
@@ -358,8 +364,36 @@
     return recordClass.instanceFields().size() < MAX_FIELDS_FOR_OUTLINE;
   }
 
+  // Answers an ordered map of the fields with the type for the hashCode proto.
+  private Pair<List<DexField>, List<DexType>> sortedInstanceFields(
+      List<DexEncodedField> instanceFields) {
+    Map<DexType, List<DexField>> temp = new IdentityHashMap<>();
+    for (DexEncodedField instanceField : instanceFields) {
+      DexType protoType =
+          instanceField.getType().isBooleanType()
+              ? instanceField.getType()
+              : ValueType.fromDexType(instanceField.getType()).toDexType(factory);
+      temp.computeIfAbsent(protoType, ignored -> new ArrayList<>())
+          .add(instanceField.getReference());
+    }
+    Pair<List<DexField>, List<DexType>> pair = new Pair<>(new ArrayList<>(), new ArrayList<>());
+    for (DexType orderedSharedType : orderedSharedTypes) {
+      List<DexField> dexFields = temp.get(orderedSharedType);
+      if (dexFields != null) {
+        for (DexField dexField : dexFields) {
+          pair.getFirst().add(dexField);
+          pair.getSecond().add(orderedSharedType);
+        }
+      }
+    }
+    assert pair.getFirst().size() == instanceFields.size();
+    assert pair.getSecond().size() == instanceFields.size();
+    return pair;
+  }
+
   private List<CfInstruction> desugarInvokeRecordHashCode(
       RecordInvokeDynamic recordInvokeDynamic,
+      FreshLocalProvider freshLocalProvider,
       LocalStackAllocator localStackAllocator,
       RecordInstructionDesugaringEventConsumer eventConsumer,
       ProgramMethod context,
@@ -370,11 +404,14 @@
     if (shouldOutlineMethods(recordClass)) {
       Pair<List<DexField>, List<DexType>> sortedFields =
           sortedInstanceFields(recordClass.instanceFields());
+      int freshLocal = freshLocalProvider.getFreshLocal(ValueType.OBJECT.requiredRegisters());
+      instructions.add(new CfStackInstruction(Opcode.Dup));
+      instructions.add(new CfStore(ValueType.OBJECT, freshLocal));
       DexField field = sortedFields.getFirst().get(0);
       int extraStack = field.getType().getRequiredRegisters();
       instructions.add(new CfInstanceFieldRead(field));
       for (int i = 1; i < sortedFields.getFirst().size(); i++) {
-        instructions.add(new CfLoad(ValueType.OBJECT, 0));
+        instructions.add(new CfLoad(ValueType.OBJECT, freshLocal));
         field = sortedFields.getFirst().get(i);
         instructions.add(new CfInstanceFieldRead(field));
         extraStack += field.getType().getRequiredRegisters();
@@ -411,37 +448,13 @@
     return instructions;
   }
 
-  // Answers an ordered map of the fields with the type for the hashCode proto.
-  private Pair<List<DexField>, List<DexType>> sortedInstanceFields(
-      List<DexEncodedField> instanceFields) {
-    Map<DexType, List<DexField>> temp = new IdentityHashMap<>();
-    for (DexEncodedField instanceField : instanceFields) {
-      DexType protoType =
-          instanceField.getType().isBooleanType()
-              ? instanceField.getType()
-              : ValueType.fromDexType(instanceField.getType()).toDexType(factory);
-      temp.computeIfAbsent(protoType, ignored -> new ArrayList<>())
-          .add(instanceField.getReference());
-    }
-    Pair<List<DexField>, List<DexType>> pair = new Pair<>(new ArrayList<>(), new ArrayList<>());
-    for (DexType orderedSharedType : orderedSharedTypes) {
-      List<DexField> dexFields = temp.get(orderedSharedType);
-      if (dexFields != null) {
-        for (DexField dexField : dexFields) {
-          pair.getFirst().add(dexField);
-          pair.getSecond().add(orderedSharedType);
-        }
-      }
-    }
-    assert pair.getFirst().size() == instanceFields.size();
-    assert pair.getSecond().size() == instanceFields.size();
-    return pair;
-  }
-
   private List<CfInstruction> desugarInvokeRecordEquals(RecordInvokeDynamic recordInvokeDynamic) {
-    DexMethod equalsRecord = equalsRecordMethod(recordInvokeDynamic.getRecordType());
-    assert recordInvokeDynamic.getRecordClass().lookupProgramMethod(equalsRecord) != null;
-    return Collections.singletonList(new CfInvoke(Opcodes.INVOKESPECIAL, equalsRecord, false));
+    ArrayList<CfInstruction> instructions = new ArrayList<>();
+    DexProgramClass recordClass = recordInvokeDynamic.getRecordClass();
+    DexMethod equalsMethod = equalsRecordMethod(recordInvokeDynamic.getRecordType());
+    assert recordClass.lookupProgramMethod(equalsMethod) != null;
+    instructions.add(new CfInvoke(Opcodes.INVOKESPECIAL, equalsMethod, false));
+    return instructions;
   }
 
   private List<CfInstruction> desugarInvokeRecordToString(
diff --git a/src/main/java/com/android/tools/r8/ir/synthetic/RecordCfCodeProvider.java b/src/main/java/com/android/tools/r8/ir/synthetic/RecordCfCodeProvider.java
index 1dbbbbc..8ba50fe 100644
--- a/src/main/java/com/android/tools/r8/ir/synthetic/RecordCfCodeProvider.java
+++ b/src/main/java/com/android/tools/r8/ir/synthetic/RecordCfCodeProvider.java
@@ -8,11 +8,13 @@
 import com.android.tools.r8.cf.code.CfArithmeticBinop.Opcode;
 import com.android.tools.r8.cf.code.CfArrayStore;
 import com.android.tools.r8.cf.code.CfCheckCast;
+import com.android.tools.r8.cf.code.CfCmp;
 import com.android.tools.r8.cf.code.CfConstNumber;
 import com.android.tools.r8.cf.code.CfFrame;
 import com.android.tools.r8.cf.code.CfIf;
 import com.android.tools.r8.cf.code.CfIfCmp;
 import com.android.tools.r8.cf.code.CfInstanceFieldRead;
+import com.android.tools.r8.cf.code.CfInstanceOf;
 import com.android.tools.r8.cf.code.CfInstruction;
 import com.android.tools.r8.cf.code.CfInvoke;
 import com.android.tools.r8.cf.code.CfLabel;
@@ -28,6 +30,7 @@
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.ir.code.Cmp.Bias;
 import com.android.tools.r8.ir.code.IfType;
 import com.android.tools.r8.ir.code.MemberType;
 import com.android.tools.r8.ir.code.NumericType;
@@ -58,6 +61,14 @@
       this.outline = outline;
     }
 
+    public static void registerSynthesizedCodeReferences(DexItemFactory factory) {
+      factory.createSynthesizedType("Ljava/lang/Objects;");
+      factory.createSynthesizedType("Ljava/lang/Double;");
+      factory.createSynthesizedType("Ljava/lang/Float;");
+      factory.createSynthesizedType("Ljava/lang/Boolean;");
+      factory.createSynthesizedType("Ljava/lang/Long;");
+    }
+
     private void addInvokeStatic(List<CfInstruction> instructions, DexMethod method) {
       instructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method, false));
     }
@@ -103,6 +114,80 @@
     }
   }
 
+  public static class RecordEqCfCodeProvider extends RecordCfCodeProvider {
+
+    private final List<DexField> fieldsToCompare;
+
+    public RecordEqCfCodeProvider(
+        AppView<?> appView, DexType holder, List<DexField> fieldsToCompare) {
+      super(appView, holder);
+      this.fieldsToCompare = fieldsToCompare;
+    }
+
+    public static void registerSynthesizedCodeReferences(DexItemFactory factory) {
+      factory.createSynthesizedType("Ljava/lang/Objects;");
+    }
+
+    private void addInvokeStatic(List<CfInstruction> instructions, DexMethod method) {
+      instructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method, false));
+    }
+
+    private void pushComparison(
+        List<CfInstruction> instructions, DexField field, CfLabel falseLabel) {
+      ValueType valueType = ValueType.fromDexType(field.getType());
+      instructions.add(new CfLoad(ValueType.OBJECT, 0));
+      instructions.add(new CfInstanceFieldRead(field));
+      instructions.add(new CfLoad(ValueType.OBJECT, 2));
+      instructions.add(new CfInstanceFieldRead(field));
+      if (valueType == ValueType.DOUBLE) {
+        instructions.add(new CfCmp(Bias.LT, NumericType.DOUBLE));
+        instructions.add(new CfIf(IfType.NE, ValueType.INT, falseLabel));
+      } else if (valueType == ValueType.FLOAT) {
+        instructions.add(new CfCmp(Bias.LT, NumericType.FLOAT));
+        instructions.add(new CfIf(IfType.NE, ValueType.INT, falseLabel));
+      } else if (valueType == ValueType.LONG) {
+        instructions.add(new CfCmp(Bias.NONE, NumericType.LONG));
+        instructions.add(new CfIf(IfType.NE, ValueType.INT, falseLabel));
+      } else if (valueType.isObject()) {
+        addInvokeStatic(instructions, appView.dexItemFactory().objectsMethods.equals);
+        instructions.add(new CfIf(IfType.EQ, ValueType.INT, falseLabel));
+      } else {
+        assert valueType == ValueType.INT;
+        instructions.add(new CfIfCmp(IfType.NE, ValueType.INT, falseLabel));
+      }
+    }
+
+    @Override
+    public CfCode generateCfCode() {
+      CfFrame frame = buildFrame();
+      List<CfInstruction> instructions = new ArrayList<>();
+      CfLabel falseLabel = new CfLabel();
+      instructions.add(new CfLoad(ValueType.OBJECT, 1));
+      instructions.add(new CfInstanceOf(getHolder()));
+      instructions.add(new CfIf(IfType.EQ, ValueType.INT, falseLabel));
+      instructions.add(new CfLoad(ValueType.OBJECT, 1));
+      instructions.add(new CfCheckCast(getHolder()));
+      instructions.add(new CfStore(ValueType.OBJECT, 2));
+      for (int i = 0; i < fieldsToCompare.size(); i++) {
+        pushComparison(instructions, fieldsToCompare.get(i), falseLabel);
+      }
+      instructions.add(new CfConstNumber(1, ValueType.INT));
+      instructions.add(new CfReturn(ValueType.INT));
+      instructions.add(falseLabel);
+      instructions.add(frame);
+      instructions.add(new CfConstNumber(0, ValueType.INT));
+      instructions.add(new CfReturn(ValueType.INT));
+      return standardCfCodeFromInstructions(instructions);
+    }
+
+    private CfFrame buildFrame() {
+      return CfFrame.builder()
+          .appendLocal(FrameType.initialized(getHolder()))
+          .appendLocal(FrameType.initialized(appView.dexItemFactory().objectType))
+          .build();
+    }
+  }
+
   /**
    * Generates a method which answers all field values as an array of objects. If the field value is
    * a primitive type, it uses the primitive wrapper to wrap it.