Introduce a new single caller inlining pass
Change-Id: I3fa1905ab55946bf168f2538ea72961b074e51a2
diff --git a/src/main/java/com/android/tools/r8/graph/DexClassAndMember.java b/src/main/java/com/android/tools/r8/graph/DexClassAndMember.java
index 44f6520..9d56643 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClassAndMember.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClassAndMember.java
@@ -14,6 +14,12 @@
private final DexClass holder;
private final D definition;
+ // To allow creation of sentinels.
+ DexClassAndMember() {
+ this.holder = null;
+ this.definition = null;
+ }
+
@SuppressWarnings("ReferenceEquality")
public DexClassAndMember(DexClass holder, D definition) {
assert holder != null;
diff --git a/src/main/java/com/android/tools/r8/graph/DexClassAndMethod.java b/src/main/java/com/android/tools/r8/graph/DexClassAndMethod.java
index f46fc2b..46e1cb9 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClassAndMethod.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClassAndMethod.java
@@ -10,6 +10,9 @@
public abstract class DexClassAndMethod extends DexClassAndMember<DexEncodedMethod, DexMethod>
implements LookupMethodTarget {
+ // To allow creation of sentinels.
+ DexClassAndMethod() {}
+
DexClassAndMethod(DexClass holder, DexEncodedMethod method) {
super(holder, method);
assert holder.isClasspathClass() == (this instanceof ClasspathMethod);
diff --git a/src/main/java/com/android/tools/r8/graph/ProgramMethod.java b/src/main/java/com/android/tools/r8/graph/ProgramMethod.java
index 0c85af8..03758b2 100644
--- a/src/main/java/com/android/tools/r8/graph/ProgramMethod.java
+++ b/src/main/java/com/android/tools/r8/graph/ProgramMethod.java
@@ -24,10 +24,16 @@
public final class ProgramMethod extends DexClassAndMethod
implements ProgramMember<DexEncodedMethod, DexMethod> {
+ private ProgramMethod() {}
+
public ProgramMethod(DexProgramClass holder, DexEncodedMethod method) {
super(holder, method);
}
+ public static ProgramMethod createSentinel() {
+ return new ProgramMethod();
+ }
+
public IRCode buildIR(AppView<?> appView) {
return buildIR(appView, MethodConversionOptions.forLirPhase(appView));
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 6b3bc28..c59bea5 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -83,7 +83,6 @@
import com.android.tools.r8.optimize.argumentpropagation.ArgumentPropagatorIROptimizer;
import com.android.tools.r8.position.MethodPosition;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.KeepMethodInfo;
import com.android.tools.r8.shaking.LibraryMethodOverrideAnalysis;
import com.android.tools.r8.utils.Action;
import com.android.tools.r8.utils.DescriptorUtils;
@@ -1080,27 +1079,9 @@
public void markProcessed(IRCode code, OptimizationFeedback feedback) {
// After all the optimizations have take place, we compute whether method should be inlined.
- ProgramMethod method = code.context();
ConstraintWithTarget state =
- shouldComputeInliningConstraint(method)
- ? inliner.computeInliningConstraint(code)
- : ConstraintWithTarget.NEVER;
- feedback.markProcessed(method.getDefinition(), state);
- }
-
- private boolean shouldComputeInliningConstraint(ProgramMethod method) {
- if (!options.inlinerOptions().enableInlining || inliner == null) {
- return false;
- }
- DexEncodedMethod definition = method.getDefinition();
- if (definition.isClassInitializer() || method.getOrComputeReachabilitySensitive(appView)) {
- return false;
- }
- KeepMethodInfo keepInfo = appView.getKeepInfo(method);
- if (!keepInfo.isInliningAllowed(options) && !keepInfo.isClassInliningAllowed(options)) {
- return false;
- }
- return true;
+ inliner != null ? inliner.computeInliningConstraint(code) : ConstraintWithTarget.NEVER;
+ feedback.markProcessed(code.context().getDefinition(), state);
}
public void printPhase(String phase) {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRFinalizer.java b/src/main/java/com/android/tools/r8/ir/conversion/IRFinalizer.java
index c0c6808..cc7fabb 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRFinalizer.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRFinalizer.java
@@ -8,17 +8,14 @@
import com.android.tools.r8.graph.Code;
import com.android.tools.r8.graph.bytecodemetadata.BytecodeMetadataProvider;
import com.android.tools.r8.ir.code.IRCode;
-import com.android.tools.r8.ir.optimize.DeadCodeRemover;
import com.android.tools.r8.utils.Timing;
public abstract class IRFinalizer<C extends Code> {
protected final AppView<?> appView;
- protected final DeadCodeRemover deadCodeRemover;
- public IRFinalizer(AppView<?> appView, DeadCodeRemover deadCodeRemover) {
+ public IRFinalizer(AppView<?> appView) {
this.appView = appView;
- this.deadCodeRemover = deadCodeRemover;
}
public abstract C finalizeCode(
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRToCfFinalizer.java b/src/main/java/com/android/tools/r8/ir/conversion/IRToCfFinalizer.java
index 0f3a0aa..cdb1be7 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRToCfFinalizer.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRToCfFinalizer.java
@@ -14,8 +14,11 @@
public class IRToCfFinalizer extends IRFinalizer<CfCode> {
+ private final DeadCodeRemover deadCodeRemover;
+
public IRToCfFinalizer(AppView<?> appView, DeadCodeRemover deadCodeRemover) {
- super(appView, deadCodeRemover);
+ super(appView);
+ this.deadCodeRemover = deadCodeRemover;
}
@Override
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRToDexFinalizer.java b/src/main/java/com/android/tools/r8/ir/conversion/IRToDexFinalizer.java
index c5e44d0..0798942 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRToDexFinalizer.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRToDexFinalizer.java
@@ -23,10 +23,13 @@
public class IRToDexFinalizer extends IRFinalizer<DexCode> {
private static final int PEEPHOLE_OPTIMIZATION_PASSES = 2;
+
+ private final DeadCodeRemover deadCodeRemover;
private final InternalOptions options;
public IRToDexFinalizer(AppView<?> appView, DeadCodeRemover deadCodeRemover) {
- super(appView, deadCodeRemover);
+ super(appView);
+ this.deadCodeRemover = deadCodeRemover;
this.options = appView.options();
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRToLirFinalizer.java b/src/main/java/com/android/tools/r8/ir/conversion/IRToLirFinalizer.java
index d981b26..899bc3a 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRToLirFinalizer.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRToLirFinalizer.java
@@ -7,7 +7,6 @@
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.bytecodemetadata.BytecodeMetadataProvider;
import com.android.tools.r8.ir.code.IRCode;
-import com.android.tools.r8.ir.optimize.DeadCodeRemover;
import com.android.tools.r8.lightir.IR2LirConverter;
import com.android.tools.r8.lightir.LirCode;
import com.android.tools.r8.lightir.LirStrategy;
@@ -15,14 +14,13 @@
public class IRToLirFinalizer extends IRFinalizer<LirCode<Integer>> {
- public IRToLirFinalizer(AppView<?> appView, DeadCodeRemover deadCodeRemover) {
- super(appView, deadCodeRemover);
+ public IRToLirFinalizer(AppView<?> appView) {
+ super(appView);
}
@Override
public LirCode<Integer> finalizeCode(
IRCode code, BytecodeMetadataProvider bytecodeMetadataProvider, Timing timing) {
- assert deadCodeRemover.verifyNoDeadCode(code);
timing.begin("Finalize LIR code");
LirCode<Integer> lirCode =
IR2LirConverter.translate(
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/MethodConversionOptions.java b/src/main/java/com/android/tools/r8/ir/conversion/MethodConversionOptions.java
index 7732f47..24f7613 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/MethodConversionOptions.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/MethodConversionOptions.java
@@ -48,7 +48,7 @@
public IRFinalizer<?> getFinalizer(DeadCodeRemover deadCodeRemover, AppView<?> appView) {
if (isGeneratingLir()) {
- return new IRToLirFinalizer(appView, deadCodeRemover);
+ return new IRToLirFinalizer(appView);
}
if (isGeneratingClassFiles()) {
return new IRToCfFinalizer(appView, deadCodeRemover);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/OneTimeMethodProcessor.java b/src/main/java/com/android/tools/r8/ir/conversion/OneTimeMethodProcessor.java
index b75e95f..fb545c6 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/OneTimeMethodProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/OneTimeMethodProcessor.java
@@ -7,6 +7,7 @@
import com.android.tools.r8.contexts.CompilationContext.ProcessorContext;
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.conversion.callgraph.CallSiteInformation;
import com.android.tools.r8.threading.ThreadingModule;
import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.ThreadUtils.WorkLoad;
@@ -69,6 +70,17 @@
return new OneTimeMethodProcessor(eventConsumer, processorContext, methodsToProcess);
}
+ private CallSiteInformation callSiteInformation;
+
+ @Override
+ public CallSiteInformation getCallSiteInformation() {
+ return callSiteInformation != null ? callSiteInformation : super.getCallSiteInformation();
+ }
+
+ public void setCallSiteInformation(CallSiteInformation callSiteInformation) {
+ this.callSiteInformation = callSiteInformation;
+ }
+
@Override
public MethodProcessorEventConsumer getEventConsumer() {
return eventConsumer;
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CycleEliminator.java b/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CycleEliminator.java
index 52dd9bc..0d720fe 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CycleEliminator.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CycleEliminator.java
@@ -21,30 +21,30 @@
import java.util.function.BiConsumer;
import java.util.function.Predicate;
-public class CycleEliminator {
+public class CycleEliminator<N extends CycleEliminatorNode<N>> {
public static final String CYCLIC_FORCE_INLINING_MESSAGE =
"Unable to satisfy force inlining constraints due to cyclic force inlining";
- private static class CallEdge {
+ private static class CallEdge<N extends CycleEliminatorNode<N>> {
- private final Node caller;
- private final Node callee;
+ private final N caller;
+ private final N callee;
- CallEdge(Node caller, Node callee) {
+ CallEdge(N caller, N callee) {
this.caller = caller;
this.callee = callee;
}
}
- static class StackEntryInfo {
+ static class StackEntryInfo<N extends CycleEliminatorNode<N>> {
final int index;
- final Node predecessor;
+ final N predecessor;
boolean processed;
- StackEntryInfo(int index, Node predecessor) {
+ StackEntryInfo(int index, N predecessor) {
this.index = index;
this.predecessor = predecessor;
}
@@ -68,42 +68,42 @@
}
// DFS stack.
- private Deque<Node> stack = new ArrayDeque<>();
+ private Deque<N> stack = new ArrayDeque<>();
// Nodes on the DFS stack.
- private Map<Node, StackEntryInfo> stackEntryInfo = new IdentityHashMap<>();
+ private Map<N, StackEntryInfo<N>> stackEntryInfo = new IdentityHashMap<>();
// Subset of the DFS stack, where the nodes on the stack are class initializers.
//
// This stack is used to efficiently compute if there is a class initializer on the stack.
- private Deque<Node> clinitCallStack = new ArrayDeque<>();
+ private Deque<N> clinitCallStack = new ArrayDeque<>();
// Subset of the DFS stack, where the nodes on the stack satisfy that the edge from the
// predecessor to the node itself is a field read edge.
//
// This stack is used to efficiently compute if there is a field read edge inside a cycle when
// a cycle is found.
- private Deque<Node> writerStack = new ArrayDeque<>();
+ private Deque<N> writerStack = new ArrayDeque<>();
// Set of nodes that have been visited entirely.
- private Set<Node> marked = Sets.newIdentityHashSet();
+ private Set<N> marked = Sets.newIdentityHashSet();
// Call edges that should be removed when the caller has been processed. These are not removed
// directly since that would lead to ConcurrentModificationExceptions.
- private Map<Node, Set<Node>> calleesToBeRemoved = new IdentityHashMap<>();
+ private Map<N, Set<N>> calleesToBeRemoved = new IdentityHashMap<>();
// Field read edges that should be removed when the reader has been processed. These are not
// removed directly since that would lead to ConcurrentModificationExceptions.
- private Map<Node, Set<Node>> writersToBeRemoved = new IdentityHashMap<>();
+ private Map<N, Set<N>> writersToBeRemoved = new IdentityHashMap<>();
// Mapping from callee to the set of callers that were removed from the callee.
private Map<DexEncodedMethod, ProgramMethodSet> removedCallEdges = new IdentityHashMap<>();
// Set of nodes from which cycle elimination must be rerun to ensure that all cycles will be
// removed.
- private LinkedHashSet<Node> revisit = new LinkedHashSet<>();
+ private LinkedHashSet<N> revisit = new LinkedHashSet<>();
- public CycleEliminationResult breakCycles(Collection<Node> roots) {
+ public CycleEliminationResult breakCycles(Collection<N> roots) {
// Break cycles in this call graph by removing edges causing cycles. We do this in a fixpoint
// because the algorithm does not guarantee that all cycles will be removed from the graph
// when we remove an edge in the middle of a cycle that contains another cycle.
@@ -139,12 +139,12 @@
removedCallEdges = new IdentityHashMap<>();
}
- private static class WorkItem {
+ private static class WorkItem<N extends CycleEliminatorNode<N>> {
boolean isNode() {
return false;
}
- NodeWorkItem asNode() {
+ NodeWorkItem<N> asNode() {
return null;
}
@@ -152,15 +152,15 @@
return false;
}
- IteratorWorkItem asIterator() {
+ IteratorWorkItem<N> asIterator() {
return null;
}
}
- private static class NodeWorkItem extends WorkItem {
- private final Node node;
+ private static class NodeWorkItem<N extends CycleEliminatorNode<N>> extends WorkItem<N> {
+ private final N node;
- NodeWorkItem(Node node) {
+ NodeWorkItem(N node) {
this.node = node;
}
@@ -170,16 +170,16 @@
}
@Override
- NodeWorkItem asNode() {
+ NodeWorkItem<N> asNode() {
return this;
}
}
- private static class IteratorWorkItem extends WorkItem {
- private final Node callerOrReader;
- private final Iterator<Node> calleesAndWriters;
+ private static class IteratorWorkItem<N extends CycleEliminatorNode<N>> extends WorkItem<N> {
+ private final N callerOrReader;
+ private final Iterator<N> calleesAndWriters;
- IteratorWorkItem(Node callerOrReader, Iterator<Node> calleesAndWriters) {
+ IteratorWorkItem(N callerOrReader, Iterator<N> calleesAndWriters) {
this.callerOrReader = callerOrReader;
this.calleesAndWriters = calleesAndWriters;
}
@@ -190,51 +190,51 @@
}
@Override
- IteratorWorkItem asIterator() {
+ IteratorWorkItem<N> asIterator() {
return this;
}
}
- private void traverse(Collection<Node> roots) {
- Deque<WorkItem> workItems = new ArrayDeque<>(roots.size());
- for (Node node : roots) {
- workItems.addLast(new NodeWorkItem(node));
+ private void traverse(Collection<N> roots) {
+ Deque<WorkItem<N>> workItems = new ArrayDeque<>(roots.size());
+ for (N node : roots) {
+ workItems.addLast(new NodeWorkItem<>(node));
}
while (!workItems.isEmpty()) {
- WorkItem workItem = workItems.removeFirst();
+ WorkItem<N> workItem = workItems.removeFirst();
if (workItem.isNode()) {
- Node node = workItem.asNode().node;
+ N node = workItem.asNode().node;
if (marked.contains(node)) {
// Already visited all nodes that can be reached from this node.
continue;
}
- Node predecessor = stack.isEmpty() ? null : stack.peek();
+ N predecessor = stack.isEmpty() ? null : stack.peek();
push(node, predecessor);
// The callees and writers must be sorted before calling traverse recursively.
// This ensures that cycles are broken the same way across multiple compilations.
- Iterator<Node> calleesAndWriterIterator =
+ Iterator<N> calleesAndWriterIterator =
Iterators.concat(
node.getCalleesWithDeterministicOrder().iterator(),
node.getWritersWithDeterministicOrder().iterator());
- workItems.addFirst(new IteratorWorkItem(node, calleesAndWriterIterator));
+ workItems.addFirst(new IteratorWorkItem<>(node, calleesAndWriterIterator));
} else {
assert workItem.isIterator();
- IteratorWorkItem iteratorWorkItem = workItem.asIterator();
- Node newCallerOrReader =
+ IteratorWorkItem<N> iteratorWorkItem = workItem.asIterator();
+ N newCallerOrReader =
iterateCalleesAndWriters(
iteratorWorkItem.calleesAndWriters, iteratorWorkItem.callerOrReader);
if (newCallerOrReader != null) {
// We did not finish the work on this iterator, so add it again.
workItems.addFirst(iteratorWorkItem);
- workItems.addFirst(new NodeWorkItem(newCallerOrReader));
+ workItems.addFirst(new NodeWorkItem<>(newCallerOrReader));
} else {
assert !iteratorWorkItem.calleesAndWriters.hasNext();
pop(iteratorWorkItem.callerOrReader);
marked.add(iteratorWorkItem.callerOrReader);
- Collection<Node> calleesToBeRemovedFromCaller =
+ Collection<N> calleesToBeRemovedFromCaller =
calleesToBeRemoved.remove(iteratorWorkItem.callerOrReader);
if (calleesToBeRemovedFromCaller != null) {
calleesToBeRemovedFromCaller.forEach(
@@ -244,7 +244,7 @@
});
}
- Collection<Node> writersToBeRemovedFromReader =
+ Collection<N> writersToBeRemovedFromReader =
writersToBeRemoved.remove(iteratorWorkItem.callerOrReader);
if (writersToBeRemovedFromReader != null) {
writersToBeRemovedFromReader.forEach(
@@ -255,11 +255,10 @@
}
}
- private Node iterateCalleesAndWriters(
- Iterator<Node> calleeOrWriterIterator, Node callerOrReader) {
+ private N iterateCalleesAndWriters(Iterator<N> calleeOrWriterIterator, N callerOrReader) {
while (calleeOrWriterIterator.hasNext()) {
- Node calleeOrWriter = calleeOrWriterIterator.next();
- StackEntryInfo calleeOrWriterStackEntryInfo = stackEntryInfo.get(calleeOrWriter);
+ N calleeOrWriter = calleeOrWriterIterator.next();
+ StackEntryInfo<N> calleeOrWriterStackEntryInfo = stackEntryInfo.get(calleeOrWriter);
boolean foundCycle = calleeOrWriterStackEntryInfo != null;
if (!foundCycle) {
return calleeOrWriter;
@@ -315,11 +314,11 @@
// The call edge cannot be removed due to force inlining. Find another call edge in the
// cycle that can safely be removed instead.
- LinkedList<Node> cycle = extractCycle(calleeOrWriter);
+ LinkedList<N> cycle = extractCycle(calleeOrWriter);
// Break the cycle by finding an edge that can be removed without breaking force
// inlining. If that is not possible, this call fails with a compilation error.
- CallEdge edge = findCallEdgeForRemoval(cycle);
+ CallEdge<N> edge = findCallEdgeForRemoval(cycle);
// The edge will be null if this cycle has already been eliminated as a result of
// another cycle elimination.
@@ -337,10 +336,10 @@
return null;
}
- private void push(Node node, Node predecessor) {
+ private void push(N node, N predecessor) {
stack.push(node);
assert !stackEntryInfo.containsKey(node);
- stackEntryInfo.put(node, new StackEntryInfo(stack.size() - 1, predecessor));
+ stackEntryInfo.put(node, new StackEntryInfo<>(stack.size() - 1, predecessor));
if (predecessor != null) {
if (node.getMethod().isClassInitializer() && node.hasCaller(predecessor)) {
clinitCallStack.push(node);
@@ -350,8 +349,8 @@
}
}
- private void pop(Node node) {
- Node popped = stack.pop();
+ private void pop(N node) {
+ N popped = stack.pop();
assert popped == node;
assert stackEntryInfo.containsKey(node);
stackEntryInfo.remove(node);
@@ -363,20 +362,20 @@
}
}
- private void removeCallEdge(Node caller, Node callee) {
+ private void removeCallEdge(N caller, N callee) {
calleesToBeRemoved.computeIfAbsent(caller, ignore -> Sets.newIdentityHashSet()).add(callee);
}
- private void removeFieldReadEdge(Node reader, Node writer) {
+ private void removeFieldReadEdge(N reader, N writer) {
writersToBeRemoved.computeIfAbsent(reader, ignore -> Sets.newIdentityHashSet()).add(writer);
}
private boolean removeIncomingEdgeOnStack(
- Node target,
- Node currentCalleeOrWriter,
- StackEntryInfo currentCalleeOrWriterStackEntryInfo,
- BiConsumer<Node, Node> edgeRemover) {
- StackEntryInfo targetStackEntryInfo = stackEntryInfo.get(target);
+ N target,
+ N currentCalleeOrWriter,
+ StackEntryInfo<N> currentCalleeOrWriterStackEntryInfo,
+ BiConsumer<N, N> edgeRemover) {
+ StackEntryInfo<N> targetStackEntryInfo = stackEntryInfo.get(target);
boolean cycleContainsTarget =
targetStackEntryInfo.index > currentCalleeOrWriterStackEntryInfo.index;
if (cycleContainsTarget) {
@@ -395,8 +394,8 @@
// TODO(b/270398965): Replace LinkedList.
@SuppressWarnings("JdkObsolete")
- private LinkedList<Node> extractCycle(Node entry) {
- LinkedList<Node> cycle = new LinkedList<>();
+ private LinkedList<N> extractCycle(N entry) {
+ LinkedList<N> cycle = new LinkedList<>();
do {
assert !stack.isEmpty();
cycle.add(stack.pop());
@@ -404,16 +403,16 @@
return cycle;
}
- private boolean verifyCycleSatisfies(Node entry, Predicate<LinkedList<Node>> predicate) {
- LinkedList<Node> cycle = extractCycle(entry);
+ private boolean verifyCycleSatisfies(N entry, Predicate<LinkedList<N>> predicate) {
+ LinkedList<N> cycle = extractCycle(entry);
assert predicate.test(cycle);
recoverStack(cycle);
return true;
}
- private CallEdge findCallEdgeForRemoval(LinkedList<Node> extractedCycle) {
- Node callee = extractedCycle.getLast();
- for (Node caller : extractedCycle) {
+ private CallEdge<N> findCallEdgeForRemoval(LinkedList<N> extractedCycle) {
+ N callee = extractedCycle.getLast();
+ for (N caller : extractedCycle) {
if (caller.hasWriter(callee)) {
// Not a call edge.
assert !caller.hasCallee(callee);
@@ -427,28 +426,29 @@
return null;
}
if (callEdgeRemovalIsSafe(caller, callee)) {
- return new CallEdge(caller, callee);
+ return new CallEdge<N>(caller, callee);
}
callee = caller;
}
throw new CompilationError(CYCLIC_FORCE_INLINING_MESSAGE);
}
- private static boolean callEdgeRemovalIsSafe(Node callerOrReader, Node calleeOrWriter) {
+ private static <N extends CycleEliminatorNode<N>> boolean callEdgeRemovalIsSafe(
+ N callerOrReader, N calleeOrWriter) {
// All call edges where the callee is a method that should be force inlined must be kept,
// to guarantee that the IR converter will process the callee before the caller.
assert calleeOrWriter.hasCaller(callerOrReader);
return !calleeOrWriter.getMethod().getOptimizationInfo().forceInline();
}
- private void recordCallEdgeRemoval(Node caller, Node callee) {
+ private void recordCallEdgeRemoval(N caller, N callee) {
removedCallEdges
.computeIfAbsent(callee.getMethod(), ignore -> ProgramMethodSet.create(2))
.add(caller.getProgramMethod());
}
- private void recoverStack(LinkedList<Node> extractedCycle) {
- Iterator<Node> descendingIt = extractedCycle.descendingIterator();
+ private void recoverStack(LinkedList<N> extractedCycle) {
+ Iterator<N> descendingIt = extractedCycle.descendingIterator();
while (descendingIt.hasNext()) {
stack.push(descendingIt.next());
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CycleEliminatorNode.java b/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CycleEliminatorNode.java
new file mode 100644
index 0000000..2483672
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/conversion/callgraph/CycleEliminatorNode.java
@@ -0,0 +1,31 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.conversion.callgraph;
+
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.ProgramMethod;
+import java.util.Set;
+
+public interface CycleEliminatorNode<N extends CycleEliminatorNode<N>> {
+
+ DexEncodedMethod getMethod();
+
+ ProgramMethod getProgramMethod();
+
+ Set<N> getCalleesWithDeterministicOrder();
+
+ Set<N> getWritersWithDeterministicOrder();
+
+ boolean hasCallee(N callee);
+
+ boolean hasCaller(N caller);
+
+ void removeCaller(N caller);
+
+ boolean hasReader(N reader);
+
+ void removeReader(N reader);
+
+ boolean hasWriter(N writer);
+}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/callgraph/Node.java b/src/main/java/com/android/tools/r8/ir/conversion/callgraph/Node.java
index 7752ad9..e515764 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/callgraph/Node.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/callgraph/Node.java
@@ -8,7 +8,7 @@
import java.util.Set;
import java.util.TreeSet;
-public class Node extends NodeBase<Node> implements Comparable<Node> {
+public class Node extends NodeBase<Node> implements Comparable<Node>, CycleEliminatorNode<Node> {
public static Node[] EMPTY_ARRAY = {};
@@ -88,6 +88,7 @@
}
}
+ @Override
public void removeCaller(Node caller) {
boolean callersChanged = callers.remove(caller);
assert callersChanged;
@@ -96,6 +97,7 @@
assert !hasReader(caller);
}
+ @Override
public void removeReader(Node reader) {
boolean readersChanged = readers.remove(reader);
assert readersChanged;
@@ -134,6 +136,7 @@
return callers;
}
+ @Override
public Set<Node> getCalleesWithDeterministicOrder() {
return callees;
}
@@ -142,6 +145,7 @@
return readers;
}
+ @Override
public Set<Node> getWritersWithDeterministicOrder() {
return writers;
}
@@ -150,18 +154,22 @@
return numberOfCallSites;
}
+ @Override
public boolean hasCallee(Node method) {
return callees.contains(method);
}
+ @Override
public boolean hasCaller(Node method) {
return callers.contains(method);
}
+ @Override
public boolean hasReader(Node method) {
return readers.contains(method);
}
+ @Override
public boolean hasWriter(Node method) {
return writers.contains(method);
}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
index af606fb..92da768 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
@@ -62,7 +62,7 @@
import java.util.Optional;
import java.util.Set;
-public final class DefaultInliningOracle implements InliningOracle {
+public class DefaultInliningOracle implements InliningOracle {
private final AppView<AppInfoWithLiveness> appView;
private final InternalOptions options;
@@ -73,7 +73,7 @@
private final InliningReasonStrategy reasonStrategy;
private int instructionAllowance;
- DefaultInliningOracle(
+ public DefaultInliningOracle(
AppView<AppInfoWithLiveness> appView,
InliningReasonStrategy inliningReasonStrategy,
ProgramMethod method,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
index 2a1fc09..5a16e64 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Inliner.java
@@ -48,6 +48,7 @@
import com.android.tools.r8.ir.conversion.IRConverter;
import com.android.tools.r8.ir.conversion.LensCodeRewriter;
import com.android.tools.r8.ir.conversion.MethodProcessor;
+import com.android.tools.r8.ir.conversion.OneTimeMethodProcessor;
import com.android.tools.r8.ir.conversion.PostMethodProcessor;
import com.android.tools.r8.ir.optimize.SimpleDominatingEffectAnalysis.SimpleEffectAnalysisResult;
import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
@@ -58,6 +59,7 @@
import com.android.tools.r8.ir.optimize.inliner.WhyAreYouNotInliningReporter;
import com.android.tools.r8.ir.optimize.membervaluepropagation.R8MemberValuePropagation;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.shaking.KeepMethodInfo;
import com.android.tools.r8.shaking.MainDexInfo;
import com.android.tools.r8.utils.ConsumerUtils;
import com.android.tools.r8.utils.InternalOptions;
@@ -106,6 +108,14 @@
private final AvailableApiExceptions availableApiExceptions;
+ public Inliner(AppView<AppInfoWithLiveness> appView) {
+ this(appView, null);
+ }
+
+ public Inliner(AppView<AppInfoWithLiveness> appView, IRConverter converter) {
+ this(appView, converter, null);
+ }
+
public Inliner(
AppView<AppInfoWithLiveness> appView,
IRConverter converter,
@@ -123,6 +133,11 @@
: null;
}
+ public WhyAreYouNotInliningReporter createWhyAreYouNotInliningReporter(
+ ProgramMethod singleTarget, ProgramMethod context) {
+ return WhyAreYouNotInliningReporter.createFor(singleTarget, appView, context);
+ }
+
public LensCodeRewriter getLensCodeRewriter() {
return lensCodeRewriter;
}
@@ -139,6 +154,20 @@
@SuppressWarnings("ReferenceEquality")
public ConstraintWithTarget computeInliningConstraint(IRCode code) {
+ InternalOptions options = appView.options();
+ if (!options.inlinerOptions().enableInlining) {
+ return ConstraintWithTarget.NEVER;
+ }
+ ProgramMethod method = code.context();
+ DexEncodedMethod definition = method.getDefinition();
+ if (definition.isClassInitializer() || method.getOrComputeReachabilitySensitive(appView)) {
+ return ConstraintWithTarget.NEVER;
+ }
+ KeepMethodInfo keepInfo = appView.getKeepInfo(method);
+ if (!keepInfo.isInliningAllowed(options) && !keepInfo.isClassInliningAllowed(options)) {
+ return ConstraintWithTarget.NEVER;
+ }
+
if (containsPotentialCatchHandlerVerificationError(code)) {
return ConstraintWithTarget.NEVER;
}
@@ -1030,7 +1059,8 @@
// TODO(b/156853206): Should not duplicate resolution.
ProgramMethod singleTarget = oracle.lookupSingleTarget(invoke, context);
if (singleTarget == null) {
- WhyAreYouNotInliningReporter.handleInvokeWithUnknownTarget(invoke, appView, context);
+ WhyAreYouNotInliningReporter.handleInvokeWithUnknownTarget(
+ this, invoke, appView, context);
continue;
}
@@ -1039,7 +1069,7 @@
WhyAreYouNotInliningReporter whyAreYouNotInliningReporter =
singleTargetOracle.isForcedInliningOracle()
? NopWhyAreYouNotInliningReporter.getInstance()
- : WhyAreYouNotInliningReporter.createFor(singleTarget, appView, context);
+ : createWhyAreYouNotInliningReporter(singleTarget, context);
InlineResult inlineResult =
singleTargetOracle.computeInlining(
code,
@@ -1095,15 +1125,17 @@
appView, code, inlinee, blockIterator, blocksToRemove, action.getDowncastClass());
if (methodProcessor.getCallSiteInformation().hasSingleCallSite(singleTarget, context)) {
- assert converter.isInWave();
feedback.markInlinedIntoSingleCallSite(singleTargetMethod);
- if (singleCallerInlinedMethodsInWave.isEmpty()) {
- converter.addWaveDoneAction(this::onWaveDone);
+ if (!(methodProcessor instanceof OneTimeMethodProcessor)) {
+ assert converter.isInWave();
+ if (singleCallerInlinedMethodsInWave.isEmpty()) {
+ converter.addWaveDoneAction(this::onWaveDone);
+ }
+ singleCallerInlinedMethodsInWave
+ .computeIfAbsent(
+ singleTarget.getHolder(), ignoreKey(ProgramMethodMap::createConcurrent))
+ .put(singleTarget, context);
}
- singleCallerInlinedMethodsInWave
- .computeIfAbsent(
- singleTarget.getHolder(), ignoreKey(ProgramMethodMap::createConcurrent))
- .put(singleTarget, context);
}
classInitializationAnalysis.notifyCodeHasChanged();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/inliner/WhyAreYouNotInliningReporter.java b/src/main/java/com/android/tools/r8/ir/optimize/inliner/WhyAreYouNotInliningReporter.java
index 159128c..b333817 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/inliner/WhyAreYouNotInliningReporter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/inliner/WhyAreYouNotInliningReporter.java
@@ -11,6 +11,7 @@
import com.android.tools.r8.ir.code.Instruction;
import com.android.tools.r8.ir.code.InvokeDirect;
import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.optimize.Inliner;
import com.android.tools.r8.ir.optimize.Inliner.Reason;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.collections.ProgramMethodSet;
@@ -27,7 +28,10 @@
}
public static void handleInvokeWithUnknownTarget(
- InvokeMethod invoke, AppView<AppInfoWithLiveness> appView, ProgramMethod context) {
+ Inliner inliner,
+ InvokeMethod invoke,
+ AppView<AppInfoWithLiveness> appView,
+ ProgramMethod context) {
if (appView.appInfo().hasNoWhyAreYouNotInliningMethods()) {
return;
}
@@ -41,7 +45,7 @@
}
for (ProgramMethod possibleTarget : possibleProgramTargets) {
- createFor(possibleTarget, appView, context).reportUnknownTarget();
+ inliner.createWhyAreYouNotInliningReporter(possibleTarget, context).reportUnknownTarget();
}
}
diff --git a/src/main/java/com/android/tools/r8/lightir/LirLensCodeRewriter.java b/src/main/java/com/android/tools/r8/lightir/LirLensCodeRewriter.java
index a12f88c..3746ba5 100644
--- a/src/main/java/com/android/tools/r8/lightir/LirLensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/lightir/LirLensCodeRewriter.java
@@ -351,10 +351,9 @@
MethodConversionOptions.forLirPhase(appView).disableStringSwitchConversion());
AffectedValues affectedValues = code.removeUnreachableBlocks();
affectedValues.narrowingWithAssumeRemoval(appView, code);
- DeadCodeRemover deadCodeRemover = new DeadCodeRemover(appView);
- deadCodeRemover.run(code, Timing.empty());
+ new DeadCodeRemover(appView).run(code, Timing.empty());
LirCode<Integer> result =
- new IRToLirFinalizer(appView, deadCodeRemover)
+ new IRToLirFinalizer(appView)
.finalizeCode(code, BytecodeMetadataProvider.empty(), Timing.empty());
return (LirCode<EV>) result;
}
@@ -370,9 +369,7 @@
// MethodProcessor argument is only used by unboxing lenses.
MethodProcessor methodProcessor = null;
new LensCodeRewriter(appView).rewrite(code, context, methodProcessor);
- DeadCodeRemover deadCodeRemover = new DeadCodeRemover(appView);
- deadCodeRemover.run(code, Timing.empty());
- IRToLirFinalizer finalizer = new IRToLirFinalizer(appView, deadCodeRemover);
+ IRToLirFinalizer finalizer = new IRToLirFinalizer(appView);
LirCode<?> rewritten =
finalizer.finalizeCode(code, BytecodeMetadataProvider.empty(), Timing.empty());
return (LirCode<EV>) rewritten;
diff --git a/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInliner.java b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInliner.java
new file mode 100644
index 0000000..1105f7c
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInliner.java
@@ -0,0 +1,228 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.singlecaller;
+
+import static com.android.tools.r8.ir.optimize.info.OptimizationFeedback.getSimpleFeedback;
+
+import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.PrunedItems;
+import com.android.tools.r8.graph.bytecodemetadata.BytecodeMetadataProvider;
+import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.conversion.IRToLirFinalizer;
+import com.android.tools.r8.ir.conversion.MethodConversionOptions;
+import com.android.tools.r8.ir.conversion.MethodProcessor;
+import com.android.tools.r8.ir.conversion.MethodProcessorEventConsumer;
+import com.android.tools.r8.ir.conversion.OneTimeMethodProcessor;
+import com.android.tools.r8.ir.conversion.callgraph.CallSiteInformation;
+import com.android.tools.r8.ir.optimize.CodeRewriter;
+import com.android.tools.r8.ir.optimize.DefaultInliningOracle;
+import com.android.tools.r8.ir.optimize.Inliner;
+import com.android.tools.r8.ir.optimize.inliner.InliningIRProvider;
+import com.android.tools.r8.ir.optimize.inliner.InliningReasonStrategy;
+import com.android.tools.r8.ir.optimize.inliner.NopWhyAreYouNotInliningReporter;
+import com.android.tools.r8.ir.optimize.inliner.WhyAreYouNotInliningReporter;
+import com.android.tools.r8.lightir.LirCode;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import java.util.Deque;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+
+public class SingleCallerInliner {
+
+ private final AppView<AppInfoWithLiveness> appView;
+
+ public SingleCallerInliner(AppView<AppInfoWithLiveness> appView) {
+ this.appView = appView;
+ }
+
+ public void runIfNecessary(ExecutorService executorService, Timing timing)
+ throws ExecutionException {
+ if (shouldRun()) {
+ timing.begin("SingleCallerInliner");
+ run(executorService);
+ timing.end();
+ }
+ }
+
+ private boolean shouldRun() {
+ InternalOptions options = appView.options();
+ return !options.debug
+ && !options.intermediate
+ && options.isOptimizing()
+ && options.isShrinking();
+ }
+
+ public void run(ExecutorService executorService) throws ExecutionException {
+ ProgramMethodMap<ProgramMethod> singleCallerMethods =
+ new SingleCallerScanner(appView).getSingleCallerMethods(executorService);
+ Inliner inliner = new SingleCallerInlinerImpl(appView, singleCallerMethods);
+ processCallees(inliner, singleCallerMethods, executorService);
+ performInlining(inliner, singleCallerMethods, executorService);
+ pruneItems(singleCallerMethods, executorService);
+ }
+
+ private void processCallees(
+ Inliner inliner,
+ ProgramMethodMap<ProgramMethod> singleCallerMethods,
+ ExecutorService executorService)
+ throws ExecutionException {
+ ThreadUtils.processItems(
+ singleCallerMethods.streamKeys()::forEach,
+ callee -> {
+ IRCode code = callee.buildIR(appView);
+ getSimpleFeedback()
+ .markProcessed(callee.getDefinition(), inliner.computeInliningConstraint(code));
+ // TODO(b/325199754): Do not tamper with the code API level.
+ if (callee.getDefinition().getApiLevelForCode().isNotSetApiLevel()) {
+ callee
+ .getDefinition()
+ .setApiLevelForCode(
+ appView.apiLevelCompute().computeInitialMinApiLevel(appView.options()));
+ }
+ },
+ appView.options().getThreadingModule(),
+ executorService);
+ }
+
+ private void performInlining(
+ Inliner inliner,
+ ProgramMethodMap<ProgramMethod> singleCallerMethods,
+ ExecutorService executorService)
+ throws ExecutionException {
+ CallSiteInformation callSiteInformation = createCallSiteInformation(singleCallerMethods);
+ Deque<ProgramMethodSet> waves = SingleCallerWaves.buildWaves(appView, singleCallerMethods);
+ while (!waves.isEmpty()) {
+ ProgramMethodSet wave = waves.removeFirst();
+ OneTimeMethodProcessor methodProcessor =
+ OneTimeMethodProcessor.create(wave, MethodProcessorEventConsumer.empty(), appView);
+ methodProcessor.setCallSiteInformation(callSiteInformation);
+ methodProcessor.forEachWaveWithExtension(
+ (method, methodProcessingContext) -> {
+ // TODO(b/325199754): Do not tamper with the code API level.
+ if (method.getDefinition().getApiLevelForCode().isNotSetApiLevel()) {
+ method
+ .getDefinition()
+ .setApiLevelForCode(
+ appView.apiLevelCompute().computeInitialMinApiLevel(appView.options()));
+ }
+ IRCode code =
+ method.buildIR(
+ appView,
+ MethodConversionOptions.forLirPhase(appView).disableStringSwitchConversion());
+ inliner.performInlining(
+ method, code, getSimpleFeedback(), methodProcessor, Timing.empty());
+ CodeRewriter.removeAssumeInstructions(appView, code);
+ LirCode<Integer> lirCode =
+ new IRToLirFinalizer(appView)
+ .finalizeCode(code, BytecodeMetadataProvider.empty(), Timing.empty());
+ method.setCode(lirCode, appView);
+ },
+ appView.options().getThreadingModule(),
+ executorService);
+ }
+ }
+
+ private CallSiteInformation createCallSiteInformation(
+ ProgramMethodMap<ProgramMethod> singleCallerMethods) {
+ return new CallSiteInformation() {
+ @Override
+ public boolean hasSingleCallSite(ProgramMethod method, ProgramMethod context) {
+ return singleCallerMethods.containsKey(method);
+ }
+
+ @Override
+ public boolean hasSingleCallSite(ProgramMethod method) {
+ return singleCallerMethods.containsKey(method);
+ }
+
+ @Override
+ public boolean isMultiCallerInlineCandidate(ProgramMethod method) {
+ return false;
+ }
+
+ @Override
+ public void unsetCallSiteInformation(ProgramMethod method) {
+ throw new Unreachable();
+ }
+ };
+ }
+
+ private void pruneItems(
+ ProgramMethodMap<ProgramMethod> singleCallerMethods, ExecutorService executorService)
+ throws ExecutionException {
+ PrunedItems.Builder prunedItemsBuilder = PrunedItems.builder().setPrunedApp(appView.app());
+ singleCallerMethods.forEach(
+ (callee, caller) -> {
+ if (callee.getOptimizationInfo().hasBeenInlinedIntoSingleCallSite()) {
+ prunedItemsBuilder.addFullyInlinedMethod(callee.getReference(), caller);
+ callee.getHolder().removeMethod(callee.getReference());
+ }
+ });
+ PrunedItems prunedItems = prunedItemsBuilder.build();
+ appView.pruneItems(prunedItems, executorService, Timing.empty());
+ appView.appInfo().getMethodAccessInfoCollection().withoutPrunedItems(prunedItems);
+ }
+
+ private static class SingleCallerInlinerImpl extends Inliner {
+
+ private final ProgramMethodMap<ProgramMethod> singleCallerMethods;
+
+ SingleCallerInlinerImpl(
+ AppView<AppInfoWithLiveness> appView, ProgramMethodMap<ProgramMethod> singleCallerMethods) {
+ super(appView);
+ this.singleCallerMethods = singleCallerMethods;
+ }
+
+ @Override
+ public DefaultInliningOracle createDefaultOracle(
+ ProgramMethod method,
+ MethodProcessor methodProcessor,
+ int inliningInstructionAllowance,
+ InliningReasonStrategy inliningReasonStrategy) {
+ return new DefaultInliningOracle(
+ appView, inliningReasonStrategy, method, methodProcessor, inliningInstructionAllowance) {
+
+ @Override
+ public InlineResult computeInlining(
+ IRCode code,
+ InvokeMethod invoke,
+ SingleResolutionResult<?> resolutionResult,
+ ProgramMethod singleTarget,
+ ProgramMethod context,
+ ClassInitializationAnalysis classInitializationAnalysis,
+ InliningIRProvider inliningIRProvider,
+ WhyAreYouNotInliningReporter whyAreYouNotInliningReporter) {
+ if (!singleCallerMethods.containsKey(singleTarget)) {
+ return null;
+ }
+ return super.computeInlining(
+ code,
+ invoke,
+ resolutionResult,
+ singleTarget,
+ context,
+ classInitializationAnalysis,
+ inliningIRProvider,
+ whyAreYouNotInliningReporter);
+ }
+ };
+ }
+
+ @Override
+ public WhyAreYouNotInliningReporter createWhyAreYouNotInliningReporter(
+ ProgramMethod singleTarget, ProgramMethod context) {
+ return NopWhyAreYouNotInliningReporter.getInstance();
+ }
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInlinerCallGraph.java b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInlinerCallGraph.java
new file mode 100644
index 0000000..03342cf
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerInlinerCallGraph.java
@@ -0,0 +1,171 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.singlecaller;
+
+import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.conversion.callgraph.CallGraphBase;
+import com.android.tools.r8.ir.conversion.callgraph.CallGraphBuilderBase;
+import com.android.tools.r8.ir.conversion.callgraph.CycleEliminator;
+import com.android.tools.r8.ir.conversion.callgraph.CycleEliminatorNode;
+import com.android.tools.r8.ir.conversion.callgraph.NodeBase;
+import com.android.tools.r8.optimize.singlecaller.SingleCallerInlinerCallGraph.Node;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.google.common.collect.Sets;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+
+public class SingleCallerInlinerCallGraph extends CallGraphBase<Node> {
+
+ public SingleCallerInlinerCallGraph(Map<DexMethod, Node> nodes) {
+ super(nodes);
+ }
+
+ public static Builder builder(AppView<AppInfoWithLiveness> appView) {
+ return new Builder(appView);
+ }
+
+ public ProgramMethodSet extractLeaves() {
+ ProgramMethodSet result = ProgramMethodSet.create();
+ Set<Node> removed = Sets.newIdentityHashSet();
+ Iterator<Node> nodeIterator = getNodes().iterator();
+ while (nodeIterator.hasNext()) {
+ Node node = nodeIterator.next();
+ if (node.isLeaf()) {
+ result.add(node.getProgramMethod());
+ nodeIterator.remove();
+ removed.add(node);
+ }
+ }
+ removed.forEach(Node::cleanForRemoval);
+ return result;
+ }
+
+ public static class Builder extends CallGraphBuilderBase<Node> {
+
+ public Builder(AppView<AppInfoWithLiveness> appView) {
+ super(appView);
+ }
+
+ public Builder populateGraph(ProgramMethodMap<ProgramMethod> singleCallerMethods) {
+ singleCallerMethods.forEach(
+ (callee, caller) -> getOrCreateNode(callee).addCaller(getOrCreateNode(caller)));
+ return this;
+ }
+
+ public Builder eliminateCycles() {
+ // Sort the nodes for deterministic cycle elimination.
+ Set<Node> nodesWithDeterministicOrder = Sets.newTreeSet(nodes.values());
+ new CycleEliminator<Node>().breakCycles(nodesWithDeterministicOrder);
+ return this;
+ }
+
+ public SingleCallerInlinerCallGraph build() {
+ return new SingleCallerInlinerCallGraph(nodes);
+ }
+
+ @Override
+ protected Node createNode(ProgramMethod method) {
+ return new Node(method);
+ }
+ }
+
+ public static class Node extends NodeBase<Node>
+ implements Comparable<Node>, CycleEliminatorNode<Node> {
+
+ private Node caller = null;
+ private final Set<Node> callees = new TreeSet<>();
+
+ public Node(ProgramMethod method) {
+ super(method);
+ }
+
+ public void addCaller(Node caller) {
+ assert this.caller == null;
+ if (this == caller) {
+ return;
+ }
+ this.caller = caller;
+ caller.callees.add(this);
+ }
+
+ public void cleanForRemoval() {
+ assert callees.isEmpty();
+ if (caller != null) {
+ caller.callees.remove(this);
+ caller = null;
+ }
+ }
+
+ public boolean isLeaf() {
+ return callees.isEmpty();
+ }
+
+ @Override
+ public void addCallerConcurrently(Node caller, boolean likelySpuriousCallEdge) {
+ throw new Unreachable();
+ }
+
+ @Override
+ public void addReaderConcurrently(Node reader) {
+ throw new Unreachable();
+ }
+
+ @Override
+ public int compareTo(Node node) {
+ return getProgramMethod().getReference().compareTo(node.getProgramMethod().getReference());
+ }
+
+ @Override
+ public Set<Node> getCalleesWithDeterministicOrder() {
+ return callees;
+ }
+
+ @Override
+ public Set<Node> getWritersWithDeterministicOrder() {
+ return Collections.emptySet();
+ }
+
+ @Override
+ public boolean hasCallee(Node callee) {
+ return callees.contains(callee);
+ }
+
+ @Override
+ public boolean hasCaller(Node caller) {
+ return this.caller != null && this.caller == caller;
+ }
+
+ @Override
+ public void removeCaller(Node caller) {
+ assert this.caller != null;
+ assert this.caller == caller;
+ boolean changed = caller.callees.remove(this);
+ assert changed;
+ this.caller = null;
+ }
+
+ @Override
+ public boolean hasReader(Node reader) {
+ return false;
+ }
+
+ @Override
+ public void removeReader(Node reader) {
+ throw new Unreachable();
+ }
+
+ @Override
+ public boolean hasWriter(Node writer) {
+ return false;
+ }
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerScanner.java b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerScanner.java
new file mode 100644
index 0000000..c0d5b6a
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerScanner.java
@@ -0,0 +1,208 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.singlecaller;
+
+import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexCallSite;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexMethodHandle;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexValue;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.desugar.LambdaDescriptor;
+import com.android.tools.r8.lightir.LirCode;
+import com.android.tools.r8.lightir.LirConstant;
+import com.android.tools.r8.lightir.LirInstructionView;
+import com.android.tools.r8.lightir.LirOpcodes;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+
+public class SingleCallerScanner {
+
+ private static final ProgramMethod MULTIPLE_CALLERS = ProgramMethod.createSentinel();
+
+ private final AppView<AppInfoWithLiveness> appView;
+
+ SingleCallerScanner(AppView<AppInfoWithLiveness> appView) {
+ this.appView = appView;
+ }
+
+ public ProgramMethodMap<ProgramMethod> getSingleCallerMethods(ExecutorService executorService)
+ throws ExecutionException {
+ ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates =
+ traceConstantPools(executorService);
+ return traceInstructions(singleCallerMethodCandidates, executorService);
+ }
+
+ private ProgramMethodMap<ProgramMethod> traceConstantPools(ExecutorService executorService)
+ throws ExecutionException {
+ ProgramMethodMap<ProgramMethod> traceResult = ProgramMethodMap.createConcurrent();
+ ThreadUtils.processItems(
+ appView.appInfo().classes(),
+ clazz -> recordCallEdges(clazz, traceResult),
+ appView.options().getThreadingModule(),
+ executorService);
+ ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates =
+ ProgramMethodMap.createConcurrent();
+ traceResult.forEach(
+ (callee, caller) -> {
+ if (callee.getDefinition().hasCode()
+ && caller != MULTIPLE_CALLERS
+ && !callee.isStructurallyEqualTo(caller)) {
+ singleCallerMethodCandidates.put(callee, caller);
+ }
+ });
+ return singleCallerMethodCandidates;
+ }
+
+ private void recordCallEdges(
+ DexProgramClass clazz, ProgramMethodMap<ProgramMethod> singleCallerMethods) {
+ clazz.forEachProgramMethodMatching(
+ method -> method.hasCode() && method.getCode().isLirCode(),
+ method -> recordCallEdges(method, singleCallerMethods));
+ }
+
+ private void recordCallEdges(
+ ProgramMethod method, ProgramMethodMap<ProgramMethod> singleCallerMethods) {
+ ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods = ProgramMethodMap.create();
+ LirCode<Integer> code = method.getDefinition().getCode().asLirCode();
+ for (LirConstant constant : code.getConstantPool()) {
+ if (constant instanceof DexCallSite) {
+ traceCallSiteConstant(method, (DexCallSite) constant, threadLocalSingleCallerMethods);
+ } else if (constant instanceof DexMethodHandle) {
+ traceMethodHandleConstant(
+ method, (DexMethodHandle) constant, threadLocalSingleCallerMethods);
+ } else if (constant instanceof DexMethod) {
+ traceMethodConstant(method, (DexMethod) constant, threadLocalSingleCallerMethods);
+ }
+ }
+ threadLocalSingleCallerMethods.forEach(
+ (callee, caller) -> {
+ if (caller == MULTIPLE_CALLERS) {
+ singleCallerMethods.put(callee, MULTIPLE_CALLERS);
+ } else {
+ recordCallEdge(caller, callee, singleCallerMethods);
+ }
+ });
+ }
+
+ private void traceCallSiteConstant(
+ ProgramMethod method,
+ DexCallSite callSite,
+ ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
+ LambdaDescriptor descriptor =
+ LambdaDescriptor.tryInfer(callSite, appView, appView.appInfo(), method);
+ if (descriptor != null) {
+ traceMethodHandleConstant(method, descriptor.implHandle, threadLocalSingleCallerMethods);
+ } else {
+ traceMethodHandleConstant(method, callSite.bootstrapMethod, threadLocalSingleCallerMethods);
+ for (DexValue bootstrapArg : callSite.getBootstrapArgs()) {
+ if (bootstrapArg.isDexValueMethodHandle()) {
+ traceMethodHandleConstant(
+ method,
+ bootstrapArg.asDexValueMethodHandle().getValue(),
+ threadLocalSingleCallerMethods);
+ }
+ }
+ }
+ }
+
+ private void traceMethodHandleConstant(
+ ProgramMethod method,
+ DexMethodHandle methodHandle,
+ ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
+ if (!methodHandle.isMethodHandle()) {
+ return;
+ }
+ traceMethodConstant(method, methodHandle.asMethod(), threadLocalSingleCallerMethods);
+ }
+
+ private void traceMethodConstant(
+ ProgramMethod method,
+ DexMethod referencedMethod,
+ ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
+ if (referencedMethod.getHolderType().isArrayType()) {
+ return;
+ }
+ if (referencedMethod.isInstanceInitializer(appView.dexItemFactory())) {
+ ProgramMethod referencedProgramMethod =
+ appView
+ .appInfo()
+ .unsafeResolveMethodDueToDexFormat(referencedMethod)
+ .getResolvedProgramMethod();
+ if (referencedProgramMethod != null) {
+ recordCallEdge(method, referencedProgramMethod, threadLocalSingleCallerMethods);
+ }
+ } else {
+ DexProgramClass referencedProgramMethodHolder =
+ asProgramClassOrNull(
+ appView
+ .appInfo()
+ .definitionForWithoutExistenceAssert(referencedMethod.getHolderType()));
+ ProgramMethod referencedProgramMethod =
+ referencedMethod.lookupOnProgramClass(referencedProgramMethodHolder);
+ if (referencedProgramMethod != null
+ && referencedProgramMethod.getAccessFlags().isPrivate()
+ && !referencedProgramMethod.getAccessFlags().isStatic()) {
+ recordCallEdge(method, referencedProgramMethod, threadLocalSingleCallerMethods);
+ }
+ }
+ }
+
+ private void recordCallEdge(
+ ProgramMethod caller, ProgramMethod callee, ProgramMethodMap<ProgramMethod> callers) {
+ callers.compute(
+ callee, (ignore, existingCallers) -> existingCallers != null ? MULTIPLE_CALLERS : caller);
+ }
+
+ private ProgramMethodMap<ProgramMethod> traceInstructions(
+ ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates, ExecutorService executorService)
+ throws ExecutionException {
+ ProgramMethodMap<ProgramMethodSet> callersToCallees = ProgramMethodMap.create();
+ singleCallerMethodCandidates.forEach(
+ (callee, caller) ->
+ callersToCallees
+ .computeIfAbsent(caller, ignoreKey(ProgramMethodSet::create))
+ .add(callee));
+ ThreadUtils.processItems(
+ callersToCallees.streamKeys()::forEach,
+ caller -> {
+ ProgramMethodSet callees = callersToCallees.get(caller);
+ LirCode<Integer> code = caller.getDefinition().getCode().asLirCode();
+ ProgramMethodMap<Integer> counters = ProgramMethodMap.create();
+ for (LirInstructionView view : code) {
+ int opcode = view.getOpcode();
+ if (opcode != LirOpcodes.INVOKEDIRECT && opcode != LirOpcodes.INVOKEDIRECT_ITF) {
+ continue;
+ }
+ DexMethod invokedMethod =
+ (DexMethod) code.getConstantItem(view.getNextConstantOperand());
+ ProgramMethod resolvedMethod =
+ appView
+ .appInfo()
+ .resolveMethod(invokedMethod, opcode == LirOpcodes.INVOKEDIRECT_ITF)
+ .getResolvedProgramMethod();
+ if (resolvedMethod != null && callees.contains(resolvedMethod)) {
+ counters.put(resolvedMethod, counters.getOrDefault(resolvedMethod, 0) + 1);
+ }
+ }
+ counters.forEach(
+ (callee, counter) -> {
+ if (counter > 1) {
+ singleCallerMethodCandidates.remove(callee);
+ }
+ });
+ },
+ appView.options().getThreadingModule(),
+ executorService);
+ return singleCallerMethodCandidates;
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerWaves.java b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerWaves.java
new file mode 100644
index 0000000..8cddbdb
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/singlecaller/SingleCallerWaves.java
@@ -0,0 +1,31 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.singlecaller;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
+import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import java.util.ArrayDeque;
+import java.util.Deque;
+
+public class SingleCallerWaves {
+
+ public static Deque<ProgramMethodSet> buildWaves(
+ AppView<AppInfoWithLiveness> appView, ProgramMethodMap<ProgramMethod> singleCallerMethods) {
+ SingleCallerInlinerCallGraph callGraph =
+ SingleCallerInlinerCallGraph.builder(appView)
+ .populateGraph(singleCallerMethods)
+ .eliminateCycles()
+ .build();
+ Deque<ProgramMethodSet> waves = new ArrayDeque<>();
+ // Intentionally drop first round of leaves as they are the callees.
+ callGraph.extractLeaves();
+ while (!callGraph.isEmpty()) {
+ waves.addLast(callGraph.extractLeaves());
+ }
+ return waves;
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/lightir/TrivialPhiLirRegression299417534Test.java b/src/test/java/com/android/tools/r8/lightir/TrivialPhiLirRegression299417534Test.java
index 780d31b..6a1d3b1 100644
--- a/src/test/java/com/android/tools/r8/lightir/TrivialPhiLirRegression299417534Test.java
+++ b/src/test/java/com/android/tools/r8/lightir/TrivialPhiLirRegression299417534Test.java
@@ -17,7 +17,6 @@
import com.android.tools.r8.ir.code.Phi;
import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.ir.conversion.IRToLirFinalizer;
-import com.android.tools.r8.ir.optimize.DeadCodeRemover;
import com.android.tools.r8.utils.BooleanBox;
import com.android.tools.r8.utils.Timing;
import java.util.List;
@@ -76,9 +75,8 @@
// Finalize the IR via LIR and rebuild it again.
Timing timing = Timing.empty();
- DeadCodeRemover deadCodeRemover = new DeadCodeRemover(appView);
BytecodeMetadataProvider metadataProvider = BytecodeMetadataProvider.empty();
- IRToLirFinalizer finalizer = new IRToLirFinalizer(appView, deadCodeRemover);
+ IRToLirFinalizer finalizer = new IRToLirFinalizer(appView);
LirCode<Integer> lirCode =
finalizer.finalizeCode(code, metadataProvider, timing);
code.context().setCode(lirCode, appView);