Rewrite references to merged lambda types similar to lens code rewriting
This is needed to avoid type related assertion errors when we start forwarding dynamic type information from field reads that are known to have a more precise type than the static type of the field.
Change-Id: I2cd1dd127b7dec44eb397f585b8865be8ae5b9d3
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index dd3416f..79772d1 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -287,10 +287,11 @@
options.enableTreeShakingOfLibraryMethodOverrides
? new LibraryMethodOverrideAnalysis(appViewWithLiveness)
: null;
- this.lensCodeRewriter = new LensCodeRewriter(appViewWithLiveness, lambdaRewriter);
- this.inliner = new Inliner(appViewWithLiveness, mainDexClasses, lensCodeRewriter);
this.lambdaMerger =
options.enableLambdaMerging ? new LambdaMerger(appViewWithLiveness) : null;
+ this.lensCodeRewriter = new LensCodeRewriter(appViewWithLiveness, lambdaRewriter);
+ this.inliner =
+ new Inliner(appViewWithLiveness, mainDexClasses, lambdaMerger, lensCodeRewriter);
this.outliner = new Outliner(appViewWithLiveness, this);
this.memberValuePropagation =
options.enableValuePropagation ? new MemberValuePropagation(appViewWithLiveness) : null;
@@ -879,7 +880,7 @@
private void collectLambdaMergingCandidates(DexApplication application) {
if (lambdaMerger != null) {
- lambdaMerger.collectGroupCandidates(application, appView.withLiveness());
+ lambdaMerger.collectGroupCandidates(application);
}
}
@@ -1116,6 +1117,11 @@
}
}
+ if (lambdaMerger != null) {
+ lambdaMerger.rewriteCode(method, code);
+ assert code.isConsistentSSA();
+ }
+
if (typeChecker != null && !typeChecker.check(code)) {
assert appView.enableWholeProgramOptimizations();
assert options.testing.allowTypeErrors;
@@ -1351,7 +1357,7 @@
previous = printMethod(code, "IR after twr close resource rewriter (SSA)", previous);
if (lambdaMerger != null) {
- lambdaMerger.processMethodCode(method, code);
+ lambdaMerger.analyzeCode(method, code);
assert code.isConsistentSSA();
}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
index fef9538..284ebe7 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
@@ -46,6 +46,7 @@
import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackIgnore;
import com.android.tools.r8.ir.optimize.inliner.NopWhyAreYouNotInliningReporter;
import com.android.tools.r8.ir.optimize.inliner.WhyAreYouNotInliningReporter;
+import com.android.tools.r8.ir.optimize.lambda.LambdaMerger;
import com.android.tools.r8.kotlin.Kotlin;
import com.android.tools.r8.origin.Origin;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
@@ -70,6 +71,7 @@
protected final AppView<AppInfoWithLiveness> appView;
private final Set<DexMethod> blacklist;
+ private final LambdaMerger lambdaMerger;
private final LensCodeRewriter lensCodeRewriter;
final MainDexClasses mainDexClasses;
@@ -82,10 +84,12 @@
public Inliner(
AppView<AppInfoWithLiveness> appView,
MainDexClasses mainDexClasses,
+ LambdaMerger lambdaMerger,
LensCodeRewriter lensCodeRewriter) {
Kotlin.Intrinsics intrinsics = appView.dexItemFactory().kotlin.intrinsics;
this.appView = appView;
this.blacklist = ImmutableSet.of(intrinsics.throwNpe, intrinsics.throwParameterIsNullException);
+ this.lambdaMerger = lambdaMerger;
this.lensCodeRewriter = lensCodeRewriter;
this.mainDexClasses = mainDexClasses;
}
@@ -585,6 +589,7 @@
ValueNumberGenerator generator,
AppView<? extends AppInfoWithSubtyping> appView,
Position callerPosition,
+ LambdaMerger lambdaMerger,
LensCodeRewriter lensCodeRewriter) {
DexItemFactory dexItemFactory = appView.dexItemFactory();
InternalOptions options = appView.options();
@@ -752,6 +757,9 @@
if (!target.isProcessed()) {
lensCodeRewriter.rewrite(code, target);
}
+ if (lambdaMerger != null) {
+ lambdaMerger.rewriteCodeForInlining(target, code, context);
+ }
assert code.isConsistentSSA();
return new InlineeWithReason(code, reason);
}
@@ -910,6 +918,7 @@
code.valueNumberGenerator,
appView,
getPositionForInlining(invoke, context),
+ lambdaMerger,
lensCodeRewriter);
if (strategy.willExceedBudget(
code, invoke, inlinee, block, whyAreYouNotInliningReporter)) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/lambda/CodeProcessor.java b/src/main/java/com/android/tools/r8/ir/optimize/lambda/CodeProcessor.java
index 5c6cb8f..f10c3fe 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/lambda/CodeProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/lambda/CodeProcessor.java
@@ -9,6 +9,7 @@
import com.android.tools.r8.graph.DexEncodedMethod;
import com.android.tools.r8.graph.DexField;
import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.ir.code.BasicBlock;
import com.android.tools.r8.ir.code.CheckCast;
@@ -147,12 +148,25 @@
public final ListIterator<BasicBlock> blocks;
private InstructionListIterator instructions;
+ // The inlining context (caller), if any.
+ private final DexEncodedMethod context;
+
CodeProcessor(
AppView<AppInfoWithLiveness> appView,
Function<DexType, Strategy> strategyProvider,
LambdaTypeVisitor lambdaChecker,
DexEncodedMethod method,
IRCode code) {
+ this(appView, strategyProvider, lambdaChecker, method, code, null);
+ }
+
+ CodeProcessor(
+ AppView<AppInfoWithLiveness> appView,
+ Function<DexType, Strategy> strategyProvider,
+ LambdaTypeVisitor lambdaChecker,
+ DexEncodedMethod method,
+ IRCode code,
+ DexEncodedMethod context) {
this.appView = appView;
this.strategyProvider = strategyProvider;
this.factory = appView.dexItemFactory();
@@ -161,6 +175,7 @@
this.method = method;
this.code = code;
this.blocks = code.listIterator();
+ this.context = context;
}
public final InstructionListIterator instructions() {
@@ -178,6 +193,19 @@
}
}
+ private boolean shouldRewrite(DexField field) {
+ return shouldRewrite(field.holder);
+ }
+
+ private boolean shouldRewrite(DexMethod method) {
+ return shouldRewrite(method.holder);
+ }
+
+ private boolean shouldRewrite(DexType type) {
+ // Rewrite references to lambda classes if we are outside the class.
+ return type != (context != null ? context : method).method.holder;
+ }
+
@Override
public Void handleInvoke(Invoke invoke) {
if (invoke.isInvokeNewArray()) {
@@ -199,7 +227,7 @@
// Invalidate signature, there still should not be lambda references.
lambdaChecker.accept(invokeMethod.getInvokedMethod().proto);
// Only rewrite references to lambda classes if we are outside the class.
- if (invokeMethod.getInvokedMethod().holder != this.method.method.holder) {
+ if (shouldRewrite(invokeMethod.getInvokedMethod())) {
process(strategy, invokeMethod);
}
return null;
@@ -218,7 +246,7 @@
Strategy strategy = strategyProvider.apply(newInstance.clazz);
if (strategy.isValidNewInstance(this, newInstance)) {
// Only rewrite references to lambda classes if we are outside the class.
- if (newInstance.clazz != this.method.method.holder) {
+ if (shouldRewrite(newInstance.clazz)) {
process(strategy, newInstance);
}
}
@@ -260,7 +288,7 @@
DexField field = instanceGet.getField();
Strategy strategy = strategyProvider.apply(field.holder);
if (strategy.isValidInstanceFieldRead(this, field)) {
- if (field.holder != this.method.method.holder) {
+ if (shouldRewrite(field)) {
// Only rewrite references to lambda classes if we are outside the class.
process(strategy, instanceGet);
}
@@ -279,7 +307,7 @@
DexField field = instancePut.getField();
Strategy strategy = strategyProvider.apply(field.holder);
if (strategy.isValidInstanceFieldWrite(this, field)) {
- if (field.holder != this.method.method.holder) {
+ if (shouldRewrite(field)) {
// Only rewrite references to lambda classes if we are outside the class.
process(strategy, instancePut);
}
@@ -298,7 +326,7 @@
DexField field = staticGet.getField();
Strategy strategy = strategyProvider.apply(field.holder);
if (strategy.isValidStaticFieldRead(this, field)) {
- if (field.holder != this.method.method.holder) {
+ if (shouldRewrite(field)) {
// Only rewrite references to lambda classes if we are outside the class.
process(strategy, staticGet);
}
@@ -314,7 +342,7 @@
DexField field = staticPut.getField();
Strategy strategy = strategyProvider.apply(field.holder);
if (strategy.isValidStaticFieldWrite(this, field)) {
- if (field.holder != this.method.method.holder) {
+ if (shouldRewrite(field)) {
// Only rewrite references to lambda classes if we are outside the class.
process(strategy, staticPut);
}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java b/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
index a90704a..b226521 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/lambda/LambdaMerger.java
@@ -56,7 +56,6 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
-import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -84,6 +83,43 @@
// 5. synthesize group lambda classes.
//
public final class LambdaMerger {
+
+ private abstract static class Mode {
+
+ void rewriteCode(DexEncodedMethod method, IRCode code, DexEncodedMethod context) {}
+
+ void analyzeCode(DexEncodedMethod method, IRCode code) {}
+ }
+
+ private class AnalyzeMode extends Mode {
+
+ @Override
+ void analyzeCode(DexEncodedMethod method, IRCode code) {
+ new AnalysisStrategy(method, code).processCode();
+ }
+ }
+
+ private class ApplyMode extends Mode {
+
+ private final Set<DexType> lambdaGroupsClasses;
+ private final LambdaMergerOptimizationInfoFixer optimizationInfoFixer;
+
+ ApplyMode(
+ Set<DexType> lambdaGroupTypes, LambdaMergerOptimizationInfoFixer optimizationInfoFixer) {
+ this.lambdaGroupsClasses = lambdaGroupTypes;
+ this.optimizationInfoFixer = optimizationInfoFixer;
+ }
+
+ @Override
+ void rewriteCode(DexEncodedMethod method, IRCode code, DexEncodedMethod context) {
+ if (lambdaGroupsClasses.contains(method.method.holder)) {
+ // Don't rewrite the methods that we have synthesized for the lambda group classes.
+ return;
+ }
+ new ApplyStrategy(method, code, context, optimizationInfoFixer).processCode();
+ }
+ }
+
// Maps lambda into a group, only contains lambdas we decided to merge.
// NOTE: needs synchronization.
private final Map<DexType, LambdaGroup> lambdas = new IdentityHashMap<>();
@@ -113,7 +149,7 @@
private final Kotlin kotlin;
private final DiagnosticsHandler reporter;
- private BiFunction<DexEncodedMethod, IRCode, CodeProcessor> strategyFactory = null;
+ private Mode mode;
// Lambda visitor invalidating lambdas it sees.
private final LambdaTypeVisitor lambdaInvalidator;
@@ -153,8 +189,7 @@
// Collect all group candidates and assign unique lambda ids inside each group.
// We do this before methods are being processed to guarantee stable order of
// lambdas inside each group.
- public final void collectGroupCandidates(
- DexApplication app, AppView<AppInfoWithLiveness> appView) {
+ public final void collectGroupCandidates(DexApplication app) {
// Collect lambda groups.
app.classes().stream()
.filter(cls -> !appView.appInfo().isPinned(cls.type))
@@ -188,20 +223,50 @@
// Remove trivial groups.
removeTrivialLambdaGroups();
- assert strategyFactory == null;
- strategyFactory = AnalysisStrategy::new;
+ assert mode == null;
+ mode = new AnalyzeMode();
}
- // Is called by IRConverter::rewriteCode, performs different actions
- // depending on phase:
- // - in ANALYZE phase just analyzes invalid usages of lambda classes
- // inside the method code, invalidated such lambda classes,
- // collects methods that need to be patched.
- // - in APPLY phase patches the code to use lambda group classes, also
- // asserts that there are no more invalid lambda class references.
- public final void processMethodCode(DexEncodedMethod method, IRCode code) {
- if (strategyFactory != null) {
- strategyFactory.apply(method, code).processCode();
+ /**
+ * Is called by IRConverter::rewriteCode. Performs different actions depending on the current
+ * mode.
+ *
+ * <ol>
+ * <li>in ANALYZE mode analyzes invalid usages of lambda classes inside the method code,
+ * invalidated such lambda classes, collects methods that need to be patched.
+ * <li>in APPLY mode does nothing.
+ * </ol>
+ */
+ public final void analyzeCode(DexEncodedMethod method, IRCode code) {
+ if (mode != null) {
+ mode.analyzeCode(method, code);
+ }
+ }
+
+ /**
+ * Is called by IRConverter::rewriteCode. Performs different actions depending on the current
+ * mode.
+ *
+ * <ol>
+ * <li>in ANALYZE mode does nothing.
+ * <li>in APPLY mode patches the code to use lambda group classes, also asserts that there are
+ * no more invalid lambda class references.
+ * </ol>
+ */
+ public final void rewriteCode(DexEncodedMethod method, IRCode code) {
+ if (mode != null) {
+ mode.rewriteCode(method, code, null);
+ }
+ }
+
+ /**
+ * Similar to {@link #rewriteCode(DexEncodedMethod, IRCode)}, but for rewriting code for inlining.
+ * The {@param context} is the caller that {@param method} is being inlined into.
+ */
+ public final void rewriteCodeForInlining(
+ DexEncodedMethod method, IRCode code, DexEncodedMethod context) {
+ if (mode != null) {
+ mode.rewriteCode(method, code, context);
}
}
@@ -239,7 +304,9 @@
feedback.fixupOptimizationInfos(appView, executorService, optimizationInfoFixer);
// Switch to APPLY strategy.
- this.strategyFactory = (method, code) -> new ApplyStrategy(method, code, optimizationInfoFixer);
+ Set<DexType> lambdaGroupTypes =
+ lambdaGroupsClasses.values().stream().map(clazz -> clazz.type).collect(Collectors.toSet());
+ this.mode = new ApplyMode(lambdaGroupTypes, optimizationInfoFixer);
// Add synthesized lambda group classes to the builder.
for (Entry<LambdaGroup, DexProgramClass> entry : lambdaGroupsClasses.entrySet()) {
@@ -274,7 +341,7 @@
// Rewrite lambda class references into lambda group class
// references inside methods from the processing queue.
rewriteLambdaReferences(converter, executorService, feedback);
- this.strategyFactory = null;
+ this.mode = null;
}
private void analyzeReferencesInProgramClasses(
@@ -459,13 +526,15 @@
private ApplyStrategy(
DexEncodedMethod method,
IRCode code,
+ DexEncodedMethod context,
LambdaMergerOptimizationInfoFixer optimizationInfoFixer) {
super(
LambdaMerger.this.appView,
LambdaMerger.this::strategyProvider,
lambdaChecker,
method,
- code);
+ code,
+ context);
this.optimizationInfoFixer = optimizationInfoFixer;
}