//===- AMDGPUDisassembler.hpp - Disassembler for AMDGPU ISA -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
///
/// This file contains declaration for AMDGPU ISA disassembler
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_AMDGPU_DISASSEMBLER_AMDGPUDISASSEMBLER_H
#define LLVM_LIB_TARGET_AMDGPU_DISASSEMBLER_AMDGPUDISASSEMBLER_H

#include "SIDefines.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/MC/MCDisassembler/MCDisassembler.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCInstrInfo.h"
#include "llvm/Support/DataExtractor.h"
#include <memory>

namespace llvm {

class MCAsmInfo;
class MCInst;
class MCOperand;
class MCSubtargetInfo;
class Twine;

// Exposes an interface expected by autogenerated code in
// FixedLenDecoderEmitter
class DecoderUInt128 {
private:
  uint64_t Lo = 0;
  uint64_t Hi = 0;

public:
  DecoderUInt128() = default;
  DecoderUInt128(uint64_t Lo, uint64_t Hi = 0) : Lo(Lo), Hi(Hi) {}
  operator bool() const { return Lo || Hi; }
  void insertBits(uint64_t SubBits, unsigned BitPosition, unsigned NumBits) {
    assert(NumBits && NumBits <= 64);
    assert(SubBits >> 1 >> (NumBits - 1) == 0);
    assert(BitPosition < 128);
    if (BitPosition < 64) {
      Lo |= SubBits << BitPosition;
      Hi |= SubBits >> 1 >> (63 - BitPosition);
    } else {
      Hi |= SubBits << (BitPosition - 64);
    }
  }
  uint64_t extractBitsAsZExtValue(unsigned NumBits,
                                  unsigned BitPosition) const {
    assert(NumBits && NumBits <= 64);
    assert(BitPosition < 128);
    uint64_t Val;
    if (BitPosition < 64)
      Val = Lo >> BitPosition | Hi << 1 << (63 - BitPosition);
    else
      Val = Hi >> (BitPosition - 64);
    return Val & ((uint64_t(2) << (NumBits - 1)) - 1);
  }
  DecoderUInt128 operator&(const DecoderUInt128 &RHS) const {
    return DecoderUInt128(Lo & RHS.Lo, Hi & RHS.Hi);
  }
  DecoderUInt128 operator&(const uint64_t &RHS) const {
    return *this & DecoderUInt128(RHS);
  }
  DecoderUInt128 operator~() const { return DecoderUInt128(~Lo, ~Hi); }
  bool operator==(const DecoderUInt128 &RHS) {
    return Lo == RHS.Lo && Hi == RHS.Hi;
  }
  bool operator!=(const DecoderUInt128 &RHS) {
    return Lo != RHS.Lo || Hi != RHS.Hi;
  }
  bool operator!=(const int &RHS) {
    return *this != DecoderUInt128(RHS);
  }
  friend raw_ostream &operator<<(raw_ostream &OS, const DecoderUInt128 &RHS) {
    return OS << APInt(128, {RHS.Lo, RHS.Hi});
  }
};

//===----------------------------------------------------------------------===//
// AMDGPUDisassembler
//===----------------------------------------------------------------------===//

class AMDGPUDisassembler : public MCDisassembler {
private:
  std::unique_ptr<MCInstrInfo const> const MCII;
  const MCRegisterInfo &MRI;
  const MCAsmInfo &MAI;
  const unsigned TargetMaxInstBytes;
  mutable ArrayRef<uint8_t> Bytes;
  mutable uint32_t Literal;
  mutable uint64_t Literal64;
  mutable bool HasLiteral;
  mutable std::optional<bool> EnableWavefrontSize32;
  unsigned CodeObjectVersion;
  const MCExpr *UCVersionW64Expr;
  const MCExpr *UCVersionW32Expr;
  const MCExpr *UCVersionMDPExpr;

  const MCExpr *createConstantSymbolExpr(StringRef Id, int64_t Val);

public:
  AMDGPUDisassembler(const MCSubtargetInfo &STI, MCContext &Ctx,
                     MCInstrInfo const *MCII);
  ~AMDGPUDisassembler() override = default;

  void setABIVersion(unsigned Version) override;

  DecodeStatus getInstruction(MCInst &MI, uint64_t &Size,
                              ArrayRef<uint8_t> Bytes, uint64_t Address,
                              raw_ostream &CS) const override;

