Ensure subclass instantiations does not skip required constructors

Change-Id: I1f72dde344fc3cd75926fe2399441419dc79ebb6
diff --git a/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java b/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
index d2f05df..2ce1db6 100644
--- a/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
@@ -91,7 +91,7 @@
     DexEncodedMethod definition = method.getDefinition();
     assert definition.getCode().isDefaultInstanceInitializerCode();
     if (appView.testing().isSupportedLirPhase()) {
-      method.setCode(get().toLirCode(appView, method), appView);
+      method.setCode(get().toLirCode(appView, method, superType), appView);
     } else {
       assert appView.testing().isPreLirPhase();
       method.setCode(get().toCfCode(method, appView.dexItemFactory(), superType), appView);
@@ -378,12 +378,11 @@
     return new CfCode(method.getHolderType(), getMaxStack(), getMaxLocals(method), instructions);
   }
 
-  public LirCode<?> toLirCode(AppView<?> appView, ProgramMethod method) {
+  public LirCode<?> toLirCode(AppView<?> appView, ProgramMethod method, DexType supertype) {
     TypeElement receiverType =
         method.getHolder().getType().toTypeElement(appView, Nullability.definitelyNotNull());
     Value receiver = new Value(0, receiverType, null);
-    DexMethod invokedMethod =
-        appView.dexItemFactory().createInstanceInitializer(method.getHolder().getSuperType());
+    DexMethod invokedMethod = appView.dexItemFactory().createInstanceInitializer(supertype);
     LirEncodingStrategy<Value, Integer> strategy =
         LirStrategy.getDefaultStrategy().getEncodingStrategy();
     strategy.defineValue(receiver, 0);
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java
index 915259c..f10c18b 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java
@@ -40,6 +40,7 @@
 import com.android.tools.r8.utils.ObjectUtils;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.WorkList;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
 import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
@@ -106,6 +107,9 @@
       return;
     }
 
+    // Extend the mapping to include subclasses.
+    ensureConstructorsOnSubclasses(ensureConstructorsOnClasses);
+
     // Create a mapping from program classes to their strongly connected program component. When we
     // need to synthesize a constructor on a class C we lock on the strongly connected component of
     // C to ensure thread safety.
@@ -116,6 +120,25 @@
     appView.dexItemFactory().clearTypeElementsCache();
   }
 
+  private void ensureConstructorsOnSubclasses(
+      Map<DexType, DexProgramClass> ensureConstructorsOnClasses) {
+    // Perform a top-down traversal from each merge class and record that instantiations of
+    // subclasses of the merge class must not skip any constructors on the merge class.
+    Map<DexType, DexProgramClass> ensureConstructorsOnSubclasses = new IdentityHashMap<>();
+    for (DexProgramClass mergeClass : ensureConstructorsOnClasses.values()) {
+      WorkList.newIdentityWorkList(immediateSubtypingInfo.getSubclasses(mergeClass))
+          .process(
+              (current, worklist) -> {
+                if (ensureConstructorsOnClasses.containsKey(current.getType())) {
+                  return;
+                }
+                ensureConstructorsOnSubclasses.put(current.getType(), mergeClass);
+                worklist.addIfNotSeen(immediateSubtypingInfo.getSubclasses(current));
+              });
+    }
+    ensureConstructorsOnClasses.putAll(ensureConstructorsOnSubclasses);
+  }
+
   private Map<DexProgramClass, StronglyConnectedComponent> computeStronglyConnectedComponents() {
     List<Set<DexProgramClass>> stronglyConnectedComponents =
         new ProgramClassesBidirectedGraph(appView, immediateSubtypingInfo)
@@ -158,14 +181,15 @@
     }
 
     private void processClass(DexProgramClass clazz) {
-      clazz.forEachProgramMethodMatching(this::filterMethod, this::processMethod);
+      clazz.forEachProgramMethodMatching(
+          method -> filterMethod(clazz, method), this::processMethod);
     }
 
