//===- CIRDialect.cpp - MLIR CIR ops implementation -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the CIR dialect and its operations.
//
//===----------------------------------------------------------------------===//

#include "clang/CIR/Dialect/IR/CIRDialect.h"

#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"

#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"

#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/Support/LogicalResult.h"

#include <numeric>

using namespace mlir;
using namespace cir;

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//
namespace {
struct CIROpAsmDialectInterface : public OpAsmDialectInterface {
  using OpAsmDialectInterface::OpAsmDialectInterface;

  AliasResult getAlias(Type type, raw_ostream &os) const final {
    if (auto recordType = dyn_cast<cir::RecordType>(type)) {
      StringAttr nameAttr = recordType.getName();
      if (!nameAttr)
        os << "rec_anon_" << recordType.getKindAsStr();
      else
        os << "rec_" << nameAttr.getValue();
      return AliasResult::OverridableAlias;
    }
    if (auto intType = dyn_cast<cir::IntType>(type)) {
      // We only provide alias for standard integer types (i.e. integer types
      // whose width is a power of 2 and at least 8).
      unsigned width = intType.getWidth();
      if (width < 8 || !llvm::isPowerOf2_32(width))
        return AliasResult::NoAlias;
      os << intType.getAlias();
      return AliasResult::OverridableAlias;
    }
    if (auto voidType = dyn_cast<cir::VoidType>(type)) {
      os << voidType.getAlias();
      return AliasResult::OverridableAlias;
    }

    return AliasResult::NoAlias;
  }

  AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
    if (auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr)) {
      os << (boolAttr.getValue() ? "true" : "false");
      return AliasResult::FinalAlias;
    }
    if (auto bitfield = mlir::dyn_cast<cir::BitfieldInfoAttr>(attr)) {
      os << "bfi_" << bitfield.getName().str();
      return AliasResult::FinalAlias;
    }
    return AliasResult::NoAlias;
  }
};
} // namespace

void cir::CIRDialect::initialize() {
  registerTypes();
  registerAttributes();
  addOperations<
#define GET_OP_LIST
#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
      >();
  addInterfaces<CIROpAsmDialectInterface>();
}

Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
                                                mlir::Attribute value,
                                                mlir::Type type,
                                                mlir::Location loc) {
  return builder.create<cir::ConstantOp>(loc, type,
                                         mlir::cast<mlir::TypedAttr>(value));
}

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//

// Parses one of the keywords provided in the list `keywords` and returns the
// position of the parsed keyword in the list. If none of the keywords from the
// list is parsed, returns -1.
static int parseOptionalKeywordAlternative(AsmParser &parser,
                                           ArrayRef<llvm::StringRef> keywords) {
  for (auto en : llvm::enumerate(keywords)) {
    if (succeeded(parser.parseOptionalKeyword(en.value())))
      return en.index();
  }
  return -1;
}

namespace {
template <typename Ty> struct EnumTraits {};

#define REGISTER_ENUM_TYPE(Ty)                                                 \
  template <> struct EnumTraits<cir::Ty> {                                     \
    static llvm::StringRef stringify(cir::Ty value) {                          \
      return stringify##Ty(value);                                             \
    }                                                                          \
    static unsigned getMaxEnumVal() { return cir::getMaxEnumValFor##Ty(); }    \
  }

REGISTER_ENUM_TYPE(GlobalLinkageKind);
REGISTER_ENUM_TYPE(VisibilityKind);
REGISTER_ENUM_TYPE(SideEffect);
} // namespace

/// Parse an enum from the keyword, or default to the provided default value.
/// The return type is the enum type by default, unless overriden with the
/// second template argument.
template <typename EnumTy, typename RetTy = EnumTy>
static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
  llvm::SmallVector<llvm::StringRef, 10> names;
  for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
    names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));

  int index = parseOptionalKeywordAlternative(parser, names);
  if (index == -1)
    return static_cast<RetTy>(defaultValue);
  return static_cast<RetTy>(index);
}

/// Parse an enum from the keyword, return failure if the keyword is not found.
template <typename EnumTy, typename RetTy = EnumTy>
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
  llvm::SmallVector<llvm::StringRef, 10> names;
  for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
    names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));

  int index = parseOptionalKeywordAlternative(parser, names);
  if (index == -1)
    return failure();
  result = static_cast<RetTy>(index);
  return success();
}

// Check if a region's termination omission is valid and, if so, creates and
// inserts the omitted terminator into the region.
static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
                                      SMLoc errLoc) {
  Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
  OpBuilder builder(parser.getBuilder().getContext());

  // Insert empty block in case the region is empty to ensure the terminator
  // will be inserted
  if (region.empty())
    builder.createBlock(&region);

  Block &block = region.back();
  // Region is properly terminated: nothing to do.
  if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>())
    return success();

  // Check for invalid terminator omissions.
  if (!region.hasOneBlock())
    return parser.emitError(errLoc,
                            "multi-block region must not omit terminator");

  // Terminator was omitted correctly: recreate it.
  builder.setInsertionPointToEnd(&block);
  builder.create<cir::YieldOp>(eLoc);
  return success();
}

// True if the region's terminator should be omitted.
static bool omitRegionTerm(mlir::Region &r) {
  const auto singleNonEmptyBlock = r.hasOneBlock() && !r.back().empty();
  const auto yieldsNothing = [&r]() {
    auto y = dyn_cast<cir::YieldOp>(r.back().getTerminator());
    return y && y.getArgs().empty();
  };
  return singleNonEmptyBlock && yieldsNothing();
}

void printVisibilityAttr(OpAsmPrinter &printer,
                         cir::VisibilityAttr &visibility) {
  switch (visibility.getValue()) {
  case cir::VisibilityKind::Hidden:
    printer << "hidden";
    break;
  case cir::VisibilityKind::Protected:
    printer << "protected";
    break;
  case cir::VisibilityKind::Default:
    break;
  }
}

void parseVisibilityAttr(OpAsmParser &parser, cir::VisibilityAttr &visibility) {
  cir::VisibilityKind visibilityKind =
      parseOptionalCIRKeyword(parser, cir::VisibilityKind::Default);
  visibility = cir::VisibilityAttr::get(parser.getContext(), visibilityKind);
}

//===----------------------------------------------------------------------===//
// CIR Custom Parsers/Printers
//===----------------------------------------------------------------------===//

static mlir::ParseResult parseOmittedTerminatorRegion(mlir::OpAsmParser &parser,
                                                      mlir::Region &region) {
  auto regionLoc = parser.getCurrentLocation();
  if (parser.parseRegion(region))
    return failure();
  if (ensureRegionTerm(parser, region, regionLoc).failed())
    return failure();
  return success();
}

static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer,
                                         cir::ScopeOp &op,
                                         mlir::Region &region) {
  printer.printRegion(region,
                      /*printEntryBlockArgs=*/false,
                      /*printBlockTerminators=*/!omitRegionTerm(region));
}

//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//

void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
                          mlir::OperationState &odsState, mlir::Type addr,
                          mlir::Type allocaType, llvm::StringRef name,
                          mlir::IntegerAttr alignment) {
  odsState.addAttribute(getAllocaTypeAttrName(odsState.name),
                        mlir::TypeAttr::get(allocaType));
  odsState.addAttribute(getNameAttrName(odsState.name),
                        odsBuilder.getStringAttr(name));
  if (alignment) {
    odsState.addAttribute(getAlignmentAttrName(odsState.name), alignment);
  }
  odsState.addTypes(addr);
}

//===----------------------------------------------------------------------===//
// BreakOp
//===----------------------------------------------------------------------===//

LogicalResult cir::BreakOp::verify() {
  assert(!cir::MissingFeatures::switchOp());
  if (!getOperation()->getParentOfType<LoopOpInterface>() &&
      !getOperation()->getParentOfType<SwitchOp>())
    return emitOpError("must be within a loop");
  return success();
}

//===----------------------------------------------------------------------===//
// ConditionOp
//===----------------------------------------------------------------------===//

