//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains class to help build DXIL op functions.
//===----------------------------------------------------------------------===//

#include "DXILOpBuilder.h"
#include "DXILConstants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/DXILABI.h"
#include "llvm/Support/ErrorHandling.h"

using namespace llvm;
using namespace llvm::dxil;

constexpr StringLiteral DXILOpNamePrefix = "dx.op.";

namespace {

enum OverloadKind : uint16_t {
  VOID = 1,
  HALF = 1 << 1,
  FLOAT = 1 << 2,
  DOUBLE = 1 << 3,
  I1 = 1 << 4,
  I8 = 1 << 5,
  I16 = 1 << 6,
  I32 = 1 << 7,
  I64 = 1 << 8,
  UserDefineType = 1 << 9,
  ObjectType = 1 << 10,
};

} // namespace

static const char *getOverloadTypeName(OverloadKind Kind) {
  switch (Kind) {
  case OverloadKind::HALF:
    return "f16";
  case OverloadKind::FLOAT:
    return "f32";
  case OverloadKind::DOUBLE:
    return "f64";
  case OverloadKind::I1:
    return "i1";
  case OverloadKind::I8:
    return "i8";
  case OverloadKind::I16:
    return "i16";
  case OverloadKind::I32:
    return "i32";
  case OverloadKind::I64:
    return "i64";
  case OverloadKind::VOID:
  case OverloadKind::ObjectType:
  case OverloadKind::UserDefineType:
    break;
  }
  llvm_unreachable("invalid overload type for name");
  return "void";
}

static OverloadKind getOverloadKind(Type *Ty) {
  Type::TypeID T = Ty->getTypeID();
  switch (T) {
  case Type::VoidTyID:
    return OverloadKind::VOID;
  case Type::HalfTyID:
    return OverloadKind::HALF;
  case Type::FloatTyID:
    return OverloadKind::FLOAT;
  case Type::DoubleTyID:
    return OverloadKind::DOUBLE;
  case Type::IntegerTyID: {
    IntegerType *ITy = cast<IntegerType>(Ty);
    unsigned Bits = ITy->getBitWidth();
    switch (Bits) {
    case 1:
      return OverloadKind::I1;
    case 8:
      return OverloadKind::I8;
    case 16:
      return OverloadKind::I16;
    case 32:
      return OverloadKind::I32;
    case 64:
      return OverloadKind::I64;
    default:
      llvm_unreachable("invalid overload type");
      return OverloadKind::VOID;
    }
  }
  case Type::PointerTyID:
    return OverloadKind::UserDefineType;
  case Type::StructTyID:
    return OverloadKind::ObjectType;
  default:
    llvm_unreachable("invalid overload type");
    return OverloadKind::VOID;
  }
}

static std::string getTypeName(OverloadKind Kind, Type *Ty) {
  if (Kind < OverloadKind::UserDefineType) {
    return getOverloadTypeName(Kind);
  } else if (Kind == OverloadKind::UserDefineType) {
    StructType *ST = cast<StructType>(Ty);
    return ST->getStructName().str();
  } else if (Kind == OverloadKind::ObjectType) {
    StructType *ST = cast<StructType>(Ty);
    return ST->getStructName().str();
  } else {
    std::string Str;
    raw_string_ostream OS(Str);
    Ty->print(OS);
    return OS.str();
  }
}

// Static properties.
struct OpCodeProperty {
  dxil::OpCode OpCode;
  // Offset in DXILOpCodeNameTable.
  unsigned OpCodeNameOffset;
  dxil::OpCodeClass OpCodeClass;
  // Offset in DXILOpCodeClassNameTable.
  unsigned OpCodeClassNameOffset;
  uint16_t OverloadTys;
  llvm::Attribute::AttrKind FuncAttr;
  int OverloadParamIndex;        // parameter index which control the overload.
                                 // When < 0, should be only 1 overload type.
  unsigned NumOfParameters;      // Number of parameters include return value.
  unsigned ParameterTableOffset; // Offset in ParameterTable.
};

// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
// getOpCodeParameterKind which generated by tableGen.
#define DXIL_OP_OPERATION_TABLE
#include "DXILOperation.inc"
#undef DXIL_OP_OPERATION_TABLE

static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
                                         const OpCodeProperty &Prop) {
  if (Kind == OverloadKind::VOID) {
    return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
  }
  return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
          getTypeName(Kind, Ty))
      .str();
}

static std::string constructOverloadTypeName(OverloadKind Kind,
                                             StringRef TypeName) {
  if (Kind == OverloadKind::VOID)
    return TypeName.str();

  assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
  return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
}

static StructType *getOrCreateStructType(StringRef Name,
                                         ArrayRef<Type *> EltTys,
                                         LLVMContext &Ctx) {
  StructType *ST = StructType::getTypeByName(Ctx, Name);
  if (ST)
    return ST;

  return StructType::create(Ctx, EltTys, Name);
}

