Ensure subclass instantiations does not skip required constructors
Change-Id: I1f72dde344fc3cd75926fe2399441419dc79ebb6
diff --git a/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java b/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
index d2f05df..2ce1db6 100644
--- a/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
@@ -91,7 +91,7 @@
DexEncodedMethod definition = method.getDefinition();
assert definition.getCode().isDefaultInstanceInitializerCode();
if (appView.testing().isSupportedLirPhase()) {
- method.setCode(get().toLirCode(appView, method), appView);
+ method.setCode(get().toLirCode(appView, method, superType), appView);
} else {
assert appView.testing().isPreLirPhase();
method.setCode(get().toCfCode(method, appView.dexItemFactory(), superType), appView);
@@ -378,12 +378,11 @@
return new CfCode(method.getHolderType(), getMaxStack(), getMaxLocals(method), instructions);
}
- public LirCode<?> toLirCode(AppView<?> appView, ProgramMethod method) {
+ public LirCode<?> toLirCode(AppView<?> appView, ProgramMethod method, DexType supertype) {
TypeElement receiverType =
method.getHolder().getType().toTypeElement(appView, Nullability.definitelyNotNull());
Value receiver = new Value(0, receiverType, null);
- DexMethod invokedMethod =
- appView.dexItemFactory().createInstanceInitializer(method.getHolder().getSuperType());
+ DexMethod invokedMethod = appView.dexItemFactory().createInstanceInitializer(supertype);
LirEncodingStrategy<Value, Integer> strategy =
LirStrategy.getDefaultStrategy().getEncodingStrategy();
strategy.defineValue(receiver, 0);
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java
index 915259c..f10c18b 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/UndoConstructorInlining.java
@@ -40,6 +40,7 @@
import com.android.tools.r8.utils.ObjectUtils;
import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.WorkList;
import it.unimi.dsi.fastutil.ints.Int2ReferenceMap;
import it.unimi.dsi.fastutil.ints.Int2ReferenceOpenHashMap;
import it.unimi.dsi.fastutil.objects.Reference2IntMap;
@@ -106,6 +107,9 @@
return;
}
+ // Extend the mapping to include subclasses.
+ ensureConstructorsOnSubclasses(ensureConstructorsOnClasses);
+
// Create a mapping from program classes to their strongly connected program component. When we
// need to synthesize a constructor on a class C we lock on the strongly connected component of
// C to ensure thread safety.
@@ -116,6 +120,25 @@
appView.dexItemFactory().clearTypeElementsCache();
}
+ private void ensureConstructorsOnSubclasses(
+ Map<DexType, DexProgramClass> ensureConstructorsOnClasses) {
+ // Perform a top-down traversal from each merge class and record that instantiations of
+ // subclasses of the merge class must not skip any constructors on the merge class.
+ Map<DexType, DexProgramClass> ensureConstructorsOnSubclasses = new IdentityHashMap<>();
+ for (DexProgramClass mergeClass : ensureConstructorsOnClasses.values()) {
+ WorkList.newIdentityWorkList(immediateSubtypingInfo.getSubclasses(mergeClass))
+ .process(
+ (current, worklist) -> {
+ if (ensureConstructorsOnClasses.containsKey(current.getType())) {
+ return;
+ }
+ ensureConstructorsOnSubclasses.put(current.getType(), mergeClass);
+ worklist.addIfNotSeen(immediateSubtypingInfo.getSubclasses(current));
+ });
+ }
+ ensureConstructorsOnClasses.putAll(ensureConstructorsOnSubclasses);
+ }
+
private Map<DexProgramClass, StronglyConnectedComponent> computeStronglyConnectedComponents() {
List<Set<DexProgramClass>> stronglyConnectedComponents =
new ProgramClassesBidirectedGraph(appView, immediateSubtypingInfo)
@@ -158,14 +181,15 @@
}
private void processClass(DexProgramClass clazz) {
- clazz.forEachProgramMethodMatching(this::filterMethod, this::processMethod);
+ clazz.forEachProgramMethodMatching(
+ method -> filterMethod(clazz, method), this::processMethod);
}
- private boolean filterMethod(DexEncodedMethod method) {
+ private boolean filterMethod(DexProgramClass clazz, DexEncodedMethod method) {
return method.hasCode()
&& method.getCode().isLirCode()
&& mayInstantiateClassOfInterest(
- method.getCode().asLirCode(), ensureConstructorsOnClasses);
+ clazz, method, method.getCode().asLirCode(), ensureConstructorsOnClasses);
}
private void processMethod(ProgramMethod method) {
@@ -177,7 +201,15 @@
}
private boolean mayInstantiateClassOfInterest(
- LirCode<Integer> code, Map<DexType, DexProgramClass> ensureConstructorsOnClasses) {
+ DexProgramClass clazz,
+ DexEncodedMethod method,
+ LirCode<Integer> code,
+ Map<DexType, DexProgramClass> ensureConstructorsOnClasses) {
+ // Treat constructors as allocations of the super class.
+ if (method.isInstanceInitializer()
+ && ensureConstructorsOnClasses.containsKey(clazz.getSuperType())) {
+ return true;
+ }
return ArrayUtils.any(
code.getConstantPool(),
constant ->
@@ -187,8 +219,7 @@
private LirCode<Integer> rewriteLir(ProgramMethod method, LirCode<Integer> code) {
// Create a mapping from new-instance value index -> new-instance type (limited to the types
// of interest).
- Int2ReferenceMap<DexProgramClass> allocationsOfInterest =
- getAllocationsOfInterest(method, code);
+ Int2ReferenceMap<DexType> allocationsOfInterest = getAllocationsOfInterest(method, code);
if (allocationsOfInterest.isEmpty()) {
return code;
}
@@ -237,22 +268,19 @@
byteWriter.toByteArray());
}
- private Int2ReferenceMap<DexProgramClass> getAllocationsOfInterest(
+ private Int2ReferenceMap<DexType> getAllocationsOfInterest(
ProgramMethod method, LirCode<Integer> code) {
- Int2ReferenceMap<DexProgramClass> allocationsOfInterest = new Int2ReferenceOpenHashMap<>();
+ Int2ReferenceMap<DexType> allocationsOfInterest = new Int2ReferenceOpenHashMap<>();
if (method.getDefinition().isInstanceInitializer()) {
- DexProgramClass classOfInterest =
- ensureConstructorsOnClasses.get(method.getHolder().getSuperType());
- if (classOfInterest != null) {
- allocationsOfInterest.put(0, classOfInterest);
+ if (ensureConstructorsOnClasses.containsKey(method.getHolder().getSuperType())) {
+ allocationsOfInterest.put(0, method.getHolder().getSuperType());
}
}
for (LirInstructionView view : code) {
if (view.getOpcode() == LirOpcodes.NEW) {
DexType type = (DexType) code.getConstantItem(view.getNextConstantOperand());
- DexProgramClass classOfInterest = ensureConstructorsOnClasses.get(type);
- if (classOfInterest != null) {
- allocationsOfInterest.put(view.getValueIndex(code), classOfInterest);
+ if (ensureConstructorsOnClasses.containsKey(type)) {
+ allocationsOfInterest.put(view.getValueIndex(code), type);
}
}
}
@@ -271,7 +299,7 @@
LirCode<Integer> code,
LirWriter lirWriter,
LirInstructionView view,
- Int2ReferenceMap<DexProgramClass> allocationsOfInterest) {
+ Int2ReferenceMap<DexType> allocationsOfInterest) {
int opcode = view.getOpcode();
if (LirOpcodes.isOneByteInstruction(opcode)) {
lirWriter.writeOneByteInstruction(opcode);
@@ -289,11 +317,13 @@
firstValue = view.getNextValueOperand();
numReadOperands++;
int receiver = code.decodeValueIndex(firstValue, view.getValueIndex(code));
- DexProgramClass classOfInterest = allocationsOfInterest.get(receiver);
- if (classOfInterest != null
- && classOfInterest.getType().isNotIdenticalTo(invokedMethod.getHolderType())
- && !isForwardingConstructorCall(method, invokedMethod, receiver)) {
- return new InvokeDirectInfo(invokedMethod, firstValue, classOfInterest);
+ DexType newType = allocationsOfInterest.get(receiver);
+ if (newType != null
+ && isConstructorInlined(newType, invokedMethod)
+ && !isForwardingConstructorCall(method, invokedMethod, receiver)
+ && isSkippingRequiredConstructor(newType, invokedMethod)) {
+ return new InvokeDirectInfo(
+ invokedMethod, firstValue, ensureConstructorsOnClasses.get(newType));
}
}
}
@@ -312,6 +342,10 @@
return null;
}
+ private boolean isConstructorInlined(DexType newType, DexMethod invokedMethod) {
+ return newType.isNotIdenticalTo(invokedMethod.getHolderType());
+ }
+
private boolean isForwardingConstructorCall(
ProgramMethod method, DexMethod invokedMethod, int receiver) {
assert invokedMethod.isInstanceInitializer(appView.dexItemFactory());
@@ -320,6 +354,14 @@
&& receiver == 0;
}
+ private boolean isSkippingRequiredConstructor(DexType newType, DexMethod invokedMethod) {
+ DexProgramClass requiredConstructorClass = ensureConstructorsOnClasses.get(newType);
+ assert requiredConstructorClass != null;
+ return !appView
+ .appInfo()
+ .isSubtype(invokedMethod.getHolderType(), requiredConstructorClass.getType());
+ }
+
private StronglyConnectedComponent getStronglyConnectedComponent(DexProgramClass clazz) {
return stronglyConnectedComponents.get(clazz);
}
@@ -331,10 +373,10 @@
private final int firstValue;
private final DexProgramClass programClass;
- InvokeDirectInfo(DexMethod invokedMethod, int firstValue, DexProgramClass programClass) {
+ InvokeDirectInfo(DexMethod invokedMethod, int firstValue, DexProgramClass newType) {
this.invokedMethod = invokedMethod;
this.firstValue = firstValue;
- this.programClass = programClass;
+ this.programClass = newType;
}
DexMethod getInvokedMethod() {
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/UndoConstructorInliningWithSubclassesTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/UndoConstructorInliningWithSubclassesTest.java
new file mode 100644
index 0000000..eefd91f
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/UndoConstructorInliningWithSubclassesTest.java
@@ -0,0 +1,111 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.classmerging.horizontal;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.NoMethodStaticizing;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.BooleanUtils;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class UndoConstructorInliningWithSubclassesTest extends TestBase {
+
+ @Parameter(0)
+ public boolean neverInlineSubInitializers;
+
+ @Parameter(1)
+ public TestParameters parameters;
+
+ @Parameters(name = "{1}, never inline: {0}")
+ public static List<Object[]> data() {
+ return buildParameters(
+ BooleanUtils.values(), getTestParameters().withAllRuntimesAndApiLevels().build());
+ }
+
+ @Test
+ public void test() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .addHorizontallyMergedClassesInspector(
+ inspector ->
+ inspector.assertIsCompleteMergeGroup(A.class, B.class).assertNoOtherClassesMerged())
+ .addOptionsModification(
+ options -> options.horizontalClassMergerOptions().disableInitialRoundOfClassMerging())
+ // When inlining of ASub.<init> and BSub.<init> is allowed, we test that we correctly deal
+ // with the following code:
+ // new-instance {v0}, LASub;
+ // invoke-direct {v0}, Ljava/lang/Object;
+ //
+ // We also test the behavior when ASub.<init> and BSub.<init> are not inlined. In this case,
+ // we end up with the following instruction in ASub.<init>:
+ // invoke-direct {v0}, Ljava/lang/Object;
+ //
+ // In both cases we need to synthesize a constructor to make sure we do not skip a
+ // constructor on A or B.
+ .applyIf(
+ neverInlineSubInitializers,
+ b -> b.addKeepRules("-neverinline class **$?Sub { void <init>(); }"))
+ .enableInliningAnnotations()
+ .enableNeverClassInliningAnnotations()
+ .enableNoHorizontalClassMergingAnnotations()
+ .enableNoMethodStaticizingAnnotations()
+ .enableNoVerticalClassMergingAnnotations()
+ .setMinApi(parameters)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("A", "B");
+ }
+
+ static class Main {
+
+ public static void main(String[] args) {
+ new ASub().m();
+ new BSub().m();
+ }
+ }
+
+ @NoVerticalClassMerging
+ static class A {
+
+ @NeverInline
+ @NoMethodStaticizing
+ void m() {
+ System.out.println("A");
+ }
+ }
+
+ @NeverClassInline
+ @NoHorizontalClassMerging
+ static class ASub extends A {
+
+ ASub() {}
+ }
+
+ @NoVerticalClassMerging
+ static class B {
+
+ @NeverInline
+ @NoMethodStaticizing
+ void m() {
+ System.out.println("B");
+ }
+ }
+
+ @NeverClassInline
+ @NoHorizontalClassMerging
+ static class BSub extends B {
+
+ BSub() {}
+ }
+}