//===----------------------------------
// BranchOpTerminatorInterface Methods
//===----------------------------------

void cir::ConditionOp::getSuccessorRegions(
    ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
  // TODO(cir): The condition value may be folded to a constant, narrowing
  // down its list of possible successors.

  // Parent is a loop: condition may branch to the body or to the parent op.
  if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
    regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
    regions.emplace_back(loopOp->getResults());
  }

  assert(!cir::MissingFeatures::awaitOp());
}

MutableOperandRange
cir::ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
  // No values are yielded to the successor region.
  return MutableOperandRange(getOperation(), 0, 0);
}

LogicalResult cir::ConditionOp::verify() {
  assert(!cir::MissingFeatures::awaitOp());
  if (!isa<LoopOpInterface>(getOperation()->getParentOp()))
    return emitOpError("condition must be within a conditional region");
  return success();
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//

static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
                                        mlir::Attribute attrType) {
  if (isa<cir::ConstPtrAttr>(attrType)) {
    if (!mlir::isa<cir::PointerType>(opType))
      return op->emitOpError(
          "pointer constant initializing a non-pointer type");
    return success();
  }

  if (isa<cir::ZeroAttr>(attrType)) {
    if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
            opType))
      return success();
    return op->emitOpError(
        "zero expects struct, array, vector, or complex type");
  }

  if (mlir::isa<cir::BoolAttr>(attrType)) {
    if (!mlir::isa<cir::BoolType>(opType))
      return op->emitOpError("result type (")
             << opType << ") must be '!cir.bool' for '" << attrType << "'";
    return success();
  }

  if (mlir::isa<cir::IntAttr, cir::FPAttr>(attrType)) {
    auto at = cast<TypedAttr>(attrType);
    if (at.getType() != opType) {
      return op->emitOpError("result type (")
             << opType << ") does not match value type (" << at.getType()
             << ")";
    }
    return success();
  }

  if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
                cir::ConstComplexAttr>(attrType))
    return success();

  assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
  return op->emitOpError("global with type ")
         << cast<TypedAttr>(attrType).getType() << " not yet supported";
}

LogicalResult cir::ConstantOp::verify() {
  // ODS already generates checks to make sure the result type is valid. We just
  // need to additionally check that the value's attribute type is consistent
  // with the result type.
  return checkConstantTypes(getOperation(), getType(), getValue());
}

OpFoldResult cir::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
  return getValue();
}

//===----------------------------------------------------------------------===//
// ContinueOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ContinueOp::verify() {
  if (!getOperation()->getParentOfType<LoopOpInterface>())
    return emitOpError("must be within a loop");
  return success();
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//

LogicalResult cir::CastOp::verify() {
  mlir::Type resType = getType();
  mlir::Type srcType = getSrc().getType();

  if (mlir::isa<cir::VectorType>(srcType) &&
      mlir::isa<cir::VectorType>(resType)) {
    // Use the element type of the vector to verify the cast kind. (Except for
    // bitcast, see below.)
    srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
    resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
  }

  switch (getKind()) {
  case cir::CastKind::int_to_bool: {
    if (!mlir::isa<cir::BoolType>(resType))
      return emitOpError() << "requires !cir.bool type for result";
    if (!mlir::isa<cir::IntType>(srcType))
      return emitOpError() << "requires !cir.int type for source";
    return success();
  }
  case cir::CastKind::ptr_to_bool: {
    if (!mlir::isa<cir::BoolType>(resType))
      return emitOpError() << "requires !cir.bool type for result";
    if (!mlir::isa<cir::PointerType>(srcType))
      return emitOpError() << "requires !cir.ptr type for source";
    return success();
  }
  case cir::CastKind::integral: {
    if (!mlir::isa<cir::IntType>(resType))
      return emitOpError() << "requires !cir.int type for result";
    if (!mlir::isa<cir::IntType>(srcType))
      return emitOpError() << "requires !cir.int type for source";
    return success();
  }
  case cir::CastKind::array_to_ptrdecay: {
    const auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
    const auto flatPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
    if (!arrayPtrTy || !flatPtrTy)
      return emitOpError() << "requires !cir.ptr type for source and result";

    // TODO(CIR): Make sure the AddrSpace of both types are equals
    return success();
  }
  case cir::CastKind::bitcast: {
    // Handle the pointer types first.
    auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
    auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);

    if (srcPtrTy && resPtrTy) {
      return success();
    }

    return success();
  }
  case cir::CastKind::floating: {
    if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
        !mlir::isa<cir::FPTypeInterface>(resType))
      return emitOpError() << "requires !cir.float type for source and result";
    return success();
  }
  case cir::CastKind::float_to_int: {
    if (!mlir::isa<cir::FPTypeInterface>(srcType))
      return emitOpError() << "requires !cir.float type for source";
    if (!mlir::dyn_cast<cir::IntType>(resType))
      return emitOpError() << "requires !cir.int type for result";
    return success();
  }
  case cir::CastKind::int_to_ptr: {
    if (!mlir::dyn_cast<cir::IntType>(srcType))
      return emitOpError() << "requires !cir.int type for source";
    if (!mlir::dyn_cast<cir::PointerType>(resType))
      return emitOpError() << "requires !cir.ptr type for result";
    return success();
  }
  case cir::CastKind::ptr_to_int: {
    if (!mlir::dyn_cast<cir::PointerType>(srcType))
      return emitOpError() << "requires !cir.ptr type for source";
    if (!mlir::dyn_cast<cir::IntType>(resType))
      return emitOpError() << "requires !cir.int type for result";
    return success();
  }
  case cir::CastKind::float_to_bool: {
    if (!mlir::isa<cir::FPTypeInterface>(srcType))
      return emitOpError() << "requires !cir.float type for source";
    if (!mlir::isa<cir::BoolType>(resType))
      return emitOpError() << "requires !cir.bool type for result";
    return success();
  }
  case cir::CastKind::bool_to_int: {
    if (!mlir::isa<cir::BoolType>(srcType))
      return emitOpError() << "requires !cir.bool type for source";
    if (!mlir::isa<cir::IntType>(resType))
      return emitOpError() << "requires !cir.int type for result";
    return success();
  }
  case cir::CastKind::int_to_float: {
    if (!mlir::isa<cir::IntType>(srcType))
      return emitOpError() << "requires !cir.int type for source";
    if (!mlir::isa<cir::FPTypeInterface>(resType))
      return emitOpError() << "requires !cir.float type for result";
    return success();
  }
  case cir::CastKind::bool_to_float: {
    if (!mlir::isa<cir::BoolType>(srcType))
      return emitOpError() << "requires !cir.bool type for source";
    if (!mlir::isa<cir::FPTypeInterface>(resType))
      return emitOpError() << "requires !cir.float type for result";
    return success();
  }
  case cir::CastKind::address_space: {
    auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcType);
    auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
    if (!srcPtrTy || !resPtrTy)
      return emitOpError() << "requires !cir.ptr type for source and result";
    if (srcPtrTy.getPointee() != resPtrTy.getPointee())
      return emitOpError() << "requires two types differ in addrspace only";
    return success();
  }
  default:
    llvm_unreachable("Unknown CastOp kind?");
  }
}

static bool isIntOrBoolCast(cir::CastOp op) {
  auto kind = op.getKind();
  return kind == cir::CastKind::bool_to_int ||
         kind == cir::CastKind::int_to_bool || kind == cir::CastKind::integral;
}

static Value tryFoldCastChain(cir::CastOp op) {
  cir::CastOp head = op, tail = op;

  while (op) {
    if (!isIntOrBoolCast(op))
      break;
    head = op;
    op = dyn_cast_or_null<cir::CastOp>(head.getSrc().getDefiningOp());
  }

  if (head == tail)
    return {};

  // if bool_to_int -> ...  -> int_to_bool: take the bool
  // as we had it was before all casts
  if (head.getKind() == cir::CastKind::bool_to_int &&
      tail.getKind() == cir::CastKind::int_to_bool)
    return head.getSrc();

  // if int_to_bool -> ...  -> int_to_bool: take the result
  // of the first one, as no other casts (and ext casts as well)
  // don't change the first result
  if (head.getKind() == cir::CastKind::int_to_bool &&
      tail.getKind() == cir::CastKind::int_to_bool)
    return head.getResult();

  return {};
}

OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
  if (getSrc().getType() == getType()) {
    switch (getKind()) {
    case cir::CastKind::integral: {
      // TODO: for sign differences, it's possible in certain conditions to
      // create a new attribute that's capable of representing the source.
      llvm::SmallVector<mlir::OpFoldResult, 1> foldResults;
      auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
      if (foldOrder.succeeded() && mlir::isa<mlir::Attribute>(foldResults[0]))
        return mlir::cast<mlir::Attribute>(foldResults[0]);
      return {};
    }
    case cir::CastKind::bitcast:
    case cir::CastKind::address_space:
    case cir::CastKind::float_complex:
    case cir::CastKind::int_complex: {
      return getSrc();
    }
    default:
      return {};
    }
  }
  return tryFoldCastChain(*this);
}

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//

mlir::OperandRange cir::CallOp::getArgOperands() {
  if (isIndirect())
    return getArgs().drop_front(1);
  return getArgs();
}

mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
  mlir::MutableOperandRange args = getArgsMutable();
  if (isIndirect())
    return args.slice(1, args.size() - 1);
  return args;
}

mlir::Value cir::CallOp::getIndirectCall() {
  assert(isIndirect());
  return getOperand(0);
}

/// Return the operand at index 'i'.
Value cir::CallOp::getArgOperand(unsigned i) {
  if (isIndirect())
    ++i;
  return getOperand(i);
}

/// Return the number of operands.
unsigned cir::CallOp::getNumArgOperands() {
  if (isIndirect())
    return this->getOperation()->getNumOperands() - 1;
  return this->getOperation()->getNumOperands();
}

static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
                                         mlir::OperationState &result) {
  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
  llvm::SMLoc opsLoc;
  mlir::FlatSymbolRefAttr calleeAttr;
  llvm::ArrayRef<mlir::Type> allResultTypes;

  // If we cannot parse a string callee, it means this is an indirect call.
  if (!parser
           .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
                                   result.attributes)
           .has_value()) {
    OpAsmParser::UnresolvedOperand indirectVal;
    // Do not resolve right now, since we need to figure out the type
    if (parser.parseOperand(indirectVal).failed())
      return failure();
    ops.push_back(indirectVal);
  }

  if (parser.parseLParen())
    return mlir::failure();

  opsLoc = parser.getCurrentLocation();
  if (parser.parseOperandList(ops))
    return mlir::failure();
  if (parser.parseRParen())
    return mlir::failure();

  if (parser.parseOptionalKeyword("nothrow").succeeded())
    result.addAttribute(CIRDialect::getNoThrowAttrName(),
                        mlir::UnitAttr::get(parser.getContext()));

  if (parser.parseOptionalKeyword("side_effect").succeeded()) {
    if (parser.parseLParen().failed())
      return failure();
    cir::SideEffect sideEffect;
    if (parseCIRKeyword<cir::SideEffect>(parser, sideEffect).failed())
      return failure();
    if (parser.parseRParen().failed())
      return failure();
    auto attr = cir::SideEffectAttr::get(parser.getContext(), sideEffect);
    result.addAttribute(CIRDialect::getSideEffectAttrName(), attr);
  }

  if (parser.parseOptionalAttrDict(result.attributes))
    return ::mlir::failure();

  if (parser.parseColon())
    return ::mlir::failure();

  mlir::FunctionType opsFnTy;
  if (parser.parseType(opsFnTy))
    return mlir::failure();

  allResultTypes = opsFnTy.getResults();
  result.addTypes(allResultTypes);

  if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
    return mlir::failure();

  return mlir::success();
}

static void printCallCommon(mlir::Operation *op,
                            mlir::FlatSymbolRefAttr calleeSym,
                            mlir::Value indirectCallee,
                            mlir::OpAsmPrinter &printer, bool isNothrow,
                            cir::SideEffect sideEffect) {
  printer << ' ';

  auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
  auto ops = callLikeOp.getArgOperands();

  if (calleeSym) {
    // Direct calls
    printer.printAttributeWithoutType(calleeSym);
  } else {
    // Indirect calls
    assert(indirectCallee);
    printer << indirectCallee;
  }
  printer << "(" << ops << ")";

  if (isNothrow)
    printer << " nothrow";

  if (sideEffect != cir::SideEffect::All) {
    printer << " side_effect(";
    printer << stringifySideEffect(sideEffect);
    printer << ")";
  }

  printer.printOptionalAttrDict(op->getAttrs(),
                                {CIRDialect::getCalleeAttrName(),
                                 CIRDialect::getNoThrowAttrName(),
                                 CIRDialect::getSideEffectAttrName()});

  printer << " : ";
  printer.printFunctionalType(op->getOperands().getTypes(),
                              op->getResultTypes());
}

mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
                                     mlir::OperationState &result) {
  return parseCallCommon(parser, result);
}

void cir::CallOp::print(mlir::OpAsmPrinter &p) {
  mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
  cir::SideEffect sideEffect = getSideEffect();
  printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
                  sideEffect);
}

static LogicalResult
verifyCallCommInSymbolUses(mlir::Operation *op,
                           SymbolTableCollection &symbolTable) {
  auto fnAttr =
      op->getAttrOfType<FlatSymbolRefAttr>(CIRDialect::getCalleeAttrName());
  if (!fnAttr) {
    // This is an indirect call, thus we don't have to check the symbol uses.
    return mlir::success();
  }

  auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
  if (!fn)
    return op->emitOpError() << "'" << fnAttr.getValue()
                             << "' does not reference a valid function";

  auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
  assert(callIf && "expected CIR call interface to be always available");

  // Verify that the operand and result types match the callee. Note that
  // argument-checking is disabled for functions without a prototype.
  auto fnType = fn.getFunctionType();
  if (!fn.getNoProto()) {
    unsigned numCallOperands = callIf.getNumArgOperands();
    unsigned numFnOpOperands = fnType.getNumInputs();

    if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
      return op->emitOpError("incorrect number of operands for callee");
    if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
      return op->emitOpError("too few operands for callee");

    for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
      if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
        return op->emitOpError("operand type mismatch: expected operand type ")
               << fnType.getInput(i) << ", but provided "
               << op->getOperand(i).getType() << " for operand number " << i;
  }

  assert(!cir::MissingFeatures::opCallCallConv());

  // Void function must not return any results.
  if (fnType.hasVoidReturn() && op->getNumResults() != 0)
    return op->emitOpError("callee returns void but call has results");

  // Non-void function calls must return exactly one result.
  if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
    return op->emitOpError("incorrect number of results for callee");

  // Parent function and return value types must match.
  if (!fnType.hasVoidReturn() &&
      op->getResultTypes().front() != fnType.getReturnType()) {
    return op->emitOpError("result type mismatch: expected ")
           << fnType.getReturnType() << ", but provided "
           << op->getResult(0).getType();
  }

  return mlir::success();
}

LogicalResult
cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
  return verifyCallCommInSymbolUses(*this, symbolTable);
}

//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//

static mlir::LogicalResult checkReturnAndFunction(cir::ReturnOp op,
                                                  cir::FuncOp function) {
  // ReturnOps currently only have a single optional operand.
  if (op.getNumOperands() > 1)
    return op.emitOpError() << "expects at most 1 return operand";

  // Ensure returned type matches the function signature.
  auto expectedTy = function.getFunctionType().getReturnType();
  auto actualTy =
      (op.getNumOperands() == 0 ? cir::VoidType::get(op.getContext())
                                : op.getOperand(0).getType());
  if (actualTy != expectedTy)
    return op.emitOpError() << "returns " << actualTy
                            << " but enclosing function returns " << expectedTy;

  return mlir::success();
}

