//===- X86RegisterBankInfo.cpp -----------------------------------*- C++ -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
/// \file
/// This file implements the targeting of the RegisterBankInfo class for X86.
/// \todo This should be generated by TableGen.
//===----------------------------------------------------------------------===//

#include "X86RegisterBankInfo.h"
#include "X86InstrInfo.h"
#include "X86Subtarget.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RegisterBank.h"
#include "llvm/CodeGen/RegisterBankInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/IR/IntrinsicsX86.h"

#define GET_TARGET_REGBANK_IMPL
#include "X86GenRegisterBank.inc"

using namespace llvm;
// This file will be TableGen'ed at some point.
#define GET_TARGET_REGBANK_INFO_IMPL
#include "X86GenRegisterBankInfo.def"

X86RegisterBankInfo::X86RegisterBankInfo(const TargetRegisterInfo &TRI) {

  // validate RegBank initialization.
  const RegisterBank &RBGPR = getRegBank(X86::GPRRegBankID);
  (void)RBGPR;
  assert(&X86::GPRRegBank == &RBGPR && "Incorrect RegBanks inizalization.");

  // The GPR register bank is fully defined by all the registers in
  // GR64 + its subclasses.
  assert(RBGPR.covers(*TRI.getRegClass(X86::GR64RegClassID)) &&
         "Subclass not added?");
  assert(getMaximumSize(RBGPR.getID()) == 64 &&
         "GPRs should hold up to 64-bit");
}

const RegisterBank &
X86RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
                                            LLT) const {

  if (X86::GR8RegClass.hasSubClassEq(&RC) ||
      X86::GR16RegClass.hasSubClassEq(&RC) ||
      X86::GR32RegClass.hasSubClassEq(&RC) ||
      X86::GR64RegClass.hasSubClassEq(&RC) ||
      X86::LOW32_ADDR_ACCESSRegClass.hasSubClassEq(&RC) ||
      X86::LOW32_ADDR_ACCESS_RBPRegClass.hasSubClassEq(&RC))
    return getRegBank(X86::GPRRegBankID);

  if (X86::FR32XRegClass.hasSubClassEq(&RC) ||
      X86::FR64XRegClass.hasSubClassEq(&RC) ||
      X86::VR128XRegClass.hasSubClassEq(&RC) ||
      X86::VR256XRegClass.hasSubClassEq(&RC) ||
      X86::VR512RegClass.hasSubClassEq(&RC))
    return getRegBank(X86::VECRRegBankID);

  if (X86::RFP80RegClass.hasSubClassEq(&RC) ||
      X86::RFP32RegClass.hasSubClassEq(&RC) ||
      X86::RFP64RegClass.hasSubClassEq(&RC))
    return getRegBank(X86::PSRRegBankID);

  llvm_unreachable("Unsupported register kind yet.");
}

// \returns true if a given intrinsic only uses and defines FPRs.
static bool isFPIntrinsic(const MachineRegisterInfo &MRI,
                          const MachineInstr &MI) {
  // TODO: Add more intrinsics.
  switch (cast<GIntrinsic>(MI).getIntrinsicID()) {
  default:
    return false;
  // SSE1
  case Intrinsic::x86_sse_rcp_ss:
  case Intrinsic::x86_sse_rcp_ps:
  case Intrinsic::x86_sse_rsqrt_ss:
  case Intrinsic::x86_sse_rsqrt_ps:
  case Intrinsic::x86_sse_min_ss:
  case Intrinsic::x86_sse_min_ps:
  case Intrinsic::x86_sse_max_ss:
  case Intrinsic::x86_sse_max_ps:
    return true;
  }
  return false;
}

bool X86RegisterBankInfo::hasFPConstraints(const MachineInstr &MI,
                                           const MachineRegisterInfo &MRI,
                                           const TargetRegisterInfo &TRI,
                                           unsigned Depth) const {
  unsigned Op = MI.getOpcode();
  if (Op == TargetOpcode::G_INTRINSIC && isFPIntrinsic(MRI, MI))
    return true;

  // Do we have an explicit floating point instruction?
  if (isPreISelGenericFloatingPointOpcode(Op))
    return true;

  // No. Check if we have a copy-like instruction. If we do, then we could
  // still be fed by floating point instructions.
  if (Op != TargetOpcode::COPY && !MI.isPHI() &&
      !isPreISelGenericOptimizationHint(Op))
    return false;

  // Check if we already know the register bank.
  auto *RB = getRegBank(MI.getOperand(0).getReg(), MRI, TRI);
  if (RB == &getRegBank(X86::PSRRegBankID))
    return true;
  if (RB == &getRegBank(X86::GPRRegBankID))
    return false;

  // We don't know anything.
  //
  // If we have a phi, we may be able to infer that it will be assigned a fp
  // type based off of its inputs.
  if (!MI.isPHI() || Depth > MaxFPRSearchDepth)
    return false;

  return any_of(MI.explicit_uses(), [&](const MachineOperand &Op) {
    return Op.isReg() &&
           onlyDefinesFP(*MRI.getVRegDef(Op.getReg()), MRI, TRI, Depth + 1);
  });
}

