// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
package com.android.tools.r8.optimize.singlecaller;

import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
import static com.android.tools.r8.utils.MapUtils.ignoreKey;

import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexCallSite;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexMethodHandle;
import com.android.tools.r8.graph.DexProgramClass;
import com.android.tools.r8.graph.DexValue;
import com.android.tools.r8.graph.ProgramMethod;
import com.android.tools.r8.ir.desugar.LambdaDescriptor;
import com.android.tools.r8.lightir.LirCode;
import com.android.tools.r8.lightir.LirConstant;
import com.android.tools.r8.lightir.LirInstructionView;
import com.android.tools.r8.lightir.LirOpcodes;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.ObjectUtils;
import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.collections.ProgramMethodMap;
import com.android.tools.r8.utils.collections.ProgramMethodSet;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;

public class SingleCallerScanner {

  private static final ProgramMethod MULTIPLE_CALLERS = ProgramMethod.createSentinel();

  private final AppView<AppInfoWithLiveness> appView;

  SingleCallerScanner(AppView<AppInfoWithLiveness> appView) {
    this.appView = appView;
  }

  public ProgramMethodMap<ProgramMethod> getSingleCallerMethods(ExecutorService executorService)
      throws ExecutionException {
    ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates =
        traceConstantPools(executorService);
    return traceInstructions(singleCallerMethodCandidates, executorService);
  }

  private ProgramMethodMap<ProgramMethod> traceConstantPools(ExecutorService executorService)
      throws ExecutionException {
    ProgramMethodMap<ProgramMethod> traceResult = ProgramMethodMap.createConcurrent();
    ThreadUtils.processItems(
        appView.appInfo().classes(),
        clazz -> recordCallEdges(clazz, traceResult),
        appView.options().getThreadingModule(),
        executorService);
    ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates =
        ProgramMethodMap.createConcurrent();
    traceResult.forEach(
        (callee, caller) -> {
          if (callee.getDefinition().hasCode()
              && ObjectUtils.notIdentical(caller, MULTIPLE_CALLERS)
              && !callee.isStructurallyEqualTo(caller)) {
            singleCallerMethodCandidates.put(callee, caller);
          }
        });
    return singleCallerMethodCandidates;
  }

  private void recordCallEdges(
      DexProgramClass clazz, ProgramMethodMap<ProgramMethod> singleCallerMethods) {
    clazz.forEachProgramMethodMatching(
        method -> method.hasCode() && method.getCode().isLirCode(),
        method -> recordCallEdges(method, singleCallerMethods));
  }

  private void recordCallEdges(
      ProgramMethod method, ProgramMethodMap<ProgramMethod> singleCallerMethods) {
    ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods = ProgramMethodMap.create();
    LirCode<Integer> code = method.getDefinition().getCode().asLirCode();
    for (LirConstant constant : code.getConstantPool()) {
      if (constant instanceof DexCallSite) {
        traceCallSiteConstant(method, (DexCallSite) constant, threadLocalSingleCallerMethods);
      } else if (constant instanceof DexMethodHandle) {
        traceMethodHandleConstant(
            method, (DexMethodHandle) constant, threadLocalSingleCallerMethods);
      } else if (constant instanceof DexMethod) {
        traceMethodConstant(method, (DexMethod) constant, threadLocalSingleCallerMethods);
      }
    }
    threadLocalSingleCallerMethods.forEach(
        (callee, caller) -> {
          if (ObjectUtils.identical(caller, MULTIPLE_CALLERS)) {
            singleCallerMethods.put(callee, MULTIPLE_CALLERS);
          } else {
            recordCallEdge(caller, callee, singleCallerMethods);
          }
        });
  }

  private void traceCallSiteConstant(
      ProgramMethod method,
      DexCallSite callSite,
      ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
    LambdaDescriptor descriptor =
        LambdaDescriptor.tryInfer(callSite, appView, appView.appInfo(), method);
    if (descriptor != null) {
      traceMethodHandleConstant(method, descriptor.implHandle, threadLocalSingleCallerMethods);
    } else {
      traceMethodHandleConstant(method, callSite.bootstrapMethod, threadLocalSingleCallerMethods);
      for (DexValue bootstrapArg : callSite.getBootstrapArgs()) {
        if (bootstrapArg.isDexValueMethodHandle()) {
          traceMethodHandleConstant(
              method,
              bootstrapArg.asDexValueMethodHandle().getValue(),
              threadLocalSingleCallerMethods);
        }
      }
    }
  }