mlir::LogicalResult cir::ReturnOp::verify() {
  // Returns can be present in multiple different scopes, get the
  // wrapping function and start from there.
  auto *fnOp = getOperation()->getParentOp();
  while (!isa<cir::FuncOp>(fnOp))
    fnOp = fnOp->getParentOp();

  // Make sure return types match function return type.
  if (checkReturnAndFunction(*this, cast<cir::FuncOp>(fnOp)).failed())
    return failure();

  return success();
}

//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//

ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
  // create the regions for 'then'.
  result.regions.reserve(2);
  Region *thenRegion = result.addRegion();
  Region *elseRegion = result.addRegion();

  mlir::Builder &builder = parser.getBuilder();
  OpAsmParser::UnresolvedOperand cond;
  Type boolType = cir::BoolType::get(builder.getContext());

  if (parser.parseOperand(cond) ||
      parser.resolveOperand(cond, boolType, result.operands))
    return failure();

  // Parse 'then' region.
  mlir::SMLoc parseThenLoc = parser.getCurrentLocation();
  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
    return failure();

  if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
    return failure();

  // If we find an 'else' keyword, parse the 'else' region.
  if (!parser.parseOptionalKeyword("else")) {
    mlir::SMLoc parseElseLoc = parser.getCurrentLocation();
    if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
      return failure();
    if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
      return failure();
  }

  // Parse the optional attribute list.
  if (parser.parseOptionalAttrDict(result.attributes))
    return failure();
  return success();
}

void cir::IfOp::print(OpAsmPrinter &p) {
  p << " " << getCondition() << " ";
  mlir::Region &thenRegion = this->getThenRegion();
  p.printRegion(thenRegion,
                /*printEntryBlockArgs=*/false,
                /*printBlockTerminators=*/!omitRegionTerm(thenRegion));

  // Print the 'else' regions if it exists and has a block.
  mlir::Region &elseRegion = this->getElseRegion();
  if (!elseRegion.empty()) {
    p << " else ";
    p.printRegion(elseRegion,
                  /*printEntryBlockArgs=*/false,
                  /*printBlockTerminators=*/!omitRegionTerm(elseRegion));
  }

  p.printOptionalAttrDict(getOperation()->getAttrs());
}

/// Default callback for IfOp builders.
void cir::buildTerminatedBody(OpBuilder &builder, Location loc) {
  // add cir.yield to end of the block
  builder.create<cir::YieldOp>(loc);
}

/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point,
                                    SmallVectorImpl<RegionSuccessor> &regions) {
  // The `then` and the `else` region branch back to the parent operation.
  if (!point.isParent()) {
    regions.push_back(RegionSuccessor());
    return;
  }

  // Don't consider the else region if it is empty.
  Region *elseRegion = &this->getElseRegion();
  if (elseRegion->empty())
    elseRegion = nullptr;

  // If the condition isn't constant, both regions may be executed.
  regions.push_back(RegionSuccessor(&getThenRegion()));
  // If the else region does not exist, it is not a viable successor.
  if (elseRegion)
    regions.push_back(RegionSuccessor(elseRegion));

  return;
}

void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
                      bool withElseRegion, BuilderCallbackRef thenBuilder,
                      BuilderCallbackRef elseBuilder) {
  assert(thenBuilder && "the builder callback for 'then' must be present");
  result.addOperands(cond);

  OpBuilder::InsertionGuard guard(builder);
  Region *thenRegion = result.addRegion();
  builder.createBlock(thenRegion);
  thenBuilder(builder, result.location);

  Region *elseRegion = result.addRegion();
  if (!withElseRegion)
    return;

  builder.createBlock(elseRegion);
  elseBuilder(builder, result.location);
}

//===----------------------------------------------------------------------===//
// ScopeOp
//===----------------------------------------------------------------------===//

/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes
/// that correspond to a constant value for each operand, or null if that
/// operand is not a constant.
void cir::ScopeOp::getSuccessorRegions(
    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
  // The only region always branch back to the parent operation.
  if (!point.isParent()) {
    regions.push_back(RegionSuccessor(getODSResults(0)));
    return;
  }

  // If the condition isn't constant, both regions may be executed.
  regions.push_back(RegionSuccessor(&getScopeRegion()));
}

void cir::ScopeOp::build(
    OpBuilder &builder, OperationState &result,
    function_ref<void(OpBuilder &, Type &, Location)> scopeBuilder) {
  assert(scopeBuilder && "the builder callback for 'then' must be present");

  OpBuilder::InsertionGuard guard(builder);
  Region *scopeRegion = result.addRegion();
  builder.createBlock(scopeRegion);
  assert(!cir::MissingFeatures::opScopeCleanupRegion());

  mlir::Type yieldTy;
  scopeBuilder(builder, yieldTy, result.location);

  if (yieldTy)
    result.addTypes(TypeRange{yieldTy});
}

void cir::ScopeOp::build(
    OpBuilder &builder, OperationState &result,
    function_ref<void(OpBuilder &, Location)> scopeBuilder) {
  assert(scopeBuilder && "the builder callback for 'then' must be present");
  OpBuilder::InsertionGuard guard(builder);
  Region *scopeRegion = result.addRegion();
  builder.createBlock(scopeRegion);
  assert(!cir::MissingFeatures::opScopeCleanupRegion());
  scopeBuilder(builder, result.location);
}

LogicalResult cir::ScopeOp::verify() {
  if (getRegion().empty()) {
    return emitOpError() << "cir.scope must not be empty since it should "
                            "include at least an implicit cir.yield ";
  }

  mlir::Block &lastBlock = getRegion().back();
  if (lastBlock.empty() || !lastBlock.mightHaveTerminator() ||
      !lastBlock.getTerminator()->hasTrait<OpTrait::IsTerminator>())
    return emitOpError() << "last block of cir.scope must be terminated";
  return success();
}

//===----------------------------------------------------------------------===//
// BrOp
//===----------------------------------------------------------------------===//

mlir::SuccessorOperands cir::BrOp::getSuccessorOperands(unsigned index) {
  assert(index == 0 && "invalid successor index");
  return mlir::SuccessorOperands(getDestOperandsMutable());
}

Block *cir::BrOp::getSuccessorForOperands(ArrayRef<Attribute>) {
  return getDest();
}

//===----------------------------------------------------------------------===//
// BrCondOp
//===----------------------------------------------------------------------===//

mlir::SuccessorOperands cir::BrCondOp::getSuccessorOperands(unsigned index) {
  assert(index < getNumSuccessors() && "invalid successor index");
  return SuccessorOperands(index == 0 ? getDestOperandsTrueMutable()
                                      : getDestOperandsFalseMutable());
}

Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
  if (IntegerAttr condAttr = dyn_cast_if_present<IntegerAttr>(operands.front()))
    return condAttr.getValue().isOne() ? getDestTrue() : getDestFalse();
  return nullptr;
}

//===----------------------------------------------------------------------===//
// CaseOp
//===----------------------------------------------------------------------===//

void cir::CaseOp::getSuccessorRegions(
    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
  if (!point.isParent()) {
    regions.push_back(RegionSuccessor());
    return;
  }
  regions.push_back(RegionSuccessor(&getCaseRegion()));
}

void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
                        ArrayAttr value, CaseOpKind kind,
                        OpBuilder::InsertPoint &insertPoint) {
  OpBuilder::InsertionGuard guardSwitch(builder);
  result.addAttribute("value", value);
  result.getOrAddProperties<Properties>().kind =
      cir::CaseOpKindAttr::get(builder.getContext(), kind);
  Region *caseRegion = result.addRegion();
  builder.createBlock(caseRegion);

  insertPoint = builder.saveInsertionPoint();
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//

static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
                                 mlir::OpAsmParser::UnresolvedOperand &cond,
                                 mlir::Type &condType) {
  cir::IntType intCondType;

  if (parser.parseLParen())
    return mlir::failure();

  if (parser.parseOperand(cond))
    return mlir::failure();
  if (parser.parseColon())
    return mlir::failure();
  if (parser.parseCustomTypeWithFallback(intCondType))
    return mlir::failure();
  condType = intCondType;

  if (parser.parseRParen())
    return mlir::failure();
  if (parser.parseRegion(regions, /*arguments=*/{}, /*argTypes=*/{}))
    return failure();

  return mlir::success();
}

