blob: f71e82f545586f1fe6baa8e5329a132ecb1128ef [file] [log] [blame]
// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
package com.android.tools.r8.optimize.singlecaller;
import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
import static com.android.tools.r8.utils.MapUtils.ignoreKey;
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexCallSite;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexMethodHandle;
import com.android.tools.r8.graph.DexProgramClass;
import com.android.tools.r8.graph.DexValue;
import com.android.tools.r8.graph.ProgramMethod;
import com.android.tools.r8.ir.desugar.LambdaDescriptor;
import com.android.tools.r8.lightir.LirCode;
import com.android.tools.r8.lightir.LirConstant;
import com.android.tools.r8.lightir.LirInstructionView;
import com.android.tools.r8.lightir.LirOpcodes;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.ObjectUtils;
import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.collections.ProgramMethodMap;
import com.android.tools.r8.utils.collections.ProgramMethodSet;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
public class SingleCallerScanner {
private static final ProgramMethod MULTIPLE_CALLERS = ProgramMethod.createSentinel();
private final AppView<AppInfoWithLiveness> appView;
SingleCallerScanner(AppView<AppInfoWithLiveness> appView) {
this.appView = appView;
}
public ProgramMethodMap<ProgramMethod> getSingleCallerMethods(ExecutorService executorService)
throws ExecutionException {
ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates =
traceConstantPools(executorService);
return traceInstructions(singleCallerMethodCandidates, executorService);
}
private ProgramMethodMap<ProgramMethod> traceConstantPools(ExecutorService executorService)
throws ExecutionException {
ProgramMethodMap<ProgramMethod> traceResult = ProgramMethodMap.createConcurrent();
ThreadUtils.processItems(
appView.appInfo().classes(),
clazz -> recordCallEdges(clazz, traceResult),
appView.options().getThreadingModule(),
executorService);
ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates =
ProgramMethodMap.createConcurrent();
traceResult.forEach(
(callee, caller) -> {
if (callee.getDefinition().hasCode()
&& ObjectUtils.notIdentical(caller, MULTIPLE_CALLERS)
&& !callee.isStructurallyEqualTo(caller)) {
singleCallerMethodCandidates.put(callee, caller);
}
});
return singleCallerMethodCandidates;
}
private void recordCallEdges(
DexProgramClass clazz, ProgramMethodMap<ProgramMethod> singleCallerMethods) {
clazz.forEachProgramMethodMatching(
method -> method.hasCode() && method.getCode().isLirCode(),
method -> recordCallEdges(method, singleCallerMethods));
}
private void recordCallEdges(
ProgramMethod method, ProgramMethodMap<ProgramMethod> singleCallerMethods) {
ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods = ProgramMethodMap.create();
LirCode<Integer> code = method.getDefinition().getCode().asLirCode();
for (LirConstant constant : code.getConstantPool()) {
if (constant instanceof DexCallSite) {
traceCallSiteConstant(method, (DexCallSite) constant, threadLocalSingleCallerMethods);
} else if (constant instanceof DexMethodHandle) {
traceMethodHandleConstant(
method, (DexMethodHandle) constant, threadLocalSingleCallerMethods);
} else if (constant instanceof DexMethod) {
traceMethodConstant(method, (DexMethod) constant, threadLocalSingleCallerMethods);
}
}
threadLocalSingleCallerMethods.forEach(
(callee, caller) -> {
if (ObjectUtils.identical(caller, MULTIPLE_CALLERS)) {
singleCallerMethods.put(callee, MULTIPLE_CALLERS);
} else {
recordCallEdge(caller, callee, singleCallerMethods);
}
});
}
private void traceCallSiteConstant(
ProgramMethod method,
DexCallSite callSite,
ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
LambdaDescriptor descriptor =
LambdaDescriptor.tryInfer(callSite, appView, appView.appInfo(), method);
if (descriptor != null) {
traceMethodHandleConstant(method, descriptor.implHandle, threadLocalSingleCallerMethods);
} else {
traceMethodHandleConstant(method, callSite.bootstrapMethod, threadLocalSingleCallerMethods);
for (DexValue bootstrapArg : callSite.getBootstrapArgs()) {
if (bootstrapArg.isDexValueMethodHandle()) {
traceMethodHandleConstant(
method,
bootstrapArg.asDexValueMethodHandle().getValue(),
threadLocalSingleCallerMethods);
}
}
}
}
private void traceMethodHandleConstant(
ProgramMethod method,
DexMethodHandle methodHandle,
ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
if (!methodHandle.isMethodHandle()) {
return;
}
traceMethodConstant(method, methodHandle.asMethod(), threadLocalSingleCallerMethods);
}
private void traceMethodConstant(
ProgramMethod method,
DexMethod referencedMethod,
ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
if (referencedMethod.getHolderType().isArrayType()) {
return;
}
if (referencedMethod.isInstanceInitializer(appView.dexItemFactory())) {
ProgramMethod referencedProgramMethod =
appView
.appInfo()
.unsafeResolveMethodDueToDexFormat(referencedMethod)
.getResolvedProgramMethod();
if (referencedProgramMethod != null) {
recordCallEdge(method, referencedProgramMethod, threadLocalSingleCallerMethods);
}
} else {
DexProgramClass referencedProgramMethodHolder =
asProgramClassOrNull(
appView
.appInfo()
.definitionForWithoutExistenceAssert(referencedMethod.getHolderType()));
ProgramMethod referencedProgramMethod =
referencedMethod.lookupOnProgramClass(referencedProgramMethodHolder);
if (referencedProgramMethod != null
&& referencedProgramMethod.getAccessFlags().isPrivate()
&& !referencedProgramMethod.getAccessFlags().isStatic()) {
recordCallEdge(method, referencedProgramMethod, threadLocalSingleCallerMethods);
}
}
}
private void recordCallEdge(
ProgramMethod caller, ProgramMethod callee, ProgramMethodMap<ProgramMethod> callers) {
callers.compute(
callee, (ignore, existingCallers) -> existingCallers != null ? MULTIPLE_CALLERS : caller);
}
private ProgramMethodMap<ProgramMethod> traceInstructions(
ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates, ExecutorService executorService)
throws ExecutionException {
ProgramMethodMap<ProgramMethodSet> callersToCallees = ProgramMethodMap.create();
singleCallerMethodCandidates.forEach(
(callee, caller) ->
callersToCallees
.computeIfAbsent(caller, ignoreKey(ProgramMethodSet::create))
.add(callee));
ThreadUtils.processItems(
callersToCallees.streamKeys()::forEach,
caller -> {
ProgramMethodSet callees = callersToCallees.get(caller);
LirCode<Integer> code = caller.getDefinition().getCode().asLirCode();
ProgramMethodMap<Integer> counters = ProgramMethodMap.create();
for (LirInstructionView view : code) {
int opcode = view.getOpcode();
if (opcode != LirOpcodes.INVOKEDIRECT
&& opcode != LirOpcodes.INVOKEDIRECT_ITF
// JDK 17 generates invokevirtual to private methods.
&& opcode != LirOpcodes.INVOKEVIRTUAL) {
continue;
}
DexMethod invokedMethod =
(DexMethod) code.getConstantItem(view.getNextConstantOperand());
ProgramMethod resolvedMethod =
appView
.appInfo()
.resolveMethod(invokedMethod, opcode == LirOpcodes.INVOKEDIRECT_ITF)
.getResolvedProgramMethod();
if (resolvedMethod != null && callees.contains(resolvedMethod)) {
counters.put(resolvedMethod, counters.getOrDefault(resolvedMethod, 0) + 1);
}
}
callees.forEach(
(callee) -> {
if (!counters.containsKey(callee) || counters.get(callee) > 1) {
singleCallerMethodCandidates.remove(callee);
}
});
},
appView.options().getThreadingModule(),
executorService);
return singleCallerMethodCandidates;
}
}