  private void traceMethodHandleConstant(
      ProgramMethod method,
      DexMethodHandle methodHandle,
      ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
    if (!methodHandle.isMethodHandle()) {
      return;
    }
    traceMethodConstant(method, methodHandle.asMethod(), threadLocalSingleCallerMethods);
  }

  private void traceMethodConstant(
      ProgramMethod method,
      DexMethod referencedMethod,
      ProgramMethodMap<ProgramMethod> threadLocalSingleCallerMethods) {
    if (referencedMethod.getHolderType().isArrayType()) {
      return;
    }
    if (referencedMethod.isInstanceInitializer(appView.dexItemFactory())) {
      ProgramMethod referencedProgramMethod =
          appView
              .appInfo()
              .unsafeResolveMethodDueToDexFormat(referencedMethod)
              .getResolvedProgramMethod();
      if (referencedProgramMethod != null) {
        recordCallEdge(method, referencedProgramMethod, threadLocalSingleCallerMethods);
      }
    } else {
      DexProgramClass referencedProgramMethodHolder =
          asProgramClassOrNull(
              appView
                  .appInfo()
                  .definitionForWithoutExistenceAssert(referencedMethod.getHolderType()));
      ProgramMethod referencedProgramMethod =
          referencedMethod.lookupOnProgramClass(referencedProgramMethodHolder);
      if (referencedProgramMethod != null
          && referencedProgramMethod.getAccessFlags().isPrivate()
          && !referencedProgramMethod.getAccessFlags().isStatic()) {
        recordCallEdge(method, referencedProgramMethod, threadLocalSingleCallerMethods);
      }
    }
  }

  private void recordCallEdge(
      ProgramMethod caller, ProgramMethod callee, ProgramMethodMap<ProgramMethod> callers) {
    callers.compute(
        callee, (ignore, existingCallers) -> existingCallers != null ? MULTIPLE_CALLERS : caller);
  }

  private ProgramMethodMap<ProgramMethod> traceInstructions(
      ProgramMethodMap<ProgramMethod> singleCallerMethodCandidates, ExecutorService executorService)
      throws ExecutionException {
    ProgramMethodMap<ProgramMethodSet> callersToCallees = ProgramMethodMap.create();
    singleCallerMethodCandidates.forEach(
        (callee, caller) ->
            callersToCallees
                .computeIfAbsent(caller, ignoreKey(ProgramMethodSet::create))
                .add(callee));
    ThreadUtils.processItems(
        callersToCallees.streamKeys()::forEach,
        caller -> {
          ProgramMethodSet callees = callersToCallees.get(caller);
          LirCode<Integer> code = caller.getDefinition().getCode().asLirCode();
          ProgramMethodMap<Integer> counters = ProgramMethodMap.create();
          for (LirInstructionView view : code) {
            int opcode = view.getOpcode();
            if (opcode != LirOpcodes.INVOKEDIRECT
                && opcode != LirOpcodes.INVOKEDIRECT_ITF
                // JDK 17 generates invokevirtual to private methods.
                && opcode != LirOpcodes.INVOKEVIRTUAL) {
              continue;
            }
            DexMethod invokedMethod =
                (DexMethod) code.getConstantItem(view.getNextConstantOperand());
            ProgramMethod resolvedMethod =
                appView
                    .appInfo()
                    .resolveMethod(invokedMethod, opcode == LirOpcodes.INVOKEDIRECT_ITF)
                    .getResolvedProgramMethod();
            if (resolvedMethod != null && callees.contains(resolvedMethod)) {
              counters.put(resolvedMethod, counters.getOrDefault(resolvedMethod, 0) + 1);
            }
          }
          callees.forEach(
              (callee) -> {
                if (!counters.containsKey(callee) || counters.get(callee) > 1) {
                  singleCallerMethodCandidates.remove(callee);
                }
              });
        },
        appView.options().getThreadingModule(),
        executorService);
    return singleCallerMethodCandidates;
  }
}