static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
                          mlir::Region &bodyRegion, mlir::Value condition,
                          mlir::Type condType) {
  p << "(";
  p << condition;
  p << " : ";
  p.printStrippedAttrOrType(condType);
  p << ")";

  p << ' ';
  p.printRegion(bodyRegion, /*printEntryBlockArgs=*/false,
                /*printBlockTerminators=*/true);
}

void cir::SwitchOp::getSuccessorRegions(
    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
  if (!point.isParent()) {
    region.push_back(RegionSuccessor());
    return;
  }

  region.push_back(RegionSuccessor(&getBody()));
}

void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
                          Value cond, BuilderOpStateCallbackRef switchBuilder) {
  assert(switchBuilder && "the builder callback for regions must be present");
  OpBuilder::InsertionGuard guardSwitch(builder);
  Region *switchRegion = result.addRegion();
  builder.createBlock(switchRegion);
  result.addOperands({cond});
  switchBuilder(builder, result.location, result);
}

void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
  walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
    // Don't walk in nested switch op.
    if (isa<cir::SwitchOp>(op) && op != *this)
      return WalkResult::skip();

    if (auto caseOp = dyn_cast<cir::CaseOp>(op))
      cases.push_back(caseOp);

    return WalkResult::advance();
  });
}

bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
  collectCases(cases);

  if (getBody().empty())
    return false;

  if (!isa<YieldOp>(getBody().front().back()))
    return false;

  if (!llvm::all_of(getBody().front(),
                    [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
    return false;

  return llvm::all_of(cases, [this](CaseOp op) {
    return op->getParentOfType<SwitchOp>() == *this;
  });
}

//===----------------------------------------------------------------------===//
// SwitchFlatOp
//===----------------------------------------------------------------------===//

void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
                              Value value, Block *defaultDestination,
                              ValueRange defaultOperands,
                              ArrayRef<APInt> caseValues,
                              BlockRange caseDestinations,
                              ArrayRef<ValueRange> caseOperands) {

  std::vector<mlir::Attribute> caseValuesAttrs;
  for (const APInt &val : caseValues)
    caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
  mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);

  build(builder, result, value, defaultOperands, caseOperands, attrs,
        defaultDestination, caseDestinations);
}

/// <cases> ::= `[` (case (`,` case )* )? `]`
/// <case>  ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
static ParseResult parseSwitchFlatOpCases(
    OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
    SmallVectorImpl<Block *> &caseDestinations,
    SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
        &caseOperands,
    SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
  if (failed(parser.parseLSquare()))
    return failure();
  if (succeeded(parser.parseOptionalRSquare()))
    return success();
  llvm::SmallVector<mlir::Attribute> values;

  auto parseCase = [&]() {
    int64_t value = 0;
    if (failed(parser.parseInteger(value)))
      return failure();

    values.push_back(cir::IntAttr::get(flagType, value));

    Block *destination;
    llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
    llvm::SmallVector<Type> operandTypes;
    if (parser.parseColon() || parser.parseSuccessor(destination))
      return failure();
    if (!parser.parseOptionalLParen()) {
      if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
                                  /*allowResultNumber=*/false) ||
          parser.parseColonTypeList(operandTypes) || parser.parseRParen())
        return failure();
    }
    caseDestinations.push_back(destination);
    caseOperands.emplace_back(operands);
    caseOperandTypes.emplace_back(operandTypes);
    return success();
  };
  if (failed(parser.parseCommaSeparatedList(parseCase)))
    return failure();

  caseValues = ArrayAttr::get(flagType.getContext(), values);

  return parser.parseRSquare();
}

static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
                                   Type flagType, mlir::ArrayAttr caseValues,
                                   SuccessorRange caseDestinations,
                                   OperandRangeRange caseOperands,
                                   const TypeRangeRange &caseOperandTypes) {
  p << '[';
  p.printNewline();
  if (!caseValues) {
    p << ']';
    return;
  }

  size_t index = 0;
  llvm::interleave(
      llvm::zip(caseValues, caseDestinations),
      [&](auto i) {
        p << "  ";
        mlir::Attribute a = std::get<0>(i);
        p << mlir::cast<cir::IntAttr>(a).getValue();
        p << ": ";
        p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
      },
      [&] {
        p << ',';
        p.printNewline();
      });
  p.printNewline();
  p << ']';
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//

static ParseResult parseConstantValue(OpAsmParser &parser,
                                      mlir::Attribute &valueAttr) {
  NamedAttrList attr;
  return parser.parseAttribute(valueAttr, "value", attr);
}

static void printConstant(OpAsmPrinter &p, Attribute value) {
  p.printAttribute(value);
}

mlir::LogicalResult cir::GlobalOp::verify() {
  // Verify that the initial value, if present, is either a unit attribute or
  // an attribute CIR supports.
  if (getInitialValue().has_value()) {
    if (checkConstantTypes(getOperation(), getSymType(), *getInitialValue())
            .failed())
      return failure();
  }

  // TODO(CIR): Many other checks for properties that haven't been upstreamed
  // yet.

  return success();
}

void cir::GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                          llvm::StringRef sym_name, mlir::Type sym_type,
                          cir::GlobalLinkageKind linkage) {
  odsState.addAttribute(getSymNameAttrName(odsState.name),
                        odsBuilder.getStringAttr(sym_name));
  odsState.addAttribute(getSymTypeAttrName(odsState.name),
                        mlir::TypeAttr::get(sym_type));

  cir::GlobalLinkageKindAttr linkageAttr =
      cir::GlobalLinkageKindAttr::get(odsBuilder.getContext(), linkage);
  odsState.addAttribute(getLinkageAttrName(odsState.name), linkageAttr);

  odsState.addAttribute(getGlobalVisibilityAttrName(odsState.name),
                        cir::VisibilityAttr::get(odsBuilder.getContext()));
}

static void printGlobalOpTypeAndInitialValue(OpAsmPrinter &p, cir::GlobalOp op,
                                             TypeAttr type,
                                             Attribute initAttr) {
  if (!op.isDeclaration()) {
    p << "= ";
    // This also prints the type...
    if (initAttr)
      printConstant(p, initAttr);
  } else {
    p << ": " << type;
  }
}

static ParseResult
parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
                                 Attribute &initialValueAttr) {
  mlir::Type opTy;
  if (parser.parseOptionalEqual().failed()) {
    // Absence of equal means a declaration, so we need to parse the type.
    //  cir.global @a : !cir.int<s, 32>
    if (parser.parseColonType(opTy))
      return failure();
  } else {
    // Parse constant with initializer, examples:
    //  cir.global @y = #cir.fp<1.250000e+00> : !cir.double
    //  cir.global @rgb = #cir.const_array<[...] : !cir.array<i8 x 3>>
    if (parseConstantValue(parser, initialValueAttr).failed())
      return failure();

    assert(mlir::isa<mlir::TypedAttr>(initialValueAttr) &&
           "Non-typed attrs shouldn't appear here.");
    auto typedAttr = mlir::cast<mlir::TypedAttr>(initialValueAttr);
    opTy = typedAttr.getType();
  }

  typeAttr = TypeAttr::get(opTy);
  return success();
}

//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//

LogicalResult
cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
  // Verify that the result type underlying pointer type matches the type of
  // the referenced cir.global or cir.func op.
  mlir::Operation *op =
      symbolTable.lookupNearestSymbolFrom(*this, getNameAttr());
  if (op == nullptr || !(isa<GlobalOp>(op) || isa<FuncOp>(op)))
    return emitOpError("'")
           << getName()
           << "' does not reference a valid cir.global or cir.func";

  mlir::Type symTy;
  if (auto g = dyn_cast<GlobalOp>(op)) {
    symTy = g.getSymType();
    assert(!cir::MissingFeatures::addressSpace());
    assert(!cir::MissingFeatures::opGlobalThreadLocal());
  } else if (auto f = dyn_cast<FuncOp>(op)) {
    symTy = f.getFunctionType();
  } else {
    llvm_unreachable("Unexpected operation for GetGlobalOp");
  }

  auto resultType = dyn_cast<PointerType>(getAddr().getType());
  if (!resultType || symTy != resultType.getPointee())
    return emitOpError("result type pointee type '")
           << resultType.getPointee() << "' does not match type " << symTy
           << " of the global @" << getName();

  return success();
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