-    private boolean filterMethod(DexEncodedMethod method) {
+    private boolean filterMethod(DexProgramClass clazz, DexEncodedMethod method) {
       return method.hasCode()
           && method.getCode().isLirCode()
           && mayInstantiateClassOfInterest(
-              method.getCode().asLirCode(), ensureConstructorsOnClasses);
+              clazz, method, method.getCode().asLirCode(), ensureConstructorsOnClasses);
     }
 
     private void processMethod(ProgramMethod method) {
@@ -177,7 +201,15 @@
     }
 
     private boolean mayInstantiateClassOfInterest(
-        LirCode<Integer> code, Map<DexType, DexProgramClass> ensureConstructorsOnClasses) {
+        DexProgramClass clazz,
+        DexEncodedMethod method,
+        LirCode<Integer> code,
+        Map<DexType, DexProgramClass> ensureConstructorsOnClasses) {
+      // Treat constructors as allocations of the super class.
+      if (method.isInstanceInitializer()
+          && ensureConstructorsOnClasses.containsKey(clazz.getSuperType())) {
+        return true;
+      }
       return ArrayUtils.any(
           code.getConstantPool(),
           constant ->
@@ -187,8 +219,7 @@
     private LirCode<Integer> rewriteLir(ProgramMethod method, LirCode<Integer> code) {
       // Create a mapping from new-instance value index -> new-instance type (limited to the types
       // of interest).
-      Int2ReferenceMap<DexProgramClass> allocationsOfInterest =
-          getAllocationsOfInterest(method, code);
+      Int2ReferenceMap<DexType> allocationsOfInterest = getAllocationsOfInterest(method, code);
       if (allocationsOfInterest.isEmpty()) {
         return code;
       }
@@ -237,22 +268,19 @@
               byteWriter.toByteArray());
     }
 
-    private Int2ReferenceMap<DexProgramClass> getAllocationsOfInterest(
+    private Int2ReferenceMap<DexType> getAllocationsOfInterest(
         ProgramMethod method, LirCode<Integer> code) {
-      Int2ReferenceMap<DexProgramClass> allocationsOfInterest = new Int2ReferenceOpenHashMap<>();
+      Int2ReferenceMap<DexType> allocationsOfInterest = new Int2ReferenceOpenHashMap<>();
       if (method.getDefinition().isInstanceInitializer()) {
-        DexProgramClass classOfInterest =
-            ensureConstructorsOnClasses.get(method.getHolder().getSuperType());
-        if (classOfInterest != null) {
-          allocationsOfInterest.put(0, classOfInterest);
+        if (ensureConstructorsOnClasses.containsKey(method.getHolder().getSuperType())) {
+          allocationsOfInterest.put(0, method.getHolder().getSuperType());
         }
       }
       for (LirInstructionView view : code) {
         if (view.getOpcode() == LirOpcodes.NEW) {
           DexType type = (DexType) code.getConstantItem(view.getNextConstantOperand());
-          DexProgramClass classOfInterest = ensureConstructorsOnClasses.get(type);
-          if (classOfInterest != null) {
-            allocationsOfInterest.put(view.getValueIndex(code), classOfInterest);
+          if (ensureConstructorsOnClasses.containsKey(type)) {
+            allocationsOfInterest.put(view.getValueIndex(code), type);
           }
         }
       }
@@ -271,7 +299,7 @@
         LirCode<Integer> code,
         LirWriter lirWriter,
         LirInstructionView view,
-        Int2ReferenceMap<DexProgramClass> allocationsOfInterest) {
+        Int2ReferenceMap<DexType> allocationsOfInterest) {
       int opcode = view.getOpcode();
       if (LirOpcodes.isOneByteInstruction(opcode)) {
         lirWriter.writeOneByteInstruction(opcode);
@@ -289,11 +317,13 @@
           firstValue = view.getNextValueOperand();
           numReadOperands++;
           int receiver = code.decodeValueIndex(firstValue, view.getValueIndex(code));
-          DexProgramClass classOfInterest = allocationsOfInterest.get(receiver);
-          if (classOfInterest != null
-              && classOfInterest.getType().isNotIdenticalTo(invokedMethod.getHolderType())
-              && !isForwardingConstructorCall(method, invokedMethod, receiver)) {
-            return new InvokeDirectInfo(invokedMethod, firstValue, classOfInterest);
+          DexType newType = allocationsOfInterest.get(receiver);
+          if (newType != null
+              && isConstructorInlined(newType, invokedMethod)
+              && !isForwardingConstructorCall(method, invokedMethod, receiver)
+              && isSkippingRequiredConstructor(newType, invokedMethod)) {
+            return new InvokeDirectInfo(
+                invokedMethod, firstValue, ensureConstructorsOnClasses.get(newType));
           }
         }
       }
@@ -312,6 +342,10 @@
       return null;
     }
 
+    private boolean isConstructorInlined(DexType newType, DexMethod invokedMethod) {
+      return newType.isNotIdenticalTo(invokedMethod.getHolderType());
+    }
+
     private boolean isForwardingConstructorCall(
         ProgramMethod method, DexMethod invokedMethod, int receiver) {
       assert invokedMethod.isInstanceInitializer(appView.dexItemFactory());
@@ -320,6 +354,14 @@
           && receiver == 0;
     }
 
+    private boolean isSkippingRequiredConstructor(DexType newType, DexMethod invokedMethod) {
+      DexProgramClass requiredConstructorClass = ensureConstructorsOnClasses.get(newType);
+      assert requiredConstructorClass != null;
+      return !appView
+          .appInfo()
+          .isSubtype(invokedMethod.getHolderType(), requiredConstructorClass.getType());
+    }
+
     private StronglyConnectedComponent getStronglyConnectedComponent(DexProgramClass clazz) {
       return stronglyConnectedComponents.get(clazz);
     }
@@ -331,10 +373,10 @@
     private final int firstValue;
     private final DexProgramClass programClass;
 
-    InvokeDirectInfo(DexMethod invokedMethod, int firstValue, DexProgramClass programClass) {
+    InvokeDirectInfo(DexMethod invokedMethod, int firstValue, DexProgramClass newType) {
       this.invokedMethod = invokedMethod;
       this.firstValue = firstValue;
-      this.programClass = programClass;
+      this.programClass = newType;
     }
 
     DexMethod getInvokedMethod() {
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/UndoConstructorInliningWithSubclassesTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/UndoConstructorInliningWithSubclassesTest.java
new file mode 100644
index 0000000..eefd91f
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/UndoConstructorInliningWithSubclassesTest.java
@@ -0,0 +1,111 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.classmerging.horizontal;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.NoMethodStaticizing;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.BooleanUtils;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class UndoConstructorInliningWithSubclassesTest extends TestBase {
+
+  @Parameter(0)
+  public boolean neverInlineSubInitializers;
+
+  @Parameter(1)
+  public TestParameters parameters;
+
+  @Parameters(name = "{1}, never inline: {0}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        BooleanUtils.values(), getTestParameters().withAllRuntimesAndApiLevels().build());
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addHorizontallyMergedClassesInspector(
+            inspector ->
+                inspector.assertIsCompleteMergeGroup(A.class, B.class).assertNoOtherClassesMerged())
+        .addOptionsModification(
+            options -> options.horizontalClassMergerOptions().disableInitialRoundOfClassMerging())
+        // When inlining of ASub.<init> and BSub.<init> is allowed, we test that we correctly deal
+        // with the following code:
+        //   new-instance {v0}, LASub;
+        //   invoke-direct {v0}, Ljava/lang/Object;
+        //
+        // We also test the behavior when ASub.<init> and BSub.<init> are not inlined. In this case,
+        // we end up with the following instruction in ASub.<init>:
+        //   invoke-direct {v0}, Ljava/lang/Object;
+        //
+        // In both cases we need to synthesize a constructor to make sure we do not skip a
+        // constructor on A or B.
+        .applyIf(
+            neverInlineSubInitializers,
+            b -> b.addKeepRules("-neverinline class **$?Sub { void <init>(); }"))
+        .enableInliningAnnotations()
+        .enableNeverClassInliningAnnotations()
+        .enableNoHorizontalClassMergingAnnotations()
+        .enableNoMethodStaticizingAnnotations()
+        .enableNoVerticalClassMergingAnnotations()
+        .setMinApi(parameters)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("A", "B");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      new ASub().m();
+      new BSub().m();
+    }
+  }
+
+  @NoVerticalClassMerging
+  static class A {
+
+    @NeverInline
+    @NoMethodStaticizing
+    void m() {
+      System.out.println("A");
+    }
+  }
+
+  @NeverClassInline
+  @NoHorizontalClassMerging
+  static class ASub extends A {
+
+    ASub() {}
+  }
+
+  @NoVerticalClassMerging
+  static class B {
+
+    @NeverInline
+    @NoMethodStaticizing
+    void m() {
+      System.out.println("B");
+    }
+  }
+
+  @NeverClassInline
+  @NoHorizontalClassMerging
+  static class BSub extends B {
+
+    BSub() {}
+  }
+}