bool X86RegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
                                     const MachineRegisterInfo &MRI,
                                     const TargetRegisterInfo &TRI,
                                     unsigned Depth) const {
  switch (MI.getOpcode()) {
  case TargetOpcode::G_FPTOSI:
  case TargetOpcode::G_FPTOUI:
  case TargetOpcode::G_FCMP:
  case TargetOpcode::G_LROUND:
  case TargetOpcode::G_LLROUND:
  case TargetOpcode::G_INTRINSIC_TRUNC:
  case TargetOpcode::G_INTRINSIC_ROUND:
    return true;
  default:
    break;
  }
  return hasFPConstraints(MI, MRI, TRI, Depth);
}

bool X86RegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
                                        const MachineRegisterInfo &MRI,
                                        const TargetRegisterInfo &TRI,
                                        unsigned Depth) const {
  switch (MI.getOpcode()) {
  case TargetOpcode::G_SITOFP:
  case TargetOpcode::G_UITOFP:
    return true;
  default:
    break;
  }
  return hasFPConstraints(MI, MRI, TRI, Depth);
}

X86GenRegisterBankInfo::PartialMappingIdx
X86GenRegisterBankInfo::getPartialMappingIdx(const MachineInstr &MI,
                                             const LLT &Ty, bool isFP) {
  const MachineFunction *MF = MI.getMF();
  const X86Subtarget *ST = &MF->getSubtarget<X86Subtarget>();
  bool HasSSE1 = ST->hasSSE1();
  bool HasSSE2 = ST->hasSSE2();
  // 80 bits is only generated for X87 floating points.
  if (Ty.getSizeInBits() == 80)
    isFP = true;
  if ((Ty.isScalar() && !isFP) || Ty.isPointer()) {
    switch (Ty.getSizeInBits()) {
    case 1:
    case 8:
      return PMI_GPR8;
    case 16:
      return PMI_GPR16;
    case 32:
      return PMI_GPR32;
    case 64:
      return PMI_GPR64;
    case 128:
      return PMI_VEC128;
      break;
    default:
      llvm_unreachable("Unsupported register size.");
    }
  } else if (Ty.isScalar()) {
    switch (Ty.getSizeInBits()) {
    case 32:
      return HasSSE1 ? PMI_FP32 : PMI_PSR32;
    case 64:
      return HasSSE2 ? PMI_FP64 : PMI_PSR64;
    case 128:
      return PMI_VEC128;
    case 80:
      return PMI_PSR80;
    default:
      llvm_unreachable("Unsupported register size.");
    }
  } else {
    switch (Ty.getSizeInBits()) {
    case 128:
      return PMI_VEC128;
    case 256:
      return PMI_VEC256;
    case 512:
      return PMI_VEC512;
    default:
      llvm_unreachable("Unsupported register size.");
    }
  }

  return PMI_None;
}

void X86RegisterBankInfo::getInstrPartialMappingIdxs(
    const MachineInstr &MI, const MachineRegisterInfo &MRI, const bool isFP,
    SmallVectorImpl<PartialMappingIdx> &OpRegBankIdx) {

  unsigned NumOperands = MI.getNumOperands();
  for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
    auto &MO = MI.getOperand(Idx);
    if (!MO.isReg() || !MO.getReg())
      OpRegBankIdx[Idx] = PMI_None;
    else
      OpRegBankIdx[Idx] =
          getPartialMappingIdx(MI, MRI.getType(MO.getReg()), isFP);
  }
}

bool X86RegisterBankInfo::getInstrValueMapping(
    const MachineInstr &MI,
    const SmallVectorImpl<PartialMappingIdx> &OpRegBankIdx,
    SmallVectorImpl<const ValueMapping *> &OpdsMapping) {

  unsigned NumOperands = MI.getNumOperands();
  for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
    if (!MI.getOperand(Idx).isReg())
      continue;
    if (!MI.getOperand(Idx).getReg())
      continue;

    auto Mapping = getValueMapping(OpRegBankIdx[Idx], 1);
    if (!Mapping->isValid())
      return false;

    OpdsMapping[Idx] = Mapping;
  }
  return true;
}