/// Returns the name used for the linkage attribute. This *must* correspond to
/// the name of the attribute in ODS.
static llvm::StringRef getLinkageAttrNameString() { return "linkage"; }

void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
                        StringRef name, FuncType type,
                        GlobalLinkageKind linkage) {
  result.addRegion();
  result.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
  result.addAttribute(getFunctionTypeAttrName(result.name),
                      TypeAttr::get(type));
  result.addAttribute(
      getLinkageAttrNameString(),
      GlobalLinkageKindAttr::get(builder.getContext(), linkage));
  result.addAttribute(getGlobalVisibilityAttrName(result.name),
                      cir::VisibilityAttr::get(builder.getContext()));
}

ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
  llvm::SMLoc loc = parser.getCurrentLocation();
  mlir::Builder &builder = parser.getBuilder();

  mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
  mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
  mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);

  // Default to external linkage if no keyword is provided.
  state.addAttribute(getLinkageAttrNameString(),
                     GlobalLinkageKindAttr::get(
                         parser.getContext(),
                         parseOptionalCIRKeyword<GlobalLinkageKind>(
                             parser, GlobalLinkageKind::ExternalLinkage)));

  ::llvm::StringRef visAttrStr;
  if (parser.parseOptionalKeyword(&visAttrStr, {"private", "public", "nested"})
          .succeeded()) {
    state.addAttribute(visNameAttr,
                       parser.getBuilder().getStringAttr(visAttrStr));
  }

  cir::VisibilityAttr cirVisibilityAttr;
  parseVisibilityAttr(parser, cirVisibilityAttr);
  state.addAttribute(visibilityNameAttr, cirVisibilityAttr);

  if (parser.parseOptionalKeyword(dsoLocalNameAttr).succeeded())
    state.addAttribute(dsoLocalNameAttr, parser.getBuilder().getUnitAttr());

  StringAttr nameAttr;
  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             state.attributes))
    return failure();
  llvm::SmallVector<OpAsmParser::Argument, 8> arguments;
  llvm::SmallVector<mlir::Type> resultTypes;
  llvm::SmallVector<DictionaryAttr> resultAttrs;
  bool isVariadic = false;
  if (function_interface_impl::parseFunctionSignatureWithArguments(
          parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
          resultAttrs))
    return failure();
  llvm::SmallVector<mlir::Type> argTypes;
  for (OpAsmParser::Argument &arg : arguments)
    argTypes.push_back(arg.type);

  if (resultTypes.size() > 1) {
    return parser.emitError(
        loc, "functions with multiple return types are not supported");
  }

  mlir::Type returnType =
      (resultTypes.empty() ? cir::VoidType::get(builder.getContext())
                           : resultTypes.front());

  cir::FuncType fnType = cir::FuncType::get(argTypes, returnType, isVariadic);
  if (!fnType)
    return failure();
  state.addAttribute(getFunctionTypeAttrName(state.name),
                     TypeAttr::get(fnType));

  bool hasAlias = false;
  mlir::StringAttr aliaseeNameAttr = getAliaseeAttrName(state.name);
  if (parser.parseOptionalKeyword("alias").succeeded()) {
    if (parser.parseLParen().failed())
      return failure();
    mlir::StringAttr aliaseeAttr;
    if (parser.parseOptionalSymbolName(aliaseeAttr).failed())
      return failure();
    state.addAttribute(aliaseeNameAttr, FlatSymbolRefAttr::get(aliaseeAttr));
    if (parser.parseRParen().failed())
      return failure();
    hasAlias = true;
  }

  // Parse the optional function body.
  auto *body = state.addRegion();
  OptionalParseResult parseResult = parser.parseOptionalRegion(
      *body, arguments, /*enableNameShadowing=*/false);
  if (parseResult.has_value()) {
    if (hasAlias)
      return parser.emitError(loc, "function alias shall not have a body");
    if (failed(*parseResult))
      return failure();
    // Function body was parsed, make sure its not empty.
    if (body->empty())
      return parser.emitError(loc, "expected non-empty function body");
  }

  return success();
}

// This function corresponds to `llvm::GlobalValue::isDeclaration` and should
// have a similar implementation. We don't currently ifuncs or materializable
// functions, but those should be handled here as they are implemented.
bool cir::FuncOp::isDeclaration() {
  assert(!cir::MissingFeatures::supportIFuncAttr());

  std::optional<StringRef> aliasee = getAliasee();
  if (!aliasee)
    return getFunctionBody().empty();

  // Aliases are always definitions.
  return false;
}

mlir::Region *cir::FuncOp::getCallableRegion() {
  // TODO(CIR): This function will have special handling for aliases and a
  // check for an external function, once those features have been upstreamed.
  return &getBody();
}

void cir::FuncOp::print(OpAsmPrinter &p) {
  if (getComdat())
    p << " comdat";

  if (getLinkage() != GlobalLinkageKind::ExternalLinkage)
    p << ' ' << stringifyGlobalLinkageKind(getLinkage());

  mlir::SymbolTable::Visibility vis = getVisibility();
  if (vis != mlir::SymbolTable::Visibility::Public)
    p << ' ' << vis;

  cir::VisibilityAttr cirVisibilityAttr = getGlobalVisibilityAttr();
  if (!cirVisibilityAttr.isDefault()) {
    p << ' ';
    printVisibilityAttr(p, cirVisibilityAttr);
  }

  if (getDsoLocal())
    p << " dso_local";

  p << ' ';
  p.printSymbolName(getSymName());
  cir::FuncType fnType = getFunctionType();
  function_interface_impl::printFunctionSignature(
      p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());

  if (std::optional<StringRef> aliaseeName = getAliasee()) {
    p << " alias(";
    p.printSymbolName(*aliaseeName);
    p << ")";
  }

  // Print the body if this is not an external function.
  Region &body = getOperation()->getRegion(0);
  if (!body.empty()) {
    p << ' ';
    p.printRegion(body, /*printEntryBlockArgs=*/false,
                  /*printBlockTerminators=*/true);
  }
}

// TODO(CIR): The properties of functions that require verification haven't
// been implemented yet.
mlir::LogicalResult cir::FuncOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// BinOp
//===----------------------------------------------------------------------===//
LogicalResult cir::BinOp::verify() {
  bool noWrap = getNoUnsignedWrap() || getNoSignedWrap();
  bool saturated = getSaturated();

  if (!isa<cir::IntType>(getType()) && noWrap)
    return emitError()
           << "only operations on integer values may have nsw/nuw flags";

  bool noWrapOps = getKind() == cir::BinOpKind::Add ||
                   getKind() == cir::BinOpKind::Sub ||
                   getKind() == cir::BinOpKind::Mul;

  bool saturatedOps =
      getKind() == cir::BinOpKind::Add || getKind() == cir::BinOpKind::Sub;

  if (noWrap && !noWrapOps)
    return emitError() << "The nsw/nuw flags are applicable to opcodes: 'add', "
                          "'sub' and 'mul'";
  if (saturated && !saturatedOps)
    return emitError() << "The saturated flag is applicable to opcodes: 'add' "
                          "and 'sub'";
  if (noWrap && saturated)
    return emitError() << "The nsw/nuw flags and the saturated flag are "
                          "mutually exclusive";

  assert(!cir::MissingFeatures::complexType());
  // TODO(cir): verify for complex binops

  return mlir::success();
}

//===----------------------------------------------------------------------===//
// TernaryOp
//===----------------------------------------------------------------------===//