  const char* getRegClassName(unsigned RegClassID) const;

  MCOperand createRegOperand(unsigned int RegId) const;
  MCOperand createRegOperand(unsigned RegClassID, unsigned Val) const;
  MCOperand createSRegOperand(unsigned SRegClassID, unsigned Val) const;
  MCOperand createVGPR16Operand(unsigned RegIdx, bool IsHi) const;

  MCOperand errOperand(unsigned V, const Twine& ErrMsg) const;

  template <typename InsnType>
  DecodeStatus tryDecodeInst(const uint8_t *Table, MCInst &MI, InsnType Inst,
                             uint64_t Address, raw_ostream &Comments) const {
    assert(MI.getOpcode() == 0);
    assert(MI.getNumOperands() == 0);
    MCInst TmpInst;
    HasLiteral = false;
    const auto SavedBytes = Bytes;

    SmallString<64> LocalComments;
    raw_svector_ostream LocalCommentStream(LocalComments);
    CommentStream = &LocalCommentStream;

    DecodeStatus Res =
        decodeInstruction(Table, TmpInst, Inst, Address, this, STI);

    CommentStream = nullptr;

    if (Res != Fail) {
      MI = TmpInst;
      Comments << LocalComments;
      return MCDisassembler::Success;
    }
    Bytes = SavedBytes;
    return MCDisassembler::Fail;
  }

  template <typename InsnType>
  DecodeStatus tryDecodeInst(const uint8_t *Table1, const uint8_t *Table2,
                             MCInst &MI, InsnType Inst, uint64_t Address,
                             raw_ostream &Comments) const {
    for (const uint8_t *T : {Table1, Table2}) {
      if (DecodeStatus Res = tryDecodeInst(T, MI, Inst, Address, Comments))
        return Res;
    }
    return MCDisassembler::Fail;
  }

  Expected<bool> onSymbolStart(SymbolInfoTy &Symbol, uint64_t &Size,
                               ArrayRef<uint8_t> Bytes,
                               uint64_t Address) const override;

  Expected<bool> decodeKernelDescriptor(StringRef KdName,
                                        ArrayRef<uint8_t> Bytes,
                                        uint64_t KdAddress) const;

  Expected<bool>
  decodeKernelDescriptorDirective(DataExtractor::Cursor &Cursor,
                                  ArrayRef<uint8_t> Bytes,
                                  raw_string_ostream &KdStream) const;

  /// Decode as directives that handle COMPUTE_PGM_RSRC1.
  /// \param FourByteBuffer - Bytes holding contents of COMPUTE_PGM_RSRC1.
  /// \param KdStream       - Stream to write the disassembled directives to.
  // NOLINTNEXTLINE(readability-identifier-naming)
  Expected<bool> decodeCOMPUTE_PGM_RSRC1(uint32_t FourByteBuffer,
                                         raw_string_ostream &KdStream) const;

  /// Decode as directives that handle COMPUTE_PGM_RSRC2.
  /// \param FourByteBuffer - Bytes holding contents of COMPUTE_PGM_RSRC2.
  /// \param KdStream       - Stream to write the disassembled directives to.
  // NOLINTNEXTLINE(readability-identifier-naming)
  Expected<bool> decodeCOMPUTE_PGM_RSRC2(uint32_t FourByteBuffer,
                                         raw_string_ostream &KdStream) const;

  /// Decode as directives that handle COMPUTE_PGM_RSRC3.
  /// \param FourByteBuffer - Bytes holding contents of COMPUTE_PGM_RSRC3.
  /// \param KdStream       - Stream to write the disassembled directives to.
  // NOLINTNEXTLINE(readability-identifier-naming)
  Expected<bool> decodeCOMPUTE_PGM_RSRC3(uint32_t FourByteBuffer,
                                         raw_string_ostream &KdStream) const;

  void convertEXPInst(MCInst &MI) const;
  void convertVINTERPInst(MCInst &MI) const;
  void convertFMAanyK(MCInst &MI, int ImmLitIdx) const;
  void convertSDWAInst(MCInst &MI) const;
  void convertDPP8Inst(MCInst &MI) const;
  void convertMIMGInst(MCInst &MI) const;
  void convertVOP3DPPInst(MCInst &MI) const;
  void convertVOP3PDPPInst(MCInst &MI) const;
  void convertVOPCDPPInst(MCInst &MI) const;
  void convertMacDPPInst(MCInst &MI) const;
  void convertTrue16OpSel(MCInst &MI) const;

