Optimize bitwise operations in nested @Composable functions
Bug: b/302483644
Change-Id: I8b4329b014fd6f8e4a3fe261dca960512f14739b
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/constant/SparseConditionalConstantPropagation.java b/src/main/java/com/android/tools/r8/ir/analysis/constant/SparseConditionalConstantPropagation.java
index b3b0924..38b433a 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/constant/SparseConditionalConstantPropagation.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/constant/SparseConditionalConstantPropagation.java
@@ -57,9 +57,13 @@
return true;
}
+ public Map<Value, AbstractValue> analyze(IRCode code) {
+ return new SparseConditionalConstantPropagationOnCode(code).analyze().mapping;
+ }
+
@Override
protected CodeRewriterResult rewriteCode(IRCode code) {
- return new SparseConditionalConstantPropagationOnCode(code).run();
+ return new SparseConditionalConstantPropagationOnCode(code).analyze().run();
}
private class SparseConditionalConstantPropagationOnCode {
@@ -86,7 +90,7 @@
visitedBlocks = new BitSet(maxBlockNumber);
}
- protected CodeRewriterResult run() {
+ public SparseConditionalConstantPropagationOnCode analyze() {
BasicBlock firstBlock = code.entryBlock();
visitInstructions(firstBlock);
@@ -113,6 +117,10 @@
}
}
}
+ return this;
+ }
+
+ protected CodeRewriterResult run() {
boolean hasChanged = rewriteConstants();
return CodeRewriterResult.hasChanged(hasChanged);
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index b47b900..bdc3ec5 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -928,6 +928,10 @@
appView.withArgumentPropagator(
argumentPropagator -> argumentPropagator.scan(method, code, methodProcessor, timing));
+ if (methodProcessor.isComposeMethodProcessor()) {
+ methodProcessor.asComposeMethodProcessor().scan(method, code, timing);
+ }
+
if (methodProcessor.isPrimaryMethodProcessor()) {
enumUnboxer.analyzeEnums(code, methodProcessor);
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
index 15dc3d2..a80055a 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MethodProcessor.java
@@ -5,9 +5,18 @@
import com.android.tools.r8.graph.ProgramMethod;
import com.android.tools.r8.ir.conversion.callgraph.CallSiteInformation;
+import com.android.tools.r8.optimize.compose.ComposeMethodProcessor;
public abstract class MethodProcessor {
+ public boolean isComposeMethodProcessor() {
+ return false;
+ }
+
+ public ComposeMethodProcessor asComposeMethodProcessor() {
+ return null;
+ }
+
public boolean isPrimaryMethodProcessor() {
return false;
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
index 8d33d03..cd02f61 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryMethodProcessor.java
@@ -98,7 +98,6 @@
InternalOptions options = appView.options();
Deque<ProgramMethodSet> waves = new ArrayDeque<>();
Collection<Node> nodes = callGraph.getNodes();
- int waveCount = 1;
while (!nodes.isEmpty()) {
ProgramMethodSet wave = callGraph.extractLeaves();
waves.addLast(wave);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
index 1a4a2ec..db30629 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/PrimaryR8IRConverter.java
@@ -25,6 +25,7 @@
import com.android.tools.r8.lightir.LirCode;
import com.android.tools.r8.optimize.MemberRebindingIdentityLens;
import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagator;
+import com.android.tools.r8.optimize.compose.ComposableOptimizationPass;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.Timing;
@@ -226,6 +227,8 @@
identifierNameStringMarker.decoupleIdentifierNameStringsInFields(executorService);
}
+ ComposableOptimizationPass.run(appView, this, timing);
+
// Assure that no more optimization feedback left after post processing.
assert feedback.noUpdatesLeft();
return appView.appInfo().app();
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorCodeScannerForComposableFunctions.java b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorCodeScannerForComposableFunctions.java
new file mode 100644
index 0000000..5d0b693
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorCodeScannerForComposableFunctions.java
@@ -0,0 +1,45 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.code.AbstractValueSupplier;
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorCodeScanner;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodParameter;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.Timing;
+
+public class ArgumentPropagatorCodeScannerForComposableFunctions
+ extends ArgumentPropagatorCodeScanner {
+
+ private final ComposableCallGraph callGraph;
+
+ public ArgumentPropagatorCodeScannerForComposableFunctions(
+ AppView<AppInfoWithLiveness> appView, ComposableCallGraph callGraph) {
+ super(appView);
+ this.callGraph = callGraph;
+ }
+
+ @Override
+ protected void addTemporaryMethodState(
+ InvokeMethod invoke,
+ ProgramMethod resolvedMethod,
+ AbstractValueSupplier abstractValueSupplier,
+ ProgramMethod context,
+ Timing timing) {
+ ComposableCallGraphNode node = callGraph.getNodes().get(resolvedMethod);
+ if (node != null && node.isComposable()) {
+ super.addTemporaryMethodState(invoke, resolvedMethod, abstractValueSupplier, context, timing);
+ }
+ }
+
+ @Override
+ protected boolean isMethodParameterAlreadyUnknown(
+ MethodParameter methodParameter, ProgramMethod method) {
+ // We haven't defined the virtual root mapping, so we can't tell.
+ return false;
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
index c4fd469..6ffeb6d 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ArgumentPropagatorComposeModeling.java
@@ -4,8 +4,10 @@
package com.android.tools.r8.optimize.compose;
import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.graph.ProgramMethod;
import com.android.tools.r8.graph.lens.GraphLens;
import com.android.tools.r8.ir.code.InstanceGet;
@@ -19,16 +21,31 @@
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.BitUtils;
import com.android.tools.r8.utils.BooleanUtils;
+import com.google.common.collect.Iterables;
public class ArgumentPropagatorComposeModeling {
private final AppView<AppInfoWithLiveness> appView;
- private final ComposeReferences composeReferences;
+ private final ComposeReferences rewrittenComposeReferences;
+
+ private final DexType rewrittenFunction2Type;
+ private final DexString invokeName;
public ArgumentPropagatorComposeModeling(AppView<AppInfoWithLiveness> appView) {
assert appView.testing().modelUnknownChangedAndDefaultArgumentsToComposableFunctions;
this.appView = appView;
- this.composeReferences = appView.getComposeReferences();
+ this.rewrittenComposeReferences =
+ appView
+ .getComposeReferences()
+ .rewrittenWithLens(appView.graphLens(), GraphLens.getIdentityLens());
+ DexItemFactory dexItemFactory = appView.dexItemFactory();
+ this.rewrittenFunction2Type =
+ appView
+ .graphLens()
+ .lookupType(
+ dexItemFactory.createType("Lkotlin/jvm/functions/Function2;"),
+ GraphLens.getIdentityLens());
+ this.invokeName = dexItemFactory.createString("invoke");
}
/**
@@ -67,20 +84,29 @@
* }
* </pre>
*/
- // TODO(b/302483644): Only apply modeling when the context is recognized as being a restart
- // lambda.
public ParameterState modelParameterStateForChangedOrDefaultArgumentToComposableFunction(
InvokeMethod invoke,
ProgramMethod singleTarget,
int argumentIndex,
Value argument,
ProgramMethod context) {
+ // TODO(b/302483644): Add some robust way of detecting restart lambda contexts.
+ if (!context.getHolder().getInterfaces().contains(rewrittenFunction2Type)
+ || !invoke.getPosition().getOutermostCaller().getMethod().getName().isEqualTo(invokeName)
+ || Iterables.isEmpty(
+ context
+ .getHolder()
+ .instanceFields(
+ f -> f.getName().isIdenticalTo(rewrittenComposeReferences.changedFieldName)))) {
+ return null;
+ }
+
// First check if this is an invoke to a @Composable function.
if (singleTarget == null
|| !singleTarget
.getDefinition()
.annotations()
- .hasAnnotation(composeReferences.composableType)) {
+ .hasAnnotation(rewrittenComposeReferences.composableType)) {
return null;
}
@@ -109,7 +135,7 @@
invokedMethod.getArity() - 2 - BooleanUtils.intValue(hasDefaultParameter);
if (!invokedMethod
.getParameter(composerParameterIndex)
- .isIdenticalTo(composeReferences.composerType)) {
+ .isIdenticalTo(rewrittenComposeReferences.composerType)) {
return null;
}
@@ -128,13 +154,9 @@
// We generally expect this argument to be defined by a call to updateChangedFlags().
if (argument.isDefinedByInstructionSatisfying(Instruction::isInvokeStatic)) {
InvokeStatic invokeStatic = argument.getDefinition().asInvokeStatic();
- DexMethod maybeUpdateChangedFlagsMethod =
- appView
- .graphLens()
- .getOriginalMethodSignature(
- invokeStatic.getInvokedMethod(), GraphLens.getIdentityLens());
+ DexMethod maybeUpdateChangedFlagsMethod = invokeStatic.getInvokedMethod();
if (!maybeUpdateChangedFlagsMethod.isIdenticalTo(
- composeReferences.updatedChangedFlagsMethod)) {
+ rewrittenComposeReferences.updatedChangedFlagsMethod)) {
return null;
}
// Assume the call does not impact the $$changed capture and strip the call.
@@ -160,10 +182,10 @@
.createDefiniteBitsNumberValue(
BitUtils.ALL_BITS_SET_MASK, BitUtils.ALL_BITS_SET_MASK << 1));
}
- expectedFieldName = composeReferences.changedFieldName;
+ expectedFieldName = rewrittenComposeReferences.changedFieldName;
} else {
// We are looking at an argument to the $$default parameter of the @Composable function.
- expectedFieldName = composeReferences.defaultFieldName;
+ expectedFieldName = rewrittenComposeReferences.defaultFieldName;
}
// At this point we expect that the restart lambda is reading either this.$$changed or
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraph.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraph.java
new file mode 100644
index 0000000..0cbf138
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraph.java
@@ -0,0 +1,170 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.Code;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.UseRegistry;
+import com.android.tools.r8.graph.lens.GraphLens;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
+import java.util.function.Consumer;
+
+/**
+ * A partial call graph that stores call edges to @Composable functions. By processing all the call
+ * sites of a given @Composable function we can reapply arguent propagation for the @Composable
+ * function.
+ */
+public class ComposableCallGraph {
+
+ private final ProgramMethodMap<ComposableCallGraphNode> nodes;
+
+ public ComposableCallGraph(ProgramMethodMap<ComposableCallGraphNode> nodes) {
+ this.nodes = nodes;
+ }
+
+ public static Builder builder(AppView<AppInfoWithLiveness> appView) {
+ return new Builder(appView);
+ }
+
+ public static ComposableCallGraph empty() {
+ return new ComposableCallGraph(ProgramMethodMap.empty());
+ }
+
+ public void forEachNode(Consumer<ComposableCallGraphNode> consumer) {
+ nodes.forEachValue(consumer);
+ }
+
+ public ProgramMethodMap<ComposableCallGraphNode> getNodes() {
+ return nodes;
+ }
+
+ public boolean isEmpty() {
+ return nodes.isEmpty();
+ }
+
+ public static class Builder {
+
+ private final AppView<AppInfoWithLiveness> appView;
+ private final ProgramMethodMap<ComposableCallGraphNode> nodes = ProgramMethodMap.create();
+
+ Builder(AppView<AppInfoWithLiveness> appView) {
+ this.appView = appView;
+ }
+
+ public ComposableCallGraph build() {
+ createCallGraphNodesForComposableFunctions();
+ if (!nodes.isEmpty()) {
+ addCallEdgesToComposableFunctions();
+ }
+ return new ComposableCallGraph(nodes);
+ }
+
+ private void createCallGraphNodesForComposableFunctions() {
+ ComposeReferences rewrittenComposeReferences =
+ appView
+ .getComposeReferences()
+ .rewrittenWithLens(appView.graphLens(), GraphLens.getIdentityLens());
+ for (DexProgramClass clazz : appView.appInfo().classes()) {
+ clazz.forEachProgramDirectMethodMatching(
+ method -> method.annotations().hasAnnotation(rewrittenComposeReferences.composableType),
+ method -> {
+ // TODO(b/302483644): Don't include kept @Composable functions, since we can't
+ // optimize them anyway.
+ assert method.getAccessFlags().isStatic();
+ nodes.put(method, new ComposableCallGraphNode(method, true));
+ });
+ }
+ }
+
+ // TODO(b/302483644): Parallelize identification of @Composable call sites.
+ private void addCallEdgesToComposableFunctions() {
+ // Code is fully rewritten so no need to lens rewrite in registry.
+ assert appView.codeLens() == appView.graphLens();
+
+ for (DexProgramClass clazz : appView.appInfo().classes()) {
+ clazz.forEachProgramMethodMatching(
+ DexEncodedMethod::hasCode,
+ method -> {
+ Code code = method.getDefinition().getCode();
+
+ // TODO(b/302483644): Leverage LIR code constant pool for efficient checking.
+ // TODO(b/302483644): Maybe remove the possibility of CF/DEX at this point.
+ assert code.isLirCode()
+ || code.isCfCode()
+ || code.isDexCode()
+ || code.isDefaultInstanceInitializerCode()
+ || code.isThrowNullCode();
+
+ code.registerCodeReferences(
+ method,
+ new UseRegistry<>(appView, method) {
+
+ private final AppView<AppInfoWithLiveness> appViewWithLiveness =
+ appView.withLiveness();
+
+ @Override
+ public void registerInvokeStatic(DexMethod method) {
+ ProgramMethod resolvedMethod =
+ appViewWithLiveness
+ .appInfo()
+ .unsafeResolveMethodDueToDexFormat(method)
+ .getResolvedProgramMethod();
+ if (resolvedMethod == null) {
+ return;
+ }
+
+ ComposableCallGraphNode callee = nodes.get(resolvedMethod);
+ if (callee == null || !callee.isComposable()) {
+ // Only record calls to Composable functions.
+ return;
+ }
+
+ ComposableCallGraphNode caller =
+ nodes.computeIfAbsent(
+ getContext(), context -> new ComposableCallGraphNode(context, false));
+ callee.addCaller(caller);
+ }
+
+ @Override
+ public void registerInitClass(DexType type) {}
+
+ @Override
+ public void registerInvokeDirect(DexMethod method) {}
+
+ @Override
+ public void registerInvokeInterface(DexMethod method) {}
+
+ @Override
+ public void registerInvokeSuper(DexMethod method) {}
+
+ @Override
+ public void registerInvokeVirtual(DexMethod method) {}
+
+ @Override
+ public void registerInstanceFieldRead(DexField field) {}
+
+ @Override
+ public void registerInstanceFieldWrite(DexField field) {}
+
+ @Override
+ public void registerStaticFieldRead(DexField field) {}
+
+ @Override
+ public void registerStaticFieldWrite(DexField field) {}
+
+ @Override
+ public void registerTypeReference(DexType type) {}
+ });
+ });
+ }
+ }
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraphNode.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraphNode.java
new file mode 100644
index 0000000..13ec14db
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposableCallGraphNode.java
@@ -0,0 +1,53 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.utils.SetUtils;
+import java.util.Set;
+import java.util.function.Consumer;
+
+public class ComposableCallGraphNode {
+
+ private final ProgramMethod method;
+ private final boolean isComposable;
+
+ private final Set<ComposableCallGraphNode> callers = SetUtils.newIdentityHashSet();
+ private final Set<ComposableCallGraphNode> callees = SetUtils.newIdentityHashSet();
+
+ ComposableCallGraphNode(ProgramMethod method, boolean isComposable) {
+ this.method = method;
+ this.isComposable = isComposable;
+ }
+
+ public void addCaller(ComposableCallGraphNode caller) {
+ callers.add(caller);
+ caller.callees.add(this);
+ }
+
+ public void forEachComposableCallee(Consumer<ComposableCallGraphNode> consumer) {
+ for (ComposableCallGraphNode callee : callees) {
+ if (callee.isComposable()) {
+ consumer.accept(callee);
+ }
+ }
+ }
+
+ public Set<ComposableCallGraphNode> getCallers() {
+ return callers;
+ }
+
+ public ProgramMethod getMethod() {
+ return method;
+ }
+
+ public boolean isComposable() {
+ return isComposable;
+ }
+
+ @Override
+ public String toString() {
+ return method.toString();
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java
new file mode 100644
index 0000000..244dbb4
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposableOptimizationPass.java
@@ -0,0 +1,116 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.ir.conversion.PrimaryR8IRConverter;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.InternalOptions.TestingOptions;
+import com.android.tools.r8.utils.SetUtils;
+import com.android.tools.r8.utils.Timing;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+
+public class ComposableOptimizationPass {
+
+ private final AppView<AppInfoWithLiveness> appView;
+ private final PrimaryR8IRConverter converter;
+
+ private ComposableOptimizationPass(
+ AppView<AppInfoWithLiveness> appView, PrimaryR8IRConverter converter) {
+ this.appView = appView;
+ this.converter = converter;
+ }
+
+ public static void run(
+ AppView<AppInfoWithLiveness> appView, PrimaryR8IRConverter converter, Timing timing) {
+ InternalOptions options = appView.options();
+ if (!options.isOptimizing() || !options.isShrinking()) {
+ return;
+ }
+ TestingOptions testingOptions = options.getTestingOptions();
+ if (!testingOptions.enableComposableOptimizationPass
+ || !testingOptions.modelUnknownChangedAndDefaultArgumentsToComposableFunctions) {
+ return;
+ }
+ timing.time(
+ "ComposableOptimizationPass",
+ () -> new ComposableOptimizationPass(appView, converter).processWaves());
+ }
+
+ void processWaves() {
+ ComposableCallGraph callGraph = ComposableCallGraph.builder(appView).build();
+ ComposeMethodProcessor methodProcessor =
+ new ComposeMethodProcessor(appView, callGraph, converter);
+ Set<ComposableCallGraphNode> wave = createInitialWave(callGraph);
+ while (!wave.isEmpty()) {
+ Set<ComposableCallGraphNode> optimizedComposableFunctions = methodProcessor.processWave(wave);
+ wave = createNextWave(methodProcessor, optimizedComposableFunctions);
+ }
+ }
+
+ // TODO(b/302483644): Should we skip root @Composable functions that don't have any nested
+ // @Composable functions (?).
+ private Set<ComposableCallGraphNode> computeComposableRoots(ComposableCallGraph callGraph) {
+ Set<ComposableCallGraphNode> composableRoots = Sets.newIdentityHashSet();
+ callGraph.forEachNode(
+ node -> {
+ if (!node.isComposable()
+ || Iterables.any(node.getCallers(), ComposableCallGraphNode::isComposable)) {
+ // This is not a @Composable root.
+ return;
+ }
+ if (node.getCallers().isEmpty()) {
+ // Don't include root @Composable functions that are never called. These are either kept
+ // or will be removed in tree shaking.
+ return;
+ }
+ composableRoots.add(node);
+ });
+ return composableRoots;
+ }
+
+ private Set<ComposableCallGraphNode> createInitialWave(ComposableCallGraph callGraph) {
+ Set<ComposableCallGraphNode> wave = Sets.newIdentityHashSet();
+ Set<ComposableCallGraphNode> composableRoots = computeComposableRoots(callGraph);
+ composableRoots.forEach(composableRoot -> wave.addAll(composableRoot.getCallers()));
+ return wave;
+ }
+
+ // TODO(b/302483644): Consider repeatedly extracting the roots from the graph similar to the way
+ // we extract leaves in the primary optimization pass.
+ private static Set<ComposableCallGraphNode> createNextWave(
+ ComposeMethodProcessor methodProcessor,
+ Set<ComposableCallGraphNode> optimizedComposableFunctions) {
+ Set<ComposableCallGraphNode> nextWave =
+ SetUtils.newIdentityHashSet(optimizedComposableFunctions);
+
+ // If the new wave contains two @Composable functions where one calls the other, then defer the
+ // processing of the callee to a later wave, to ensure that we have seen all of its callers
+ // before processing the callee.
+ List<ComposableCallGraphNode> deferredComposableFunctions = new ArrayList<>();
+ nextWave.forEach(
+ node -> {
+ if (SetUtils.containsAnyOf(nextWave, node.getCallers())) {
+ deferredComposableFunctions.add(node);
+ }
+ });
+ deferredComposableFunctions.forEach(nextWave::remove);
+
+ // To optimize the @Composable functions that are called from the @Composable functions of the
+ // next wave in the wave after that, we need to include their callers in the next wave as well.
+ Set<ComposableCallGraphNode> callersOfCalledComposableFunctions = Sets.newIdentityHashSet();
+ nextWave.forEach(
+ node ->
+ node.forEachComposableCallee(
+ callee -> callersOfCalledComposableFunctions.addAll(callee.getCallers())));
+ nextWave.addAll(callersOfCalledComposableFunctions);
+ nextWave.removeIf(methodProcessor::isProcessed);
+ return nextWave;
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
new file mode 100644
index 0000000..8ee6237
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposeMethodProcessor.java
@@ -0,0 +1,171 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.compose;
+
+import com.android.tools.r8.contexts.CompilationContext.ProcessorContext;
+import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.constant.SparseConditionalConstantPropagation;
+import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.ir.code.AbstractValueSupplier;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.conversion.MethodConversionOptions;
+import com.android.tools.r8.ir.conversion.MethodProcessor;
+import com.android.tools.r8.ir.conversion.MethodProcessorEventConsumer;
+import com.android.tools.r8.ir.conversion.PrimaryR8IRConverter;
+import com.android.tools.r8.ir.conversion.callgraph.CallSiteInformation;
+import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
+import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorCodeScanner;
+import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorOptimizationInfoPopulator;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteMonomorphicMethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ConcreteParameterState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.MethodState;
+import com.android.tools.r8.optimize.argumentpropagation.codescanner.ParameterState;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.LazyBox;
+import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import java.util.Map;
+import java.util.Set;
+
+public class ComposeMethodProcessor extends MethodProcessor {
+
+ private final AppView<AppInfoWithLiveness> appView;
+ private final ArgumentPropagatorCodeScanner codeScanner;
+ private final PrimaryR8IRConverter converter;
+
+ private final Set<ComposableCallGraphNode> processed = Sets.newIdentityHashSet();
+
+ public ComposeMethodProcessor(
+ AppView<AppInfoWithLiveness> appView,
+ ComposableCallGraph callGraph,
+ PrimaryR8IRConverter converter) {
+ this.appView = appView;
+ this.codeScanner = new ArgumentPropagatorCodeScannerForComposableFunctions(appView, callGraph);
+ this.converter = converter;
+ }
+
+ // TODO(b/302483644): Process wave concurrently.
+ public Set<ComposableCallGraphNode> processWave(Set<ComposableCallGraphNode> wave) {
+ ProcessorContext processorContext = appView.createProcessorContext();
+ wave.forEach(
+ node -> {
+ assert !processed.contains(node);
+ converter.processDesugaredMethod(
+ node.getMethod(),
+ OptimizationFeedback.getIgnoreFeedback(),
+ this,
+ processorContext.createMethodProcessingContext(node.getMethod()),
+ MethodConversionOptions.forLirPhase(appView));
+ });
+ processed.addAll(wave);
+ return optimizeComposableFunctionsCalledFromWave(wave);
+ }
+
+ private Set<ComposableCallGraphNode> optimizeComposableFunctionsCalledFromWave(
+ Set<ComposableCallGraphNode> wave) {
+ ArgumentPropagatorOptimizationInfoPopulator optimizationInfoPopulator =
+ new ArgumentPropagatorOptimizationInfoPopulator(appView, null, null, null);
+ Set<ComposableCallGraphNode> optimizedComposableFunctions = Sets.newIdentityHashSet();
+ wave.forEach(
+ node ->
+ node.forEachComposableCallee(
+ callee -> {
+ if (Iterables.all(callee.getCallers(), this::isProcessed)) {
+ optimizationInfoPopulator.setOptimizationInfo(
+ callee.getMethod(), ProgramMethodSet.empty(), getMethodState(callee));
+ // TODO(b/302483644): Only enqueue this callee if its optimization info changed.
+ optimizedComposableFunctions.add(callee);
+ }
+ }));
+ return optimizedComposableFunctions;
+ }
+
+ private MethodState getMethodState(ComposableCallGraphNode node) {
+ assert processed.containsAll(node.getCallers());
+ MethodState methodState = codeScanner.getMethodStates().get(node.getMethod());
+ return widenMethodState(methodState);
+ }
+
+ /**
+ * If a parameter state of the current method state encodes that it is greater than (lattice wise)
+ * than another parameter in the program, then widen the parameter state to unknown. This is
+ * needed since we are not guaranteed to have seen all possible call sites of the callers of this
+ * method.
+ */
+ private MethodState widenMethodState(MethodState methodState) {
+ assert !methodState.isBottom();
+ assert !methodState.isPolymorphic();
+ if (methodState.isMonomorphic()) {
+ ConcreteMonomorphicMethodState monomorphicMethodState = methodState.asMonomorphic();
+ for (int i = 0; i < monomorphicMethodState.size(); i++) {
+ if (monomorphicMethodState.getParameterState(i).isConcrete()) {
+ ConcreteParameterState concreteParameterState =
+ monomorphicMethodState.getParameterState(i).asConcrete();
+ if (concreteParameterState.hasInParameters()) {
+ monomorphicMethodState.setParameterState(i, ParameterState.unknown());
+ }
+ }
+ }
+ } else {
+ assert methodState.isUnknown();
+ }
+ return methodState;
+ }
+
+ public void scan(ProgramMethod method, IRCode code, Timing timing) {
+ LazyBox<Map<Value, AbstractValue>> abstractValues =
+ new LazyBox<>(() -> new SparseConditionalConstantPropagation(appView).analyze(code));
+ AbstractValueSupplier abstractValueSupplier =
+ value -> {
+ AbstractValue abstractValue = abstractValues.computeIfAbsent().get(value);
+ assert abstractValue != null;
+ return abstractValue;
+ };
+ codeScanner.scan(method, code, abstractValueSupplier, timing);
+ }
+
+ public boolean isProcessed(ComposableCallGraphNode node) {
+ return processed.contains(node);
+ }
+
+ @Override
+ public CallSiteInformation getCallSiteInformation() {
+ return CallSiteInformation.empty();
+ }
+
+ @Override
+ public MethodProcessorEventConsumer getEventConsumer() {
+ throw new Unreachable();
+ }
+
+ @Override
+ public boolean isComposeMethodProcessor() {
+ return true;
+ }
+
+ @Override
+ public ComposeMethodProcessor asComposeMethodProcessor() {
+ return this;
+ }
+
+ @Override
+ public boolean isProcessedConcurrently(ProgramMethod method) {
+ return false;
+ }
+
+ @Override
+ public void scheduleDesugaredMethodForProcessing(ProgramMethod method) {
+ throw new Unreachable();
+ }
+
+ @Override
+ public boolean shouldApplyCodeRewritings(ProgramMethod method) {
+ return false;
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java b/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
index 7a489a7..e441188 100644
--- a/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
+++ b/src/main/java/com/android/tools/r8/optimize/compose/ComposeReferences.java
@@ -7,6 +7,7 @@
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexString;
import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.lens.GraphLens;
public class ComposeReferences {
@@ -31,4 +32,26 @@
factory.createProto(factory.intType, factory.intType),
"updateChangedFlags");
}
+
+ public ComposeReferences(
+ DexString changedFieldName,
+ DexString defaultFieldName,
+ DexType composableType,
+ DexType composerType,
+ DexMethod updatedChangedFlagsMethod) {
+ this.changedFieldName = changedFieldName;
+ this.defaultFieldName = defaultFieldName;
+ this.composableType = composableType;
+ this.composerType = composerType;
+ this.updatedChangedFlagsMethod = updatedChangedFlagsMethod;
+ }
+
+ public ComposeReferences rewrittenWithLens(GraphLens graphLens, GraphLens codeLens) {
+ return new ComposeReferences(
+ changedFieldName,
+ defaultFieldName,
+ graphLens.lookupClassType(composableType, codeLens),
+ graphLens.lookupClassType(composerType, codeLens),
+ graphLens.getRenamedMethodSignature(updatedChangedFlagsMethod, codeLens));
+ }
}
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index ffae35b..17e1f0a 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -2410,6 +2410,9 @@
System.getProperty("com.android.tools.r8.disableMarkingClassesFinal") != null;
public boolean testEnableTestAssertions = false;
public boolean keepMetadataInR8IfNotRewritten = true;
+ public boolean enableComposableOptimizationPass =
+ SystemPropertyUtils.parseSystemPropertyForDevelopmentOrDefault(
+ "com.android.tools.r8.enableComposableOptimizationPass", false);
public boolean modelUnknownChangedAndDefaultArgumentsToComposableFunctions =
SystemPropertyUtils.parseSystemPropertyForDevelopmentOrDefault(
"com.android.tools.r8.modelUnknownChangedAndDefaultArgumentsToComposableFunctions",
diff --git a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
index bbc27a2..eb747e4 100644
--- a/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
+++ b/src/main/java/com/android/tools/r8/utils/collections/DexClassAndMemberMap.java
@@ -12,6 +12,7 @@
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
+import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
@@ -47,6 +48,10 @@
backing.forEach((wrapper, value) -> consumer.accept(wrapper.get(), value));
}
+ public void forEachValue(Consumer<V> consumer) {
+ backing.values().forEach(consumer);
+ }
+
public V get(K member) {
return backing.get(wrap(member));
}