/// Given the region at `point`, or the parent operation if `point` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void cir::TernaryOp::getSuccessorRegions(
    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
  // The `true` and the `false` region branch back to the parent operation.
  if (!point.isParent()) {
    regions.push_back(RegionSuccessor(this->getODSResults(0)));
    return;
  }

  // When branching from the parent operation, both the true and false
  // regions are considered possible successors
  regions.push_back(RegionSuccessor(&getTrueRegion()));
  regions.push_back(RegionSuccessor(&getFalseRegion()));
}

void cir::TernaryOp::build(
    OpBuilder &builder, OperationState &result, Value cond,
    function_ref<void(OpBuilder &, Location)> trueBuilder,
    function_ref<void(OpBuilder &, Location)> falseBuilder) {
  result.addOperands(cond);
  OpBuilder::InsertionGuard guard(builder);
  Region *trueRegion = result.addRegion();
  Block *block = builder.createBlock(trueRegion);
  trueBuilder(builder, result.location);
  Region *falseRegion = result.addRegion();
  builder.createBlock(falseRegion);
  falseBuilder(builder, result.location);

  auto yield = dyn_cast<YieldOp>(block->getTerminator());
  assert((yield && yield.getNumOperands() <= 1) &&
         "expected zero or one result type");
  if (yield.getNumOperands() == 1)
    result.addTypes(TypeRange{yield.getOperandTypes().front()});
}

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
  mlir::Attribute condition = adaptor.getCondition();
  if (condition) {
    bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
    return conditionValue ? getTrueValue() : getFalseValue();
  }

  // cir.select if %0 then x else x -> x
  mlir::Attribute trueValue = adaptor.getTrueValue();
  mlir::Attribute falseValue = adaptor.getFalseValue();
  if (trueValue == falseValue)
    return trueValue;
  if (getTrueValue() == getFalseValue())
    return getTrueValue();

  return {};
}

//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
LogicalResult cir::ShiftOp::verify() {
  mlir::Operation *op = getOperation();
  auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
  auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
  if (!op0VecTy ^ !op1VecTy)
    return emitOpError() << "input types cannot be one vector and one scalar";

  if (op0VecTy) {
    if (op0VecTy.getSize() != op1VecTy.getSize())
      return emitOpError() << "input vector types must have the same size";

    auto opResultTy = mlir::dyn_cast<cir::VectorType>(getType());
    if (!opResultTy)
      return emitOpError() << "the type of the result must be a vector "
                           << "if it is vector shift";

    auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
    auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
    if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
      return emitOpError()
             << "vector operands do not have the same elements sizes";

    auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
    if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
      return emitOpError() << "vector operands and result type do not have the "
                              "same elements sizes";
  }

  return mlir::success();
}

//===----------------------------------------------------------------------===//
// UnaryOp
//===----------------------------------------------------------------------===//

LogicalResult cir::UnaryOp::verify() {
  switch (getKind()) {
  case cir::UnaryOpKind::Inc:
  case cir::UnaryOpKind::Dec:
  case cir::UnaryOpKind::Plus:
  case cir::UnaryOpKind::Minus:
  case cir::UnaryOpKind::Not:
    // Nothing to verify.
    return success();
  }

  llvm_unreachable("Unknown UnaryOp kind?");
}

static bool isBoolNot(cir::UnaryOp op) {
  return isa<cir::BoolType>(op.getInput().getType()) &&
         op.getKind() == cir::UnaryOpKind::Not;
}

// This folder simplifies the sequential boolean not operations.
// For instance, the next two unary operations will be eliminated:
//
// ```mlir
// %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
// %2 = cir.unary(not, %1) : !cir.bool, !cir.bool
// ```
//
// and the argument of the first one (%0) will be used instead.
OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
  if (isBoolNot(*this))
    if (auto previous = dyn_cast_or_null<UnaryOp>(getInput().getDefiningOp()))
      if (isBoolNot(previous))
        return previous.getInput();

  return {};
}

//===----------------------------------------------------------------------===//
// GetMemberOp Definitions
//===----------------------------------------------------------------------===//

LogicalResult cir::GetMemberOp::verify() {
  const auto recordTy = dyn_cast<RecordType>(getAddrTy().getPointee());
  if (!recordTy)
    return emitError() << "expected pointer to a record type";

  if (recordTy.getMembers().size() <= getIndex())
    return emitError() << "member index out of bounds";

  if (recordTy.getMembers()[getIndex()] != getType().getPointee())
    return emitError() << "member type mismatch";

  return mlir::success();
}

//===----------------------------------------------------------------------===//
// VecCreateOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::VecCreateOp::fold(FoldAdaptor adaptor) {
  if (llvm::any_of(getElements(), [](mlir::Value value) {
        return !mlir::isa<cir::ConstantOp>(value.getDefiningOp());
      }))
    return {};

  return cir::ConstVectorAttr::get(
      getType(), mlir::ArrayAttr::get(getContext(), adaptor.getElements()));
}

LogicalResult cir::VecCreateOp::verify() {
  // Verify that the number of arguments matches the number of elements in the
  // vector, and that the type of all the arguments matches the type of the
  // elements in the vector.
  const cir::VectorType vecTy = getType();
  if (getElements().size() != vecTy.getSize()) {
    return emitOpError() << "operand count of " << getElements().size()
                         << " doesn't match vector type " << vecTy
                         << " element count of " << vecTy.getSize();
  }

  const mlir::Type elementType = vecTy.getElementType();
  for (const mlir::Value element : getElements()) {
    if (element.getType() != elementType) {
      return emitOpError() << "operand type " << element.getType()
                           << " doesn't match vector element type "
                           << elementType;
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// VecExtractOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
  const auto vectorAttr =
      llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
  if (!vectorAttr)
    return {};

  const auto indexAttr =
      llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
  if (!indexAttr)
    return {};

  const mlir::ArrayAttr elements = vectorAttr.getElts();
  const uint64_t index = indexAttr.getUInt();
  if (index >= elements.size())
    return {};

  return elements[index];
}

//===----------------------------------------------------------------------===//
// VecCmpOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
  auto lhsVecAttr =
      mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
  auto rhsVecAttr =
      mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
  if (!lhsVecAttr || !rhsVecAttr)
    return {};

  mlir::Type inputElemTy =
      mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
  if (!isAnyIntegerOrFloatingPointType(inputElemTy))
    return {};

  cir::CmpOpKind opKind = adaptor.getKind();
  mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
  mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
  uint64_t vecSize = lhsVecElhs.size();

  SmallVector<mlir::Attribute, 16> elements(vecSize);
  bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
  for (uint64_t i = 0; i < vecSize; i++) {
    mlir::Attribute lhsAttr = lhsVecElhs[i];
    mlir::Attribute rhsAttr = rhsVecElhs[i];
    int cmpResult = 0;
    switch (opKind) {
    case cir::CmpOpKind::lt: {
      if (isIntAttr) {
        cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
                    mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
      } else {
        cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
                    mlir::cast<cir::FPAttr>(rhsAttr).getValue();
      }
      break;
    }
    case cir::CmpOpKind::le: {
      if (isIntAttr) {
        cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
                    mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
      } else {
        cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
                    mlir::cast<cir::FPAttr>(rhsAttr).getValue();
      }
      break;
    }
    case cir::CmpOpKind::gt: {
      if (isIntAttr) {
        cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
                    mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
      } else {
        cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
                    mlir::cast<cir::FPAttr>(rhsAttr).getValue();
      }
      break;
    }
    case cir::CmpOpKind::ge: {
      if (isIntAttr) {
        cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
                    mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
      } else {
        cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
                    mlir::cast<cir::FPAttr>(rhsAttr).getValue();
      }
      break;
    }
    case cir::CmpOpKind::eq: {
      if (isIntAttr) {
        cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
                    mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
      } else {
        cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
                    mlir::cast<cir::FPAttr>(rhsAttr).getValue();
      }
      break;
    }
    case cir::CmpOpKind::ne: {
      if (isIntAttr) {
        cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
                    mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
      } else {
        cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
                    mlir::cast<cir::FPAttr>(rhsAttr).getValue();
      }
      break;
    }
    }

    elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
  }

  return cir::ConstVectorAttr::get(
      getType(), mlir::ArrayAttr::get(getContext(), elements));
}

//===----------------------------------------------------------------------===//
// VecShuffleOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) {
  auto vec1Attr =
      mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec1());
  auto vec2Attr =
      mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec2());
  if (!vec1Attr || !vec2Attr)
    return {};

  mlir::Type vec1ElemTy =
      mlir::cast<cir::VectorType>(vec1Attr.getType()).getElementType();

  mlir::ArrayAttr vec1Elts = vec1Attr.getElts();
  mlir::ArrayAttr vec2Elts = vec2Attr.getElts();
  mlir::ArrayAttr indicesElts = adaptor.getIndices();

  SmallVector<mlir::Attribute, 16> elements;
  elements.reserve(indicesElts.size());

  uint64_t vec1Size = vec1Elts.size();
  for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
    if (idxAttr.getSInt() == -1) {
      elements.push_back(cir::UndefAttr::get(vec1ElemTy));
      continue;
    }

    uint64_t idxValue = idxAttr.getUInt();
    elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue]
                                           : vec2Elts[idxValue - vec1Size]);
  }

  return cir::ConstVectorAttr::get(
      getType(), mlir::ArrayAttr::get(getContext(), elements));
}

