Add support for StringSwitch in inlining constraint analysis
Bug: b/331337747
Change-Id: Ie6b742e336db7f2dd48b5f00a920354f24b4081a
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/inlining/ConstSimpleInliningConstraint.java b/src/main/java/com/android/tools/r8/ir/analysis/inlining/ConstSimpleInliningConstraint.java
new file mode 100644
index 0000000..4ca4fa3
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/analysis/inlining/ConstSimpleInliningConstraint.java
@@ -0,0 +1,61 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.analysis.inlining;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.proto.ArgumentInfoCollection;
+import com.android.tools.r8.graph.proto.RemovedArgumentInfo;
+import com.android.tools.r8.ir.analysis.value.SingleValue;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+
+public class ConstSimpleInliningConstraint extends SimpleInliningArgumentConstraint {
+
+ private ConstSimpleInliningConstraint(int argumentIndex) {
+ super(argumentIndex);
+ }
+
+ static ConstSimpleInliningConstraint create(
+ int argumentIndex, SimpleInliningConstraintFactory witness) {
+ assert witness != null;
+ return new ConstSimpleInliningConstraint(argumentIndex);
+ }
+
+ @Override
+ public boolean isSatisfied(InvokeMethod invoke) {
+ Value argumentRoot = invoke.getArgument(getArgumentIndex()).getAliasedValue();
+ return argumentRoot.isDefinedByInstructionSatisfying(Instruction::isConstInstruction);
+ }
+
+ @Override
+ public SimpleInliningConstraint fixupAfterParametersChanged(
+ AppView<AppInfoWithLiveness> appView,
+ ArgumentInfoCollection changes,
+ SimpleInliningConstraintFactory factory) {
+ if (changes.isArgumentRemoved(getArgumentIndex())) {
+ RemovedArgumentInfo removedArgumentInfo =
+ changes.getArgumentInfo(getArgumentIndex()).asRemovedArgumentInfo();
+ if (!removedArgumentInfo.hasSingleValue()) {
+ // We should never have constraints for unused arguments.
+ assert false;
+ return NeverSimpleInliningConstraint.getInstance();
+ }
+ SingleValue singleValue = removedArgumentInfo.getSingleValue();
+ return singleValue.isSingleConstValue()
+ ? AlwaysSimpleInliningConstraint.getInstance()
+ : NeverSimpleInliningConstraint.getInstance();
+ } else {
+ assert !changes.hasArgumentInfo(getArgumentIndex());
+ }
+ return withArgumentIndex(changes.getNewArgumentIndex(getArgumentIndex()), factory);
+ }
+
+ @Override
+ SimpleInliningArgumentConstraint withArgumentIndex(
+ int argumentIndex, SimpleInliningConstraintFactory factory) {
+ return factory.createConstConstraint(argumentIndex);
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintAnalysis.java b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintAnalysis.java
index 49d5674..36d43a0 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintAnalysis.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintAnalysis.java
@@ -7,17 +7,23 @@
import static com.android.tools.r8.ir.code.Opcodes.GOTO;
import static com.android.tools.r8.ir.code.Opcodes.IF;
import static com.android.tools.r8.ir.code.Opcodes.RETURN;
+import static com.android.tools.r8.ir.code.Opcodes.STRING_SWITCH;
import static com.android.tools.r8.ir.code.Opcodes.THROW;
import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.code.Argument;
import com.android.tools.r8.ir.code.BasicBlock;
import com.android.tools.r8.ir.code.IRCode;
import com.android.tools.r8.ir.code.If;
import com.android.tools.r8.ir.code.IfType;
import com.android.tools.r8.ir.code.Instruction;
import com.android.tools.r8.ir.code.InstructionIterator;
+import com.android.tools.r8.ir.code.InvokeVirtual;
+import com.android.tools.r8.ir.code.JumpInstruction;
+import com.android.tools.r8.ir.code.StringSwitch;
import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.InternalOptions;
@@ -40,7 +46,8 @@
*/
public class SimpleInliningConstraintAnalysis {
- private final SimpleInliningConstraintFactory factory;
+ private final SimpleInliningConstraintFactory constraintFactory;
+ private final DexItemFactory dexItemFactory;
private final ProgramMethod method;
private final InternalOptions options;
private final int simpleInliningConstraintThreshold;
@@ -49,7 +56,8 @@
public SimpleInliningConstraintAnalysis(
AppView<AppInfoWithLiveness> appView, ProgramMethod method) {
- this.factory = appView.simpleInliningConstraintFactory();
+ this.constraintFactory = appView.simpleInliningConstraintFactory();
+ this.dexItemFactory = appView.dexItemFactory();
this.method = method;
this.options = appView.options();
this.simpleInliningConstraintThreshold = appView.options().simpleInliningConstraintThreshold;
@@ -87,11 +95,18 @@
// Move the instruction iterator forward to the block's jump instruction, while incrementing the
// instruction depth of the depth-first traversal.
Instruction instruction = instructionIterator.next();
+ SimpleInliningConstraint blockConstraint = AlwaysSimpleInliningConstraint.getInstance();
while (!instruction.isJumpInstruction()) {
assert !instruction.isArgument();
assert !instruction.isDebugInstruction();
- if (!instruction.isAssume()) {
- instructionDepth += 1;
+ SimpleInliningConstraint instructionConstraint =
+ computeConstraintForInstructionNotToMaterialize(instruction);
+ if (instructionConstraint.isAlways()) {
+ assert instruction.isAssume();
+ } else if (instructionConstraint.isNever()) {
+ instructionDepth++;
+ } else {
+ blockConstraint = blockConstraint.meet(instructionConstraint);
}
instruction = instructionIterator.next();
}
@@ -102,8 +117,36 @@
return NeverSimpleInliningConstraint.getInstance();
}
- // Analyze the jump instruction.
- // TODO(b/132600418): Extend to switch and throw instructions.
+ SimpleInliningConstraint jumpConstraint =
+ computeConstraintForJumpInstruction(instruction.asJumpInstruction(), instructionDepth);
+ return blockConstraint.meet(jumpConstraint);
+ }
+
+ private SimpleInliningConstraint computeConstraintForInstructionNotToMaterialize(
+ Instruction instruction) {
+ if (instruction.isAssume()) {
+ return AlwaysSimpleInliningConstraint.getInstance();
+ }
+ if (instruction.isInvokeVirtual()) {
+ InvokeVirtual invoke = instruction.asInvokeVirtual();
+ if (invoke.getInvokedMethod().isIdenticalTo(dexItemFactory.objectMembers.getClass)
+ && invoke.hasUnusedOutValue()) {
+ Value receiver = invoke.getReceiver();
+ if (receiver.getType().isDefinitelyNotNull()) {
+ return AlwaysSimpleInliningConstraint.getInstance();
+ }
+ Value receiverRoot = receiver.getAliasedValue();
+ if (receiverRoot.isDefinedByInstructionSatisfying(Instruction::isArgument)) {
+ Argument argument = receiverRoot.getDefinition().asArgument();
+ return constraintFactory.createNotEqualToNullConstraint(argument.getIndex());
+ }
+ }
+ }
+ return NeverSimpleInliningConstraint.getInstance();
+ }
+
+ private SimpleInliningConstraint computeConstraintForJumpInstruction(
+ JumpInstruction instruction, int instructionDepth) {
switch (instruction.opcode()) {
case IF:
If ifInstruction = instruction.asIf();
@@ -154,8 +197,26 @@
case RETURN:
return AlwaysSimpleInliningConstraint.getInstance();
+ case STRING_SWITCH:
+ // Require that all cases including the default case are simple. In that case we can
+ // guarantee simpleness by requiring that the switch value is constant.
+ StringSwitch stringSwitch = instruction.asStringSwitch();
+ Value valueRoot = stringSwitch.value().getAliasedValue();
+ if (!valueRoot.isDefinedByInstructionSatisfying(Instruction::isArgument)) {
+ return NeverSimpleInliningConstraint.getInstance();
+ }
+ for (BasicBlock successor : stringSwitch.getBlock().getNormalSuccessors()) {
+ SimpleInliningConstraint successorConstraint =
+ analyzeInstructionsInBlock(successor, instructionDepth);
+ if (!successorConstraint.isAlways()) {
+ return NeverSimpleInliningConstraint.getInstance();
+ }
+ }
+ Argument argument = valueRoot.getDefinition().asArgument();
+ return constraintFactory.createConstConstraint(argument.getIndex());
+
case THROW:
- return block.hasCatchHandlers()
+ return instruction.getBlock().hasCatchHandlers()
? NeverSimpleInliningConstraint.getInstance()
: AlwaysSimpleInliningConstraint.getInstance();
@@ -174,15 +235,16 @@
case EQ:
if (isZeroTest) {
if (argumentType.isReferenceType()) {
- return factory.createEqualToNullConstraint(argumentIndex);
+ return constraintFactory.createEqualToNullConstraint(argumentIndex);
}
if (argumentType.isBooleanType()) {
- return factory.createEqualToFalseConstraint(argumentIndex);
+ return constraintFactory.createEqualToFalseConstraint(argumentIndex);
}
} else if (argumentType.isPrimitiveType()) {
OptionalLong rawValue = getRawNumberValue(otherOperand);
if (rawValue.isPresent()) {
- return factory.createEqualToNumberConstraint(argumentIndex, rawValue.getAsLong());
+ return constraintFactory.createEqualToNumberConstraint(
+ argumentIndex, rawValue.getAsLong());
}
}
return NeverSimpleInliningConstraint.getInstance();
@@ -190,15 +252,16 @@
case NE:
if (isZeroTest) {
if (argumentType.isReferenceType()) {
- return factory.createNotEqualToNullConstraint(argumentIndex);
+ return constraintFactory.createNotEqualToNullConstraint(argumentIndex);
}
if (argumentType.isBooleanType()) {
- return factory.createEqualToTrueConstraint(argumentIndex);
+ return constraintFactory.createEqualToTrueConstraint(argumentIndex);
}
} else if (argumentType.isPrimitiveType()) {
OptionalLong rawValue = getRawNumberValue(otherOperand);
if (rawValue.isPresent()) {
- return factory.createNotEqualToNumberConstraint(argumentIndex, rawValue.getAsLong());
+ return constraintFactory.createNotEqualToNumberConstraint(
+ argumentIndex, rawValue.getAsLong());
}
}
return NeverSimpleInliningConstraint.getInstance();
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintFactory.java b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintFactory.java
index 5b60ccd..39a8164 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintFactory.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintFactory.java
@@ -8,6 +8,7 @@
import static com.android.tools.r8.ir.analysis.type.Nullability.definitelyNull;
import com.android.tools.r8.ir.analysis.type.Nullability;
+import com.android.tools.r8.utils.ArrayUtils;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
@@ -15,16 +16,29 @@
public class SimpleInliningConstraintFactory {
// Immutable argument constraints for low argument indices to avoid overhead of ConcurrentHashMap.
+ private final ConstSimpleInliningConstraint[] lowConstConstraints =
+ ArrayUtils.initialize(
+ new ConstSimpleInliningConstraint[5], i -> ConstSimpleInliningConstraint.create(i, this));
private final EqualToBooleanSimpleInliningConstraint[] lowEqualToFalseConstraints =
- new EqualToBooleanSimpleInliningConstraint[5];
+ ArrayUtils.initialize(
+ new EqualToBooleanSimpleInliningConstraint[5],
+ i -> EqualToBooleanSimpleInliningConstraint.create(i, false, this));
private final EqualToBooleanSimpleInliningConstraint[] lowEqualToTrueConstraints =
- new EqualToBooleanSimpleInliningConstraint[5];
+ ArrayUtils.initialize(
+ new EqualToBooleanSimpleInliningConstraint[5],
+ i -> EqualToBooleanSimpleInliningConstraint.create(i, true, this));
private final NullSimpleInliningConstraint[] lowNotEqualToNullConstraints =
- new NullSimpleInliningConstraint[5];
+ ArrayUtils.initialize(
+ new NullSimpleInliningConstraint[5],
+ i -> NullSimpleInliningConstraint.create(i, definitelyNotNull(), this));
private final NullSimpleInliningConstraint[] lowEqualToNullConstraints =
- new NullSimpleInliningConstraint[5];
+ ArrayUtils.initialize(
+ new NullSimpleInliningConstraint[5],
+ i -> NullSimpleInliningConstraint.create(i, definitelyNull(), this));
// Argument constraints for high argument indices.
+ private final Map<Integer, ConstSimpleInliningConstraint> highConstConstraints =
+ new ConcurrentHashMap<>();
private final Map<Integer, EqualToBooleanSimpleInliningConstraint> highEqualToFalseConstraints =
new ConcurrentHashMap<>();
private final Map<Integer, EqualToBooleanSimpleInliningConstraint> highEqualToTrueConstraints =
@@ -34,20 +48,12 @@
private final Map<Integer, NullSimpleInliningConstraint> highEqualToNullConstraints =
new ConcurrentHashMap<>();
- public SimpleInliningConstraintFactory() {
- for (int i = 0; i < lowEqualToFalseConstraints.length; i++) {
- lowEqualToFalseConstraints[i] = EqualToBooleanSimpleInliningConstraint.create(i, false, this);
- }
- for (int i = 0; i < lowEqualToTrueConstraints.length; i++) {
- lowEqualToTrueConstraints[i] = EqualToBooleanSimpleInliningConstraint.create(i, true, this);
- }
- for (int i = 0; i < lowNotEqualToNullConstraints.length; i++) {
- lowNotEqualToNullConstraints[i] =
- NullSimpleInliningConstraint.create(i, definitelyNotNull(), this);
- }
- for (int i = 0; i < lowEqualToNullConstraints.length; i++) {
- lowEqualToNullConstraints[i] = NullSimpleInliningConstraint.create(i, definitelyNull(), this);
- }
+ public ConstSimpleInliningConstraint createConstConstraint(int argumentIndex) {
+ return createArgumentConstraint(
+ argumentIndex,
+ lowConstConstraints,
+ highConstConstraints,
+ () -> ConstSimpleInliningConstraint.create(argumentIndex, this));
}
public EqualToBooleanSimpleInliningConstraint createEqualToFalseConstraint(int argumentIndex) {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 9f18e74..bd261a3 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -120,7 +120,6 @@
protected final IdentifierNameStringMarker identifierNameStringMarker;
private final Devirtualizer devirtualizer;
protected final CovariantReturnTypeAnnotationTransformer covariantReturnTypeAnnotationTransformer;
- private final StringSwitchRemover stringSwitchRemover;
private final TypeChecker typeChecker;
protected EnumUnboxer enumUnboxer;
protected final NumberUnboxer numberUnboxer;
@@ -202,7 +201,6 @@
this.identifierNameStringMarker = null;
this.devirtualizer = null;
this.typeChecker = null;
- this.stringSwitchRemover = null;
this.methodOptimizationInfoCollector = null;
this.enumUnboxer = EnumUnboxer.empty();
this.numberUnboxer = NumberUnboxer.empty();
@@ -275,10 +273,6 @@
this.enumUnboxer = EnumUnboxer.empty();
this.numberUnboxer = NumberUnboxer.empty();
}
- this.stringSwitchRemover =
- options.isStringSwitchConversionEnabled()
- ? new StringSwitchRemover(appView, identifierNameStringMarker)
- : null;
}
public IRConverter(AppInfo appInfo) {
@@ -773,11 +767,10 @@
.run(code, methodProcessor, methodProcessingContext, timing);
}
- if (code.getConversionOptions().isStringSwitchConversionEnabled()) {
- // Remove string switches prior to canonicalization to ensure that the constants that are
- // being introduced will be canonicalized if possible.
- stringSwitchRemover.run(code, methodProcessor, methodProcessingContext, timing);
- }
+ // Remove string switches prior to canonicalization to ensure that the constants that are
+ // being introduced will be canonicalized if possible.
+ new StringSwitchRemover(appView, identifierNameStringMarker)
+ .run(code, methodProcessor, methodProcessingContext, timing);
// TODO(mkroghj) Test if shorten live ranges is worth it.
if (options.isGeneratingDex()) {
@@ -967,9 +960,7 @@
IRCode code, OptimizationFeedback feedback, Timing timing) {
if (!code.getConversionOptions().isGeneratingLir()) {
new FilledNewArrayRewriter(appView).run(code, timing);
- }
- if (stringSwitchRemover != null) {
- stringSwitchRemover.run(code, timing);
+ new StringSwitchRemover(appView, identifierNameStringMarker).run(code, timing);
}
code.removeRedundantBlocks();
deadCodeRemover.run(code, timing);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LirConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/LirConverter.java
index 6da063b..846d0c7 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LirConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LirConverter.java
@@ -143,11 +143,12 @@
new CodeRewriterPassCollection(
new AdaptClassStringsRewriter(appView),
new ConstResourceNumberRemover(appView),
+ // Must run before DexItemBasedConstStringRemover.
+ new StringSwitchRemover(appView),
new DexItemBasedConstStringRemover(appView),
new InitClassRemover(appView),
new RecordInvokeDynamicInvokeCustomRewriter(appView),
- new FilledNewArrayRewriter(appView),
- new StringSwitchRemover(appView));
+ new FilledNewArrayRewriter(appView));
ThreadUtils.processItems(
appView.appInfo().classes(),
clazz ->
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/StringSwitchRemover.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/StringSwitchRemover.java
index f35f820..74fa808 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/StringSwitchRemover.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/StringSwitchRemover.java
@@ -84,7 +84,7 @@
while (blockIterator.hasNext()) {
BasicBlock block = blockIterator.next();
StringSwitch theSwitch = block.exit().asStringSwitch();
- if (theSwitch != null) {
+ if (theSwitch != null && shouldBeRemoved(code, theSwitch)) {
try {
SingleStringSwitchRemover remover;
if (theSwitch.numberOfKeys() < appView.options().minimumStringSwitchSize
@@ -132,7 +132,7 @@
BasicBlock block = blockIterator.next();
for (BasicBlock predecessor : block.getNormalPredecessors()) {
StringSwitch exit = predecessor.exit().asStringSwitch();
- if (exit != null) {
+ if (exit != null && shouldBeRemoved(code, exit)) {
hasStringSwitch = true;
if (block == exit.fallthroughBlock()) {
// After the elimination of this string-switch instruction, there will be two
@@ -164,6 +164,15 @@
return hasStringSwitch;
}
+ private boolean shouldBeRemoved(IRCode code, StringSwitch theSwitch) {
+ // We only support retaining StringSwitch instructions in LIR. However, even when compiling to
+ // LIR, we (currently) need to remove StringSwitch instructions where the keys may be class
+ // names, so that these DexItemBasedConstStrings are correctly lens code rewritten.
+ // (Note this could be avoided by introducing a separate DexItemBasedStringSwitch instruction.)
+ return !code.getConversionOptions().isGeneratingLir()
+ || isClassNameValue(theSwitch.value(), dexItemFactory);
+ }
+
@Override
protected boolean shouldRewriteCode(IRCode code, MethodProcessor methodProcessor) {
return code.metadata().mayHaveStringSwitch();
diff --git a/src/test/java/com/android/tools/r8/ir/conversion/DexItemBasedStringSwitchTest.java b/src/test/java/com/android/tools/r8/ir/conversion/DexItemBasedStringSwitchTest.java
new file mode 100644
index 0000000..6b92f2a
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/conversion/DexItemBasedStringSwitchTest.java
@@ -0,0 +1,63 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.conversion;
+
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DexItemBasedStringSwitchTest extends TestBase {
+
+ @Parameter(0)
+ public TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ @Test
+ public void test() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .addKeepRules("-repackageclasses")
+ .enableNoHorizontalClassMergingAnnotations()
+ .setMinApi(parameters)
+ .compile()
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("A");
+ }
+
+ static class Main {
+
+ public static void main(String[] args) {
+ Object o = System.currentTimeMillis() > 0 ? new A() : new B();
+ switch (o.getClass().getName()) {
+ case "com.android.tools.r8.ir.conversion.DexItemBasedStringSwitchTest$A":
+ System.out.println("A");
+ return;
+ case "com.android.tools.r8.ir.conversion.DexItemBasedStringSwitchTest$B":
+ System.out.println("B");
+ break;
+ default:
+ System.out.println("Neither");
+ break;
+ }
+ }
+ }
+
+ @NoHorizontalClassMerging
+ static class A {}
+
+ @NoHorizontalClassMerging
+ static class B {}
+}