static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
  OverloadKind Kind = getOverloadKind(OverloadTy);
  std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
  Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
                         Type::getInt32Ty(Ctx)};
  return getOrCreateStructType(TypeName, FieldTypes, Ctx);
}

static StructType *getHandleType(LLVMContext &Ctx) {
  return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx),
                               Ctx);
}

static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
  auto &Ctx = OverloadTy->getContext();
  switch (Kind) {
  case ParameterKind::Void:
    return Type::getVoidTy(Ctx);
  case ParameterKind::Half:
    return Type::getHalfTy(Ctx);
  case ParameterKind::Float:
    return Type::getFloatTy(Ctx);
  case ParameterKind::Double:
    return Type::getDoubleTy(Ctx);
  case ParameterKind::I1:
    return Type::getInt1Ty(Ctx);
  case ParameterKind::I8:
    return Type::getInt8Ty(Ctx);
  case ParameterKind::I16:
    return Type::getInt16Ty(Ctx);
  case ParameterKind::I32:
    return Type::getInt32Ty(Ctx);
  case ParameterKind::I64:
    return Type::getInt64Ty(Ctx);
  case ParameterKind::Overload:
    return OverloadTy;
  case ParameterKind::ResourceRet:
    return getResRetType(OverloadTy, Ctx);
  case ParameterKind::DXILHandle:
    return getHandleType(Ctx);
  default:
    break;
  }
  llvm_unreachable("Invalid parameter kind");
  return nullptr;
}

/// Construct DXIL function type. This is the type of a function with
/// the following prototype
///     OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
/// <param-types> are constructed from types in Prop.
/// \param Prop  Structure containing DXIL Operation properties based on
///               its specification in DXIL.td.
/// \param OverloadTy Return type to be used to construct DXIL function type.
static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
                                           Type *ReturnTy, Type *OverloadTy) {
  SmallVector<Type *> ArgTys;

  auto ParamKinds = getOpCodeParameterKind(*Prop);

  // Add ReturnTy as return type of the function
  ArgTys.emplace_back(ReturnTy);

  // Add DXIL Opcode value type viz., Int32 as first argument
  ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));

  // Add DXIL Operation parameter types as specified in DXIL properties
  for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
    ParameterKind Kind = ParamKinds[I];
    ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
  }
  return FunctionType::get(
      ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
}

namespace llvm {
namespace dxil {

CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
                                          Type *OverloadTy,
                                          SmallVector<Value *> Args) {
  const OpCodeProperty *Prop = getOpCodeProperty(OpCode);

  OverloadKind Kind = getOverloadKind(OverloadTy);
  if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
    report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
  }

  std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
  FunctionCallee DXILFn;
  // Get the function with name DXILFnName, if one exists
  if (auto *Func = M.getFunction(DXILFnName)) {
    DXILFn = FunctionCallee(Func);
  } else {
    // Construct and add a function with name DXILFnName
    FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
    DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
  }

  return B.CreateCall(DXILFn, Args);
}

Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {

  const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
  // If DXIL Op has no overload parameter, just return the
  // precise return type specified.
  if (Prop->OverloadParamIndex < 0) {
    auto &Ctx = FT->getContext();
    switch (Prop->OverloadTys) {
    case OverloadKind::VOID:
      return Type::getVoidTy(Ctx);
    case OverloadKind::HALF:
      return Type::getHalfTy(Ctx);
    case OverloadKind::FLOAT:
      return Type::getFloatTy(Ctx);
    case OverloadKind::DOUBLE:
      return Type::getDoubleTy(Ctx);
    case OverloadKind::I1:
      return Type::getInt1Ty(Ctx);
    case OverloadKind::I8:
      return Type::getInt8Ty(Ctx);
    case OverloadKind::I16:
      return Type::getInt16Ty(Ctx);
    case OverloadKind::I32:
      return Type::getInt32Ty(Ctx);
    case OverloadKind::I64:
      return Type::getInt64Ty(Ctx);
    default:
      llvm_unreachable("invalid overload type");
      return nullptr;
    }
  }

  // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
  Type *OverloadType = FT->getReturnType();
  if (Prop->OverloadParamIndex != 0) {
    // Skip Return Type.
    OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1);
  }

  auto ParamKinds = getOpCodeParameterKind(*Prop);
  auto Kind = ParamKinds[Prop->OverloadParamIndex];
  // For ResRet and CBufferRet, OverloadTy is in field of StructType.
  if (Kind == ParameterKind::CBufferRet ||
      Kind == ParameterKind::ResourceRet) {
    auto *ST = cast<StructType>(OverloadType);
    OverloadType = ST->getElementType(0);
  }
  return OverloadType;
}

const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
  return ::getOpCodeName(DXILOp);
}
} // namespace dxil
} // namespace llvm