  enum OpWidthTy {
    OPW32,
    OPW64,
    OPW96,
    OPW128,
    OPW160,
    OPW256,
    OPW288,
    OPW320,
    OPW352,
    OPW384,
    OPW512,
    OPW1024,
    OPW16,
    OPWV216,
    OPWV232,
    OPW_LAST_,
    OPW_FIRST_ = OPW32
  };

  unsigned getVgprClassId(const OpWidthTy Width) const;
  unsigned getAgprClassId(const OpWidthTy Width) const;
  unsigned getSgprClassId(const OpWidthTy Width) const;
  unsigned getTtmpClassId(const OpWidthTy Width) const;

  static MCOperand decodeIntImmed(unsigned Imm);
  static MCOperand decodeFPImmed(unsigned ImmWidth, unsigned Imm,
                                 AMDGPU::OperandSemantics Sema);

  MCOperand decodeMandatoryLiteralConstant(unsigned Imm) const;
  MCOperand decodeLiteralConstant(bool ExtendFP64) const;

  MCOperand decodeSrcOp(
      const OpWidthTy Width, unsigned Val, bool MandatoryLiteral = false,
      unsigned ImmWidth = 0,
      AMDGPU::OperandSemantics Sema = AMDGPU::OperandSemantics::INT) const;

  MCOperand decodeNonVGPRSrcOp(
      const OpWidthTy Width, unsigned Val, bool MandatoryLiteral = false,
      unsigned ImmWidth = 0,
      AMDGPU::OperandSemantics Sema = AMDGPU::OperandSemantics::INT) const;

  MCOperand decodeVOPDDstYOp(MCInst &Inst, unsigned Val) const;
  MCOperand decodeSpecialReg32(unsigned Val) const;
  MCOperand decodeSpecialReg64(unsigned Val) const;

  MCOperand decodeSDWASrc(const OpWidthTy Width, unsigned Val,
                          unsigned ImmWidth,
                          AMDGPU::OperandSemantics Sema) const;
  MCOperand decodeSDWASrc16(unsigned Val) const;
  MCOperand decodeSDWASrc32(unsigned Val) const;
  MCOperand decodeSDWAVopcDst(unsigned Val) const;

  MCOperand decodeBoolReg(unsigned Val) const;
  MCOperand decodeSplitBarrier(unsigned Val) const;
  MCOperand decodeDpp8FI(unsigned Val) const;

  MCOperand decodeVersionImm(unsigned Imm) const;

  int getTTmpIdx(unsigned Val) const;

  const MCInstrInfo *getMCII() const { return MCII.get(); }

  bool isVI() const;
  bool isGFX9() const;
  bool isGFX90A() const;
  bool isGFX9Plus() const;
  bool isGFX10() const;
  bool isGFX10Plus() const;
  bool isGFX11() const;
  bool isGFX11Plus() const;
  bool isGFX12() const;
  bool isGFX12Plus() const;

  bool hasArchitectedFlatScratch() const;
  bool hasKernargPreload() const;

  bool isMacDPP(MCInst &MI) const;
};

//===----------------------------------------------------------------------===//
// AMDGPUSymbolizer
//===----------------------------------------------------------------------===//

class AMDGPUSymbolizer : public MCSymbolizer {
private:
  void *DisInfo;
  std::vector<uint64_t> ReferencedAddresses;

public:
  AMDGPUSymbolizer(MCContext &Ctx, std::unique_ptr<MCRelocationInfo> &&RelInfo,
                   void *disInfo)
                   : MCSymbolizer(Ctx, std::move(RelInfo)), DisInfo(disInfo) {}

  bool tryAddingSymbolicOperand(MCInst &Inst, raw_ostream &cStream,
                                int64_t Value, uint64_t Address, bool IsBranch,
                                uint64_t Offset, uint64_t OpSize,
                                uint64_t InstSize) override;

  void tryAddingPcLoadReferenceComment(raw_ostream &cStream,
                                       int64_t Value,
                                       uint64_t Address) override;

  ArrayRef<uint64_t> getReferencedAddresses() const override {
    return ReferencedAddresses;
  }
};

} // end namespace llvm

#endif // LLVM_LIB_TARGET_AMDGPU_DISASSEMBLER_AMDGPUDISASSEMBLER_H