LogicalResult cir::VecShuffleOp::verify() {
  // The number of elements in the indices array must match the number of
  // elements in the result type.
  if (getIndices().size() != getResult().getType().getSize()) {
    return emitOpError() << ": the number of elements in " << getIndices()
                         << " and " << getResult().getType() << " don't match";
  }

  // The element types of the two input vectors and of the result type must
  // match.
  if (getVec1().getType().getElementType() !=
      getResult().getType().getElementType()) {
    return emitOpError() << ": element types of " << getVec1().getType()
                         << " and " << getResult().getType() << " don't match";
  }

  const uint64_t maxValidIndex =
      getVec1().getType().getSize() + getVec2().getType().getSize() - 1;
  if (llvm::any_of(
          getIndices().getAsRange<cir::IntAttr>(), [&](cir::IntAttr idxAttr) {
            return idxAttr.getSInt() != -1 && idxAttr.getUInt() > maxValidIndex;
          })) {
    return emitOpError() << ": index for __builtin_shufflevector must be "
                            "less than the total number of vector elements";
  }
  return success();
}

//===----------------------------------------------------------------------===//
// VecShuffleDynamicOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) {
  mlir::Attribute vec = adaptor.getVec();
  mlir::Attribute indices = adaptor.getIndices();
  if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(vec) &&
      mlir::isa_and_nonnull<cir::ConstVectorAttr>(indices)) {
    auto vecAttr = mlir::cast<cir::ConstVectorAttr>(vec);
    auto indicesAttr = mlir::cast<cir::ConstVectorAttr>(indices);

    mlir::ArrayAttr vecElts = vecAttr.getElts();
    mlir::ArrayAttr indicesElts = indicesAttr.getElts();

    const uint64_t numElements = vecElts.size();

    SmallVector<mlir::Attribute, 16> elements;
    elements.reserve(numElements);

    const uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
    for (const auto &idxAttr : indicesElts.getAsRange<cir::IntAttr>()) {
      uint64_t idxValue = idxAttr.getUInt();
      uint64_t newIdx = idxValue & maskBits;
      elements.push_back(vecElts[newIdx]);
    }

    return cir::ConstVectorAttr::get(
        getType(), mlir::ArrayAttr::get(getContext(), elements));
  }

  return {};
}

LogicalResult cir::VecShuffleDynamicOp::verify() {
  // The number of elements in the two input vectors must match.
  if (getVec().getType().getSize() !=
      mlir::cast<cir::VectorType>(getIndices().getType()).getSize()) {
    return emitOpError() << ": the number of elements in " << getVec().getType()
                         << " and " << getIndices().getType() << " don't match";
  }
  return success();
}

//===----------------------------------------------------------------------===//
// VecTernaryOp
//===----------------------------------------------------------------------===//

LogicalResult cir::VecTernaryOp::verify() {
  // Verify that the condition operand has the same number of elements as the
  // other operands.  (The automatic verification already checked that all
  // operands are vector types and that the second and third operands are the
  // same type.)
  if (getCond().getType().getSize() != getLhs().getType().getSize()) {
    return emitOpError() << ": the number of elements in "
                         << getCond().getType() << " and " << getLhs().getType()
                         << " don't match";
  }
  return success();
}

OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
  mlir::Attribute cond = adaptor.getCond();
  mlir::Attribute lhs = adaptor.getLhs();
  mlir::Attribute rhs = adaptor.getRhs();

  if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
      !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
      !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
    return {};
  auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
  auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
  auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);

  mlir::ArrayAttr condElts = condVec.getElts();

  SmallVector<mlir::Attribute, 16> elements;
  elements.reserve(condElts.size());

  for (const auto &[idx, condAttr] :
       llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
    if (condAttr.getSInt()) {
      elements.push_back(lhsVec.getElts()[idx]);
    } else {
      elements.push_back(rhsVec.getElts()[idx]);
    }
  }

  cir::VectorType vecTy = getLhs().getType();
  return cir::ConstVectorAttr::get(
      vecTy, mlir::ArrayAttr::get(getContext(), elements));
}

//===----------------------------------------------------------------------===//
// ComplexCreateOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ComplexCreateOp::verify() {
  if (getType().getElementType() != getReal().getType()) {
    emitOpError()
        << "operand type of cir.complex.create does not match its result type";
    return failure();
  }

  return success();
}

OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
  mlir::Attribute real = adaptor.getReal();
  mlir::Attribute imag = adaptor.getImag();
  if (!real || !imag)
    return {};

  // When both of real and imag are constants, we can fold the operation into an
  // `#cir.const_complex` operation.
  auto realAttr = mlir::cast<mlir::TypedAttr>(real);
  auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
  return cir::ConstComplexAttr::get(realAttr, imagAttr);
}

//===----------------------------------------------------------------------===//
// ComplexRealOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ComplexRealOp::verify() {
  if (getType() != getOperand().getType().getElementType()) {
    emitOpError() << ": result type does not match operand type";
    return failure();
  }
  return success();
}

OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
  if (auto complexCreateOp =
          dyn_cast_or_null<cir::ComplexCreateOp>(getOperand().getDefiningOp()))
    return complexCreateOp.getOperand(0);

  auto complex =
      mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
  return complex ? complex.getReal() : nullptr;
}

//===----------------------------------------------------------------------===//
// ComplexImagOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ComplexImagOp::verify() {
  if (getType() != getOperand().getType().getElementType()) {
    emitOpError() << ": result type does not match operand type";
    return failure();
  }
  return success();
}

OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
  if (auto complexCreateOp =
          dyn_cast_or_null<cir::ComplexCreateOp>(getOperand().getDefiningOp()))
    return complexCreateOp.getOperand(1);

  auto complex =
      mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
  return complex ? complex.getImag() : nullptr;
}

//===----------------------------------------------------------------------===//
// ComplexRealPtrOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ComplexRealPtrOp::verify() {
  mlir::Type resultPointeeTy = getType().getPointee();
  cir::PointerType operandPtrTy = getOperand().getType();
  auto operandPointeeTy =
      mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());

  if (resultPointeeTy != operandPointeeTy.getElementType()) {
    return emitOpError() << ": result type does not match operand type";
  }

  return success();
}

//===----------------------------------------------------------------------===//
// ComplexImagPtrOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ComplexImagPtrOp::verify() {
  mlir::Type resultPointeeTy = getType().getPointee();
  cir::PointerType operandPtrTy = getOperand().getType();
  auto operandPointeeTy =
      mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());

  if (resultPointeeTy != operandPointeeTy.getElementType()) {
    return emitOpError()
           << "cir.complex.imag_ptr result type does not match operand type";
  }
  return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
