Introduce a new reverse constructor inlining pass

Change-Id: I2002fc60e641bf48aaad7556e749946957bef3cd
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java
new file mode 100644
index 0000000..915259c
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java
@@ -0,0 +1,473 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.horizontalclassmerging;
+
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+import static com.android.tools.r8.ir.analysis.type.Nullability.definitelyNotNull;
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
+
+import com.android.tools.r8.cf.CfVersion;
+import com.android.tools.r8.classmerging.ClassMergerMode;
+import com.android.tools.r8.classmerging.ClassMergerSharedData;
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.graph.MethodAccessFlags;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
+import com.android.tools.r8.ir.code.IRMetadata;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.lightir.ByteArrayWriter;
+import com.android.tools.r8.lightir.ByteUtils;
+import com.android.tools.r8.lightir.LirBuilder;
+import com.android.tools.r8.lightir.LirCode;
+import com.android.tools.r8.lightir.LirConstant;
+import com.android.tools.r8.lightir.LirEncodingStrategy;
+import com.android.tools.r8.lightir.LirInstructionView;
+import com.android.tools.r8.lightir.LirOpcodes;
+import com.android.tools.r8.lightir.LirStrategy;
+import com.android.tools.r8.lightir.LirWriter;
+import com.android.tools.r8.optimize.argumentpropagation.utils.ProgramClassesBidirectedGraph;
+import com.android.tools.r8.profile.rewriting.ProfileCollectionAdditions;
+import com.android.tools.r8.utils.ArrayUtils;
+import com.android.tools.r8.utils.ObjectUtils;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.Timing;
+import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
+import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
+import it.unimi.dsi.fastutil.objects.Reference2IntMap;
+import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.IdentityHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+public class UndoConstructorInlining {
+
+  private final AppView<? extends AppInfoWithClassHierarchy> appView;
+  private final ClassMergerSharedData classMergerSharedData;
+  private final ImmediateProgramSubtypingInfo immediateSubtypingInfo;
+
+  public UndoConstructorInlining(
+      AppView<?> appView,
+      ClassMergerSharedData classMergerSharedData,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
+    this.appView = appView.enableWholeProgramOptimizations() ? appView.withClassHierarchy() : null;
+    this.classMergerSharedData = classMergerSharedData;
+    this.immediateSubtypingInfo = immediateSubtypingInfo;
+  }
+
+  public void runIfNecessary(
+      Collection<HorizontalMergeGroup> groups,
+      ClassMergerMode mode,
+      ExecutorService executorService,
+      Timing timing)
+      throws ExecutionException {
+    if (shouldRun(mode)) {
+      timing.begin("Undo constructor inlining");
+      run(groups, executorService);
+      timing.end();
+    }
+  }
+
+  private boolean shouldRun(ClassMergerMode mode) {
+    // Only run when constructor inlining is enabled.
+    return appView != null
+        && appView.options().canInitNewInstanceUsingSuperclassConstructor()
+        && mode.isFinal();
+  }
+
+  private void run(Collection<HorizontalMergeGroup> groups, ExecutorService executorService)
+      throws ExecutionException {
+    // Find all classes in horizontal non-interface merge groups that have a class id. All
+    // instantiations of these classes *must* use a constructor on the class itself, since the
+    // constructor will be responsible for assigning the class id field. This property may not hold
+    // as a result of constructor inlining, so we need to restore it.
+    Map<DexType, DexProgramClass> ensureConstructorsOnClasses =
+        groups.stream()
+            .filter(group -> group.isClassGroup() && group.hasClassIdField())
+            .flatMap(HorizontalMergeGroup::stream)
+            .collect(Collectors.toMap(DexClass::getType, Function.identity()));
+    if (ensureConstructorsOnClasses.isEmpty()) {
+      return;
+    }
+
+    // Create a mapping from program classes to their strongly connected program component. When we
+    // need to synthesize a constructor on a class C we lock on the strongly connected component of
+    // C to ensure thread safety.
+    Map<DexProgramClass, StronglyConnectedComponent> stronglyConnectedComponents =
+        computeStronglyConnectedComponents();
+    new LirRewriter(appView, ensureConstructorsOnClasses, stronglyConnectedComponents)
+        .run(executorService);
+    appView.dexItemFactory().clearTypeElementsCache();
+  }
+
+  private Map<DexProgramClass, StronglyConnectedComponent> computeStronglyConnectedComponents() {
+    List<Set<DexProgramClass>> stronglyConnectedComponents =
+        new ProgramClassesBidirectedGraph(appView, immediateSubtypingInfo)
+            .computeStronglyConnectedComponents();
+    Map<DexProgramClass, StronglyConnectedComponent> stronglyConnectedComponentMap =
+        new IdentityHashMap<>();
+    for (Set<DexProgramClass> classes : stronglyConnectedComponents) {
+      StronglyConnectedComponent stronglyConnectedComponent = new StronglyConnectedComponent();
+      for (DexProgramClass clazz : classes) {
+        stronglyConnectedComponentMap.put(clazz, stronglyConnectedComponent);
+      }
+    }
+    return stronglyConnectedComponentMap;
+  }
+
+  private static class LirRewriter {
+
+    private final AppView<? extends AppInfoWithClassHierarchy> appView;
+    private final Map<DexType, DexProgramClass> ensureConstructorsOnClasses;
+    private final ProfileCollectionAdditions profileCollectionAdditions;
+    private final Map<DexProgramClass, StronglyConnectedComponent> stronglyConnectedComponents;
+
+    LirRewriter(
+        AppView<? extends AppInfoWithClassHierarchy> appView,
+        Map<DexType, DexProgramClass> ensureConstructorsOnClasses,
+        Map<DexProgramClass, StronglyConnectedComponent> stronglyConnectedComponents) {
+      this.appView = appView;
+      this.ensureConstructorsOnClasses = ensureConstructorsOnClasses;
+      this.profileCollectionAdditions = ProfileCollectionAdditions.create(appView);
+      this.stronglyConnectedComponents = stronglyConnectedComponents;
+    }
+
+    public void run(ExecutorService executorService) throws ExecutionException {
+      ThreadUtils.processItems(
+          appView.appInfo().classes(),
+          this::processClass,
+          appView.options().getThreadingModule(),
+          executorService);
+      profileCollectionAdditions.commit(appView);
+    }
+
+    private void processClass(DexProgramClass clazz) {
+      clazz.forEachProgramMethodMatching(this::filterMethod, this::processMethod);
+    }
+
+    private boolean filterMethod(DexEncodedMethod method) {
+      return method.hasCode()
+          && method.getCode().isLirCode()
+          && mayInstantiateClassOfInterest(
+              method.getCode().asLirCode(), ensureConstructorsOnClasses);
+    }
+
+    private void processMethod(ProgramMethod method) {
+      LirCode<Integer> code = method.getDefinition().getCode().asLirCode();
+      LirCode<Integer> rewritten = rewriteLir(method, code);
+      if (ObjectUtils.notIdentical(code, rewritten)) {
+        method.setCode(rewritten, appView);
+      }
+    }
+
+    private boolean mayInstantiateClassOfInterest(
+        LirCode<Integer> code, Map<DexType, DexProgramClass> ensureConstructorsOnClasses) {
+      return ArrayUtils.any(
+          code.getConstantPool(),
+          constant ->
+              constant instanceof DexType && ensureConstructorsOnClasses.containsKey(constant));
+    }
+
+    private LirCode<Integer> rewriteLir(ProgramMethod method, LirCode<Integer> code) {
+      // Create a mapping from new-instance value index -> new-instance type (limited to the types
+      // of interest).
+      Int2ReferenceMap<DexProgramClass> allocationsOfInterest =
+          getAllocationsOfInterest(method, code);
+      if (allocationsOfInterest.isEmpty()) {
+        return code;
+      }
+      ByteArrayWriter byteWriter = new ByteArrayWriter();
+      LirWriter lirWriter = new LirWriter(byteWriter);
+      List<LirConstant> methodsToAppend = new ArrayList<>();
+      Reference2IntMap<DexMethod> methodIndices = new Reference2IntOpenHashMap<>();
+      for (LirInstructionView view : code) {
+        InvokeDirectInfo info =
+            getAllocationOfInterest(method, code, lirWriter, view, allocationsOfInterest);
+        if (info == null) {
+          continue;
+        }
+        ProgramMethod newInvokedMethod =
+            getStronglyConnectedComponent(info.getProgramClass())
+                .getOrCreateConstructor(
+                    info.getProgramClass(),
+                    info.getInvokedMethod(),
+                    ensureConstructorsOnClasses,
+                    newConstructor ->
+                        profileCollectionAdditions.addMethodIfContextIsInProfile(
+                            newConstructor, method));
+        int constantIndex =
+            methodIndices.computeIfAbsent(
+                newInvokedMethod.getReference(),
+                ref -> {
+                  methodsToAppend.add(ref);
+                  return code.getConstantPool().length + methodsToAppend.size() - 1;
+                });
+        int constantIndexSize = ByteUtils.intEncodingSize(constantIndex);
+        int firstValueSize = ByteUtils.intEncodingSize(info.getFirstValue());
+        int remainingSize = view.getRemainingOperandSizeInBytes();
+        lirWriter.writeInstruction(
+            LirOpcodes.INVOKEDIRECT, constantIndexSize + firstValueSize + remainingSize);
+        ByteUtils.writeEncodedInt(constantIndex, lirWriter::writeOperand);
+        ByteUtils.writeEncodedInt(info.getFirstValue(), lirWriter::writeOperand);
+        while (remainingSize-- > 0) {
+          lirWriter.writeOperand(view.getNextU1());
+        }
+      }
+      return methodsToAppend.isEmpty()
+          ? code
+          : code.copyWithNewConstantsAndInstructions(
+              code.getMetadataForIR(),
+              ArrayUtils.appendElements(code.getConstantPool(), methodsToAppend),
+              byteWriter.toByteArray());
+    }
+
+    private Int2ReferenceMap<DexProgramClass> getAllocationsOfInterest(
+        ProgramMethod method, LirCode<Integer> code) {
+      Int2ReferenceMap<DexProgramClass> allocationsOfInterest = new Int2ReferenceOpenHashMap<>();
+      if (method.getDefinition().isInstanceInitializer()) {
+        DexProgramClass classOfInterest =
+            ensureConstructorsOnClasses.get(method.getHolder().getSuperType());
+        if (classOfInterest != null) {
+          allocationsOfInterest.put(0, classOfInterest);
+        }
+      }
+      for (LirInstructionView view : code) {
+        if (view.getOpcode() == LirOpcodes.NEW) {
+          DexType type = (DexType) code.getConstantItem(view.getNextConstantOperand());
+          DexProgramClass classOfInterest = ensureConstructorsOnClasses.get(type);
+          if (classOfInterest != null) {
+            allocationsOfInterest.put(view.getValueIndex(code), classOfInterest);
+          }
+        }
+      }
+      return allocationsOfInterest;
+    }
+
+    /**
+     * Returns non-null for constructor calls that need to be rewritten to a constructor call on the
+     * new-instance type. When this returns null, the current instruction is written to the {@param
+     * lirWriter}.
+     */
+    // TODO(b/225838009): Look into making it easier to "peek" data in the LirInstructionView to
+    //  avoid needing to keep track of how many operands have been consumed.
+    private InvokeDirectInfo getAllocationOfInterest(
+        ProgramMethod method,
+        LirCode<Integer> code,
+        LirWriter lirWriter,
+        LirInstructionView view,
+        Int2ReferenceMap<DexProgramClass> allocationsOfInterest) {
+      int opcode = view.getOpcode();
+      if (LirOpcodes.isOneByteInstruction(opcode)) {
+        lirWriter.writeOneByteInstruction(opcode);
+        return null;
+      }
+      int operandSizeInBytes = view.getRemainingOperandSizeInBytes();
+      int numReadOperands = 0;
+      int constantIndex = -1;
+      int firstValue = -1;
+      if (opcode == LirOpcodes.INVOKEDIRECT) {
+        constantIndex = view.getNextConstantOperand();
+        numReadOperands++;
+        DexMethod invokedMethod = (DexMethod) code.getConstantItem(constantIndex);
+        if (invokedMethod.isInstanceInitializer(appView.dexItemFactory())) {
+          firstValue = view.getNextValueOperand();
+          numReadOperands++;
+          int receiver = code.decodeValueIndex(firstValue, view.getValueIndex(code));
+          DexProgramClass classOfInterest = allocationsOfInterest.get(receiver);
+          if (classOfInterest != null
+              && classOfInterest.getType().isNotIdenticalTo(invokedMethod.getHolderType())
+              && !isForwardingConstructorCall(method, invokedMethod, receiver)) {
+            return new InvokeDirectInfo(invokedMethod, firstValue, classOfInterest);
+          }
+        }
+      }
+      lirWriter.writeInstruction(opcode, operandSizeInBytes);
+      assert numReadOperands <= 2;
+      if (numReadOperands > 0) {
+        ByteUtils.writeEncodedInt(constantIndex, lirWriter::writeOperand);
+        if (numReadOperands == 2) {
+          ByteUtils.writeEncodedInt(firstValue, lirWriter::writeOperand);
+        }
+      }
+      int size = view.getRemainingOperandSizeInBytes();
+      while (size-- > 0) {
+        lirWriter.writeOperand(view.getNextU1());
+      }
+      return null;
+    }
+
+    private boolean isForwardingConstructorCall(
+        ProgramMethod method, DexMethod invokedMethod, int receiver) {
+      assert invokedMethod.isInstanceInitializer(appView.dexItemFactory());
+      return method.getDefinition().isInstanceInitializer()
+          && invokedMethod.getHolderType().isIdenticalTo(method.getHolderType())
+          && receiver == 0;
+    }
+
+    private StronglyConnectedComponent getStronglyConnectedComponent(DexProgramClass clazz) {
+      return stronglyConnectedComponents.get(clazz);
+    }
+  }
+
+  private static class InvokeDirectInfo {
+
+    private final DexMethod invokedMethod;
+    private final int firstValue;
+    private final DexProgramClass programClass;
+
+    InvokeDirectInfo(DexMethod invokedMethod, int firstValue, DexProgramClass programClass) {
+      this.invokedMethod = invokedMethod;
+      this.firstValue = firstValue;
+      this.programClass = programClass;
+    }
+
+    DexMethod getInvokedMethod() {
+      return invokedMethod;
+    }
+
+    public int getFirstValue() {
+      return firstValue;
+    }
+
+    public DexProgramClass getProgramClass() {
+      return programClass;
+    }
+  }
+
+  private class StronglyConnectedComponent {
+
+    private final Map<DexProgramClass, Map<DexMethod, ProgramMethod>> constructorCache =
+        new IdentityHashMap<>();
+
+    // Get or create a constructor on the given class that calls target, which is a constructor on
+    // a parent class. Note that the returned constructor may not actually call the given target
+    // constructor directly, as constructors may also need to be synthesized between the current
+    // class and the target holder.
+    //
+    // Synchronized to ensure thread safety.
+    public synchronized ProgramMethod getOrCreateConstructor(
+        DexProgramClass clazz,
+        DexMethod target,
+        Map<DexType, DexProgramClass> ensureConstructorsOnClasses,
+        Consumer<ProgramMethod> creationConsumer) {
+      return constructorCache
+          .computeIfAbsent(clazz, ignoreKey(IdentityHashMap::new))
+          .computeIfAbsent(
+              target,
+              k -> createConstructor(clazz, target, ensureConstructorsOnClasses, creationConsumer));
+    }
+
+    private ProgramMethod createConstructor(
+        DexProgramClass clazz,
+        DexMethod target,
+        Map<DexType, DexProgramClass> ensureConstructorsOnClasses,
+        Consumer<ProgramMethod> creationConsumer) {
+      // Create a fresh constructor on the given class that calls target. If there is a class in the
+      // hierarchy inbetween `clazz` and `target.holder`, which is also subject to class merging,
+      // then we must create a constructor that calls a constructor on that intermediate class,
+      // which then calls target.
+      DexType currentType = clazz.getSuperType();
+      while (currentType.isNotIdenticalTo(target.getHolderType())) {
+        DexProgramClass currentClass = asProgramClassOrNull(appView.definitionFor(currentType));
+        if (currentClass == null) {
+          break;
+        }
+        if (ensureConstructorsOnClasses.containsKey(currentType)) {
+          target =
+              getOrCreateConstructor(
+                      currentClass, target, ensureConstructorsOnClasses, creationConsumer)
+                  .getReference();
+          break;
+        }
+        currentType = currentClass.getSuperType();
+      }
+
+      // Create a constructor that calls target.
+      DexItemFactory dexItemFactory = appView.dexItemFactory();
+      DexMethod candidateMethodReference = target.withHolder(clazz, dexItemFactory);
+      DexMethod methodReference =
+          dexItemFactory.createInstanceInitializerWithFreshProto(
+              candidateMethodReference,
+              classMergerSharedData.getExtraUnusedArgumentTypes(),
+              test -> clazz.lookupDirectMethod(test) == null);
+      DexEncodedMethod method =
+          DexEncodedMethod.syntheticBuilder()
+              .setMethod(methodReference)
+              .setAccessFlags(
+                  MethodAccessFlags.builder().setConstructor().setPublic().setSynthetic().build())
+              .setCode(createConstructorCode(methodReference, target))
+              .setClassFileVersion(CfVersion.V1_6)
+              .build();
+      clazz.addDirectMethod(method);
+      ProgramMethod programMethod = method.asProgramMethod(clazz);
+      creationConsumer.accept(programMethod);
+      return programMethod;
+    }
+
+    private LirCode<Integer> createConstructorCode(DexMethod methodReference, DexMethod target) {
+      LirEncodingStrategy<Value, Integer> strategy =
+          LirStrategy.getDefaultStrategy().getEncodingStrategy();
+      LirBuilder<Value, Integer> lirBuilder =
+          LirCode.builder(methodReference, true, strategy, appView.options())
+              .setMetadata(IRMetadata.unknown());
+
+      int instructionIndex = 0;
+      List<Value> argumentValues = new ArrayList<>();
+
+      // Add receiver argument.
+      DexType receiverType = methodReference.getHolderType();
+      TypeElement receiverTypeElement = receiverType.toTypeElement(appView, definitelyNotNull());
+      Value receiverValue = Value.createNoDebugLocal(instructionIndex, receiverTypeElement);
+      argumentValues.add(receiverValue);
+      strategy.defineValue(receiverValue, receiverValue.getNumber());
+      lirBuilder.addArgument(receiverValue.getNumber(), false);
+      instructionIndex++;
+
+      // Add non-receiver arguments.
+      for (;
+          instructionIndex < target.getNumberOfArgumentsForNonStaticMethod();
+          instructionIndex++) {
+        DexType argumentType = target.getArgumentTypeForNonStaticMethod(instructionIndex);
+        TypeElement argumentTypeElement = argumentType.toTypeElement(appView);
+        Value argumentValue = Value.createNoDebugLocal(instructionIndex, argumentTypeElement);
+        argumentValues.add(argumentValue);
+        strategy.defineValue(argumentValue, argumentValue.getNumber());
+        lirBuilder.addArgument(argumentValue.getNumber(), argumentType.isBooleanType());
+      }
+
+      // Add remaining unused non-receiver arguments.
+      for (;
+          instructionIndex < methodReference.getNumberOfArgumentsForNonStaticMethod();
+          instructionIndex++) {
+        DexType argumentType = methodReference.getArgumentTypeForNonStaticMethod(instructionIndex);
+        lirBuilder.addArgument(instructionIndex, argumentType.isBooleanType());
+      }
+
+      // Invoke parent constructor.
+      lirBuilder.addInvokeDirect(target, argumentValues, false);
+      instructionIndex++;
+
+      // Return.
+      lirBuilder.addReturnVoid();
+      instructionIndex++;
+
+      return lirBuilder.build();
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/lightir/LirInstructionView.java b/src/main/java/com/android/tools/r8/lightir/LirInstructionView.java
index 7fb4e59..166c1d7 100644
--- a/src/main/java/com/android/tools/r8/lightir/LirInstructionView.java
+++ b/src/main/java/com/android/tools/r8/lightir/LirInstructionView.java
@@ -15,6 +15,9 @@
   /** Convenience method to forward control to a callback. */
   void accept(LirInstructionCallback eventCallback);
 
+  /** Get the current value index. */
+  int getValueIndex(LirCode<Integer> code);
+
   /** Get the instruction index. */
   int getInstructionIndex();
 
diff --git a/src/main/java/com/android/tools/r8/lightir/LirIterator.java b/src/main/java/com/android/tools/r8/lightir/LirIterator.java
index aac44bb..18c2bbe 100644
--- a/src/main/java/com/android/tools/r8/lightir/LirIterator.java
+++ b/src/main/java/com/android/tools/r8/lightir/LirIterator.java
@@ -66,6 +66,11 @@
   }
 
   @Override
+  public int getValueIndex(LirCode<Integer> code) {
+    return code.getArgumentCount() + currentInstructionIndex;
+  }
+
+  @Override
   public int getInstructionIndex() {
     return currentInstructionIndex;
   }