const RegisterBankInfo::InstructionMapping &
X86RegisterBankInfo::getSameOperandsMapping(const MachineInstr &MI,
                                            bool isFP) const {
  const MachineFunction &MF = *MI.getParent()->getParent();
  const MachineRegisterInfo &MRI = MF.getRegInfo();

  unsigned NumOperands = MI.getNumOperands();
  LLT Ty = MRI.getType(MI.getOperand(0).getReg());

  if (NumOperands != 3 || (Ty != MRI.getType(MI.getOperand(1).getReg())) ||
      (Ty != MRI.getType(MI.getOperand(2).getReg())))
    llvm_unreachable("Unsupported operand mapping yet.");

  auto Mapping = getValueMapping(getPartialMappingIdx(MI, Ty, isFP), 3);
  return getInstructionMapping(DefaultMappingID, 1, Mapping, NumOperands);
}

const RegisterBankInfo::InstructionMapping &
X86RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
  const MachineFunction &MF = *MI.getParent()->getParent();
  const TargetSubtargetInfo &STI = MF.getSubtarget();
  const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
  const MachineRegisterInfo &MRI = MF.getRegInfo();
  unsigned Opc = MI.getOpcode();

  // Try the default logic for non-generic instructions that are either
  // copies or already have some operands assigned to banks.
  if (!isPreISelGenericOpcode(Opc) || Opc == TargetOpcode::G_PHI) {
    const InstructionMapping &Mapping = getInstrMappingImpl(MI);
    if (Mapping.isValid())
      return Mapping;
  }

  switch (Opc) {
  case TargetOpcode::G_ADD:
  case TargetOpcode::G_SUB:
  case TargetOpcode::G_MUL:
    return getSameOperandsMapping(MI, false);
  case TargetOpcode::G_FADD:
  case TargetOpcode::G_FSUB:
  case TargetOpcode::G_FMUL:
  case TargetOpcode::G_FDIV:
    return getSameOperandsMapping(MI, true);
  case TargetOpcode::G_SHL:
  case TargetOpcode::G_LSHR:
  case TargetOpcode::G_ASHR: {
    unsigned NumOperands = MI.getNumOperands();
    LLT Ty = MRI.getType(MI.getOperand(0).getReg());

    auto Mapping = getValueMapping(getPartialMappingIdx(MI, Ty, false), 3);
    return getInstructionMapping(DefaultMappingID, 1, Mapping, NumOperands);
  }
  default:
    break;
  }

  unsigned NumOperands = MI.getNumOperands();
  SmallVector<PartialMappingIdx, 4> OpRegBankIdx(NumOperands);

  switch (Opc) {
  case TargetOpcode::G_FPEXT:
  case TargetOpcode::G_FPTRUNC:
  case TargetOpcode::G_FCONSTANT:
    // Instruction having only floating-point operands (all scalars in
    // VECRReg)
    getInstrPartialMappingIdxs(MI, MRI, /* isFP= */ true, OpRegBankIdx);
    break;
  case TargetOpcode::G_SITOFP:
  case TargetOpcode::G_FPTOSI: {
    // Some of the floating-point instructions have mixed GPR and FP
    // operands: fine-tune the computed mapping.
    auto &Op0 = MI.getOperand(0);
    auto &Op1 = MI.getOperand(1);
    const LLT Ty0 = MRI.getType(Op0.getReg());
    const LLT Ty1 = MRI.getType(Op1.getReg());

    bool FirstArgIsFP = Opc == TargetOpcode::G_SITOFP;
    bool SecondArgIsFP = Opc == TargetOpcode::G_FPTOSI;
    OpRegBankIdx[0] = getPartialMappingIdx(MI, Ty0, /* isFP= */ FirstArgIsFP);
    OpRegBankIdx[1] = getPartialMappingIdx(MI, Ty1, /* isFP= */ SecondArgIsFP);
    break;
  }
  case TargetOpcode::G_FCMP: {
    LLT Ty1 = MRI.getType(MI.getOperand(2).getReg());
    LLT Ty2 = MRI.getType(MI.getOperand(3).getReg());
    (void)Ty2;
    assert(Ty1.getSizeInBits() == Ty2.getSizeInBits() &&
           "Mismatched operand sizes for G_FCMP");

    unsigned Size = Ty1.getSizeInBits();
    (void)Size;
    assert((Size == 32 || Size == 64) && "Unsupported size for G_FCMP");

    auto FpRegBank = getPartialMappingIdx(MI, Ty1, /* isFP= */ true);
    OpRegBankIdx = {PMI_GPR8,
                    /* Predicate */ PMI_None, FpRegBank, FpRegBank};
    break;
  }
  case TargetOpcode::G_TRUNC:
  case TargetOpcode::G_ANYEXT: {
    auto &Op0 = MI.getOperand(0);
    auto &Op1 = MI.getOperand(1);
    const LLT Ty0 = MRI.getType(Op0.getReg());
    const LLT Ty1 = MRI.getType(Op1.getReg());

    bool isFPTrunc = (Ty0.getSizeInBits() == 32 || Ty0.getSizeInBits() == 64) &&
                     Ty1.getSizeInBits() == 128 && Opc == TargetOpcode::G_TRUNC;
    bool isFPAnyExt =
        Ty0.getSizeInBits() == 128 &&
        (Ty1.getSizeInBits() == 32 || Ty1.getSizeInBits() == 64) &&
        Opc == TargetOpcode::G_ANYEXT;

    getInstrPartialMappingIdxs(MI, MRI, /* isFP= */ isFPTrunc || isFPAnyExt,
                               OpRegBankIdx);
    break;
  }
  case TargetOpcode::G_LOAD: {
    // Check if that load feeds fp instructions.
    // In that case, we want the default mapping to be on FPR
    // instead of blind map every scalar to GPR.
    bool IsFP = any_of(MRI.use_nodbg_instructions(cast<GLoad>(MI).getDstReg()),
                       [&](const MachineInstr &UseMI) {
                         // If we have at least one direct use in a FP
                         // instruction, assume this was a floating point load
                         // in the IR. If it was not, we would have had a
                         // bitcast before reaching that instruction.
                         return onlyUsesFP(UseMI, MRI, TRI);
                       });
    getInstrPartialMappingIdxs(MI, MRI, IsFP, OpRegBankIdx);
    break;
  }
  case TargetOpcode::G_STORE: {
    // Check if that store is fed by fp instructions.
    Register VReg = cast<GStore>(MI).getValueReg();
    if (!VReg)
      break;
    MachineInstr *DefMI = MRI.getVRegDef(VReg);
    bool IsFP = onlyDefinesFP(*DefMI, MRI, TRI);
    getInstrPartialMappingIdxs(MI, MRI, IsFP, OpRegBankIdx);
    break;
  }
  default:
    // Track the bank of each register, use NotFP mapping (all scalars in
    // GPRs)
    getInstrPartialMappingIdxs(MI, MRI, /* isFP= */ false, OpRegBankIdx);
    break;
  }

  // Finally construct the computed mapping.
  SmallVector<const ValueMapping *, 8> OpdsMapping(NumOperands);
  if (!getInstrValueMapping(MI, OpRegBankIdx, OpdsMapping))
    return getInvalidInstructionMapping();

  return getInstructionMapping(DefaultMappingID, /* Cost */ 1,
                               getOperandsMapping(OpdsMapping), NumOperands);
}

