//===- JumpTableToSwitch.cpp ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"

using namespace llvm;

static cl::opt<unsigned>
    JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
                           cl::desc("Only split jump tables with size less or "
                                    "equal than JumpTableSizeThreshold."),
                           cl::init(10));

// TODO: Consider adding a cost model for profitability analysis of this
// transformation. Currently we replace a jump table with a switch if all the
// functions in the jump table are smaller than the provided threshold.
static cl::opt<unsigned> FunctionSizeThreshold(
    "jump-table-to-switch-function-size-threshold", cl::Hidden,
    cl::desc("Only split jump tables containing functions whose sizes are less "
             "or equal than this threshold."),
    cl::init(50));

#define DEBUG_TYPE "jump-table-to-switch"

namespace {
struct JumpTableTy {
  Value *Index;
  SmallVector<Function *, 10> Funcs;
};
} // anonymous namespace

static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
                                                 PointerType *PtrTy) {
  Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
  if (!Ptr)
    return std::nullopt;

  GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
    return std::nullopt;

  Function &F = *GEP->getParent()->getParent();
  const DataLayout &DL = F.getDataLayout();
  const unsigned BitWidth =
      DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
  SmallMapVector<Value *, APInt, 4> VariableOffsets;
  APInt ConstantOffset(BitWidth, 0);
  if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
    return std::nullopt;
  if (VariableOffsets.size() != 1)
    return std::nullopt;
  // TODO: consider supporting more general patterns
  if (!ConstantOffset.isZero())
    return std::nullopt;
  APInt StrideBytes = VariableOffsets.front().second;
  const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
  if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
    return std::nullopt;
  const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
  if (N > JumpTableSizeThreshold)
    return std::nullopt;

  JumpTableTy JumpTable;
  JumpTable.Index = VariableOffsets.front().first;
  JumpTable.Funcs.reserve(N);
  for (uint64_t Index = 0; Index < N; ++Index) {
    // ConstantOffset is zero.
    APInt Offset = Index * StrideBytes;
    Constant *C =
        ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL);
    auto *Func = dyn_cast_or_null<Function>(C);
    if (!Func || Func->isDeclaration() ||
        Func->getInstructionCount() > FunctionSizeThreshold)
      return std::nullopt;
    JumpTable.Funcs.push_back(Func);
  }
  return JumpTable;
}

static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
                                  DomTreeUpdater &DTU,
                                  OptimizationRemarkEmitter &ORE) {
  const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());

  SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
  BasicBlock *BB = CB->getParent();
  BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
                                BB->getName() + Twine(".tail"));
  DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
  BB->getTerminator()->eraseFromParent();

  Function &F = *BB->getParent();
  BasicBlock *BBUnreachable = BasicBlock::Create(
      F.getContext(), "default.switch.case.unreachable", &F, Tail);
  IRBuilder<> BuilderUnreachable(BBUnreachable);
  BuilderUnreachable.CreateUnreachable();

  IRBuilder<> Builder(BB);
  SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
  DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});

  IRBuilder<> BuilderTail(CB);
  PHINode *PHI =
      IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());

  for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
    BasicBlock *B = BasicBlock::Create(Func->getContext(),
                                       "call." + Twine(Index), &F, Tail);
    DTUpdates.push_back({DominatorTree::Insert, BB, B});
    DTUpdates.push_back({DominatorTree::Insert, B, Tail});

    CallBase *Call = cast<CallBase>(CB->clone());
    Call->setCalledFunction(Func);
    Call->insertInto(B, B->end());
    Switch->addCase(
        cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
    BranchInst::Create(Tail, B);
    if (PHI)
      PHI->addIncoming(Call, B);
  }
  DTU.applyUpdates(DTUpdates);
  ORE.emit([&]() {
    return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
           << "expanded indirect call into switch";
  });
  if (PHI)
    CB->replaceAllUsesWith(PHI);
  CB->eraseFromParent();
  return Tail;
}

PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
                                             FunctionAnalysisManager &AM) {
  OptimizationRemarkEmitter &ORE =
      AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
  DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
  PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
  DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
  bool Changed = false;
  for (BasicBlock &BB : make_early_inc_range(F)) {
    BasicBlock *CurrentBB = &BB;
    while (CurrentBB) {
      BasicBlock *SplittedOutTail = nullptr;
      for (Instruction &I : make_early_inc_range(*CurrentBB)) {
        auto *Call = dyn_cast<CallInst>(&I);
        if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
          continue;
        auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
        // Skip atomic or volatile loads.
        if (!L || !L->isSimple())
          continue;
        auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
        if (!GEP)
          continue;
        auto *PtrTy = dyn_cast<PointerType>(L->getType());
        assert(PtrTy && "call operand must be a pointer");
        std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
        if (!JumpTable)
          continue;
        SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
        Changed = true;
        break;
      }
      CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
    }
  }

  if (!Changed)
    return PreservedAnalyses::all();

  PreservedAnalyses PA;
  if (DT)
    PA.preserve<DominatorTreeAnalysis>();
  if (PDT)
    PA.preserve<PostDominatorTreeAnalysis>();
  return PA;
}