void X86RegisterBankInfo::applyMappingImpl(
    MachineIRBuilder &Builder, const OperandsMapper &OpdMapper) const {
  return applyDefaultMapping(OpdMapper);
}

RegisterBankInfo::InstructionMappings
X86RegisterBankInfo::getInstrAlternativeMappings(const MachineInstr &MI) const {

  const MachineFunction &MF = *MI.getParent()->getParent();
  const TargetSubtargetInfo &STI = MF.getSubtarget();
  const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
  const MachineRegisterInfo &MRI = MF.getRegInfo();

  switch (MI.getOpcode()) {
  case TargetOpcode::G_LOAD:
  case TargetOpcode::G_STORE:
  case TargetOpcode::G_IMPLICIT_DEF: {
    // we going to try to map 32/64/80 bit to PMI_FP32/PMI_FP64/PMI_FP80
    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
    if (Size != 32 && Size != 64 && Size != 80)
      break;

    unsigned NumOperands = MI.getNumOperands();

    // Track the bank of each register, use FP mapping (all scalars in VEC)
    SmallVector<PartialMappingIdx, 4> OpRegBankIdx(NumOperands);
    getInstrPartialMappingIdxs(MI, MRI, /* isFP= */ true, OpRegBankIdx);

    // Finally construct the computed mapping.
    SmallVector<const ValueMapping *, 8> OpdsMapping(NumOperands);
    if (!getInstrValueMapping(MI, OpRegBankIdx, OpdsMapping))
      break;

    const RegisterBankInfo::InstructionMapping &Mapping = getInstructionMapping(
        /*ID*/ 1, /*Cost*/ 1, getOperandsMapping(OpdsMapping), NumOperands);
    InstructionMappings AltMappings;
    AltMappings.push_back(&Mapping);
    return AltMappings;
  }
  default:
    break;
  }
  return RegisterBankInfo::getInstrAlternativeMappings(MI);
}
