//===- MIGraphXToTosa.cpp - Lowering MIGraphX to Tosa Dialect
//-------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the MIGraphX to the Tos dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MHAL/IR/MHAL.h"
#include "mlir/Dialect/MIGraphX/IR/MIGraphX.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/RockTypes.h"
#include "mlir/Dialect/Rock/utility/builderUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"

using namespace mlir;
using mlir::migraphx::MIXRShapedType;

#define DEBUG_TYPE "migraphx-to-tosa"

//===----------------------------------------------------------------------===//
// Type conversions
//===----------------------------------------------------------------------===//

migraphx::MIXRShapedToTensorConverter::MIXRShapedToTensorConverter() {
  addConversion([](Type type) {
    if (type.isInteger() && !type.isSignlessInteger()) {
      type = IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
                              IntegerType::SignednessSemantics::Signless);
    }
    return type;
  });
  addConversion([](MIXRShapedType shaped) {
    RankedTensorType newType = shaped.asTensor();
    Type elementType = newType.getElementType();

    // Convert to signless if the element type is a signed integer
    if (elementType.isInteger() && !elementType.isSignlessInteger()) {
      elementType = IntegerType::get(
          shaped.getContext(), elementType.getIntOrFloatBitWidth(),
          IntegerType::SignednessSemantics::Signless);
      // Create a new tensor type with the signless element type
      newType = RankedTensorType::get(newType.getShape(), elementType);
    }
    return newType;
  });

  addSourceMaterialization([](OpBuilder &b, MIXRShapedType shapedResType,
                              ValueRange tensorResult, Location loc) -> Value {
    if (tensorResult.size() != 1)
      return Value(); // 1-1 conversions only.
    return migraphx::AsUnderlyingShapeOp::create(b, loc, shapedResType,
                                                 tensorResult[0]);
  });

  addTargetMaterialization([](OpBuilder &b, Type wantedInputType,
                              ValueRange shapedInput, Location loc) -> Value {
    if (shapedInput.size() != 1)
      return Value(); // 1-1 conversions only.
    return migraphx::AsLogicalShapeOp::create(b, loc, wantedInputType,
                                              shapedInput[0]);
  });
}

migraphx::MIXRShapedToMemoryLayoutConverter::
    MIXRShapedToMemoryLayoutConverter() {
  addConversion([](Type type) { return type; });
  addConversion(
      [](MIXRShapedType shaped) { return shaped.asFlatMemoryTensor(); });
}

//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//

template <typename TosaOp, typename... Args>
static TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc,
                               Type elemType, Args &&...args) {
  auto op =
      TosaOp::create(rewriter, loc, UnrankedTensorType::get(elemType), args...);
  InferShapedTypeOpInterface shapeInterface =
      cast<InferShapedTypeOpInterface>(op.getOperation());
  SmallVector<ShapedTypeComponents> returnShape;
  LogicalResult shapeInferenceStatus = shapeInterface.inferReturnTypeComponents(
      op.getContext(), op.getLoc(), op->getOperands(), op->getAttrDictionary(),
      op->getPropertiesStorage(), op->getRegions(), returnShape);
  assert(shapeInferenceStatus.succeeded());
  Type newOutTy = RankedTensorType::get({returnShape[0].getDims()}, elemType);
  auto result = op->getResult(0);
  result.setType(newOutTy);
  return op;
}

static Value createCastOp(PatternRewriter &rewriter, Location loc,
                          Type resElementType, Value input, Type inputType,
                          Type resElementTypeBeforeConvert = nullptr) {
  ShapedType shapedInputType = cast<ShapedType>(input.getType());
  Type resType = shapedInputType.cloneWith({}, resElementType);

  if (!resElementTypeBeforeConvert)
    resElementTypeBeforeConvert = resElementType;

  Value res;
  if (inputType.isUnsignedInteger() ||
      resElementTypeBeforeConvert.isUnsignedInteger()) {
    assert(!inputType.isSignedInteger() &&
           !resElementTypeBeforeConvert.isSignedInteger());
    res = tosa::CustomOp::create(rewriter, loc, resType, "unsigned_cast",
                                 "rocmlir", "", input)
              .getResult(0);
  } else {
    res = rewriter.createOrFold<tosa::CastOp>(loc, resType, input);
  }
  return res;
}

static Type getShapedElementTy(Value v) {
  return cast<ShapedType>(v.getType()).getElementType();
}

static Value getZeroTensor(Location loc, RankedTensorType type,
                           ConversionPatternRewriter &rewriter) {
  auto value = cast<ElementsAttr>(rewriter.getZeroAttr(type));
  return tosa::ConstOp::create(rewriter, loc, type, value);
}

static Value getZeroTensor(Location loc, Type elemType, ArrayRef<int64_t> shape,
                           ConversionPatternRewriter &rewriter) {
  auto tensorTy = RankedTensorType::get(shape, elemType);
  return tosa::ConstOp::create(
      rewriter, loc, tensorTy,
      cast<ElementsAttr>(rewriter.getZeroAttr(tensorTy)));
}

static tosa::TransposeOp getTransposeOp(Location loc, Value input,
                                        ConversionPatternRewriter &rewriter,
                                        ArrayRef<int32_t> permutation) {
  ShapedType inputTy = cast<ShapedType>(input.getType());
  auto inputShape = inputTy.getShape();
  SmallVector<int64_t> newShape;
  newShape.reserve(permutation.size());
  for (int32_t fromIdx : permutation)
    newShape.push_back(inputShape[fromIdx]);
  Type newTy = RankedTensorType::get(newShape, inputTy.getElementType());

  auto newOp =
      tosa::TransposeOp::create(rewriter, loc, newTy, input, permutation);
  return newOp;
}

//===----------------------------------------------------------------------===//
// The general one-to-one conversion and
//===----------------------------------------------------------------------===//

namespace {
template <typename MIGraphXOp, typename TosaOp>
struct TrivialConverter final : public OpConversionPattern<MIGraphXOp> {
  using OpConversionPattern<MIGraphXOp>::OpConversionPattern;
  using OpConversionPattern<MIGraphXOp>::getTypeConverter;
  using OpAdaptor = typename OpConversionPattern<MIGraphXOp>::OpAdaptor;

  LogicalResult
  matchAndRewrite(MIGraphXOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

template <typename MIGraphXOp, typename TosaOp>
LogicalResult TrivialConverter<MIGraphXOp, TosaOp>::matchAndRewrite(
    MIGraphXOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  SmallVector<Type, 1> types;
  if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), types)))
    return failure();
  SmallVector<NamedAttribute> filteredAttrs =
      llvm::to_vector(op->getDiscardableAttrDictionary());
  rewriter.replaceOpWithNewOp<TosaOp>(op, types, adaptor.getOperands(),
                                      filteredAttrs);
  return success();
}

//===----------------------------------------------------------------------===//
// Base kernels (convolution, gemm)
//===----------------------------------------------------------------------===//

namespace {
template <typename ConvType>
struct ConvConverter final : public OpConversionPattern<ConvType> {
  using OpConversionPattern<ConvType>::OpConversionPattern;
  using OpConversionPattern<ConvType>::getTypeConverter;
  using OpAdaptor = typename OpConversionPattern<ConvType>::OpAdaptor;

  // Note, this lowering pattern works for both migraphx.convolution and
  // migraphx.quant_convolution. The only difference between the two ops
  // is that quant_convolution allows convolution input and output to be
  // different types. Because of this, we use same lowering pattern but
  // different tablegen to capture the difference between the two ops.
  LogicalResult
  matchAndRewrite(ConvType op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

template <typename DotType>
struct DotConverter final : public OpConversionPattern<DotType> {
  using OpConversionPattern<DotType>::OpConversionPattern;
  using OpConversionPattern<DotType>::getTypeConverter;
  using OpAdaptor = typename OpConversionPattern<DotType>::OpAdaptor;

  LogicalResult
  matchAndRewrite(DotType op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

template <typename ConvType>
LogicalResult ConvConverter<ConvType>::matchAndRewrite(
    ConvType op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
  Location loc = op->getLoc();
  Value input = adaptor.getInput();
  auto inputType = cast<ShapedType>(input.getType());
  Value filter = adaptor.getFilter();
  ValueRange results = op->getResults();
  Type elementTy = inputType.getElementType();
  auto outputTy = cast<MIXRShapedType>(results[0].getType());
  Type outElementTy = outputTy.getElementType();
  Type newOutElementTy = getTypeConverter()->convertType(outElementTy);
  bool isBwdDataConvOp = isa<migraphx::ConvolutionBwdDataOp>(op);

  if (outElementTy.isUnsignedInteger())
    return op.emitError("No support for unsigned convolution.\n");

  int dims = outputTy.getShape().size() - 2;
  SmallVector<int32_t> toChannelLast{0};
  SmallVector<int32_t> fromChannelLast{0, dims + 1};
  for (int i = 0; i < dims; i++) {
    toChannelLast.push_back(i + 2);
    fromChannelLast.push_back(i + 1);
  }
  toChannelLast.push_back(1);

  // insert transpose to input and filter tensors
  input = getTransposeOp(loc, input, rewriter, toChannelLast);
  filter = getTransposeOp(loc, filter, rewriter, toChannelLast);
  ArrayRef<int64_t> outShape = outputTy.getShape();

  // original output shape was NCHW, change it into NHWC
  SmallVector<int64_t> newShape{outShape[0]};
  for (int i = 0; i < dims; i++)
    newShape.push_back(outShape[i + 2]);
  newShape.push_back(outShape[1]);
  Type newOutTy = RankedTensorType::get(newShape, newOutElementTy);

  // There is no tosa.conv1d or tosa.transpose_conv1d, so instead we'll add a
  // dummy x1 dimension to the input tensors, and make a tosa.conv2d.
  auto expandTo2D = [&rewriter, loc](mlir::Value value) {
    ArrayRef<int64_t> origShape = cast<ShapedType>(value.getType()).getShape();
    SmallVector<int64_t> expShape(origShape.drop_back());
    expShape.push_back(1);
    expShape.push_back(origShape.back());
    auto expShapeValue = tosa::getTosaConstShape(rewriter, loc, expShape);
    Value reshaped =
        tosa::ReshapeOp::create(rewriter, loc, value, expShapeValue);
    return reshaped;
  };

  // Construct a new Conv2DOp/TransposeConv2DOp.
  Operation *cop;
  Type new1DOutTy;
  Value inputZp, weightZp;
  switch (dims) {
  case 1:
    // Expand to do a conv2d/transpose_conv2d, because there's no 1d version of
    // the ops.
    newShape.insert(std::prev(newShape.end()), 1);
    new1DOutTy = RankedTensorType::get(newShape, newOutElementTy);
    input = expandTo2D(input);
    filter = expandTo2D(filter);
    inputZp =
        tosa::createZeroPointTensor(rewriter, loc, input.getType(), 0).value();
    weightZp =
        tosa::createZeroPointTensor(rewriter, loc, filter.getType(), 0).value();

    if (isBwdDataConvOp) {
      cop = tosa::TransposeConv2DOp::create(
          rewriter, loc, new1DOutTy,
          ValueRange{input, filter,
                     getZeroTensor(loc, newOutElementTy,
                                   cast<ShapedType>(new1DOutTy).getShape()[3],
                                   rewriter),
                     inputZp, weightZp});
    } else {
      cop = tosa::Conv2DOp::create(
          rewriter, loc, new1DOutTy,
          ValueRange{
              input, filter,
              getZeroTensor(loc, newOutElementTy,
                            cast<ShapedType>(filter.getType()).getShape()[0],
                            rewriter),
              inputZp, weightZp});
    }
    break;

  case 2:
    inputZp =
        tosa::createZeroPointTensor(rewriter, loc, input.getType(), 0).value();
    weightZp =
        tosa::createZeroPointTensor(rewriter, loc, filter.getType(), 0).value();
    if (isBwdDataConvOp) {
      cop = tosa::TransposeConv2DOp::create(
          rewriter, loc, newOutTy,
          ValueRange{input, filter,
                     getZeroTensor(loc, newOutElementTy,
                                   cast<ShapedType>(newOutTy).getShape()[3],
                                   rewriter),
                     inputZp, weightZp});
    } else {
      cop = tosa::Conv2DOp::create(
          rewriter, loc, newOutTy,
          ValueRange{
              input, filter,
              getZeroTensor(loc, newOutElementTy,
                            cast<ShapedType>(filter.getType()).getShape()[0],
                            rewriter),
              inputZp, weightZp});
    }
    break;
  case 3:
    if (isBwdDataConvOp)
      return op->emitError("Only 1-D and 2-D backwards convolution ops are "
                           "supported");

    inputZp =
        tosa::createZeroPointTensor(rewriter, loc, input.getType(), 0).value();
    weightZp =
        tosa::createZeroPointTensor(rewriter, loc, filter.getType(), 0).value();
    cop = tosa::Conv3DOp::create(
        rewriter, loc, newOutTy,
        ValueRange{
            input, filter,
            getZeroTensor(loc, newOutElementTy,
                          cast<ShapedType>(filter.getType()).getShape()[0],
                          rewriter),
            inputZp, weightZp});
    break;
  default:
    return op->emitError("Only 1-D, 2-D, and 3-D have been implemented.");
    break;
  }

  // translate attributes
  auto padAttr = cast<ArrayAttr>(op->getAttr("padding"));
  auto strideAttr = cast<ArrayAttr>(op->getAttr("stride"));
  auto dilationAttr = cast<ArrayAttr>(op->getAttr("dilation"));
  // MIGraphX padAttr is [hlow, wlow, hhigh, whigh] while TOSA padAttr
  // is [hlow, hhigh, wlow, whigh].
  SmallVector<int64_t> pads;
  for (int i = 0; i < dims; i++) {
    pads.push_back(dyn_cast<IntegerAttr>(padAttr[i]).getInt());
    pads.push_back(dyn_cast<IntegerAttr>(padAttr[i + dims]).getInt());
  }

  SmallVector<int64_t> strides;
  SmallVector<int64_t> dilations;
  for (size_t i = 0; i < strideAttr.size(); i++) {
    strides.push_back(dyn_cast<IntegerAttr>(strideAttr[i]).getInt());
    dilations.push_back(dyn_cast<IntegerAttr>(dilationAttr[i]).getInt());
  }

  // Determine the accumulation type based on the output type.
  Type accType;
  if (isa<FloatType>(elementTy)) {
    accType = rewriter.getF32Type();
    // accType is not used by rocMLIR when converting tosa to rock.
  } else if (isa<IntegerType>(elementTy)) {
    accType = rewriter.getI32Type();
  } else {
    llvm_unreachable("not expected type");
  }
  // convolution config attributes
  if (dims == 1) {
    if ((dilations.size() != 1) || (strides.size() != 1) ||
        (pads.size() != 2)) {
      return op->emitError(
          "1-D convolution has improper dilation, stride, or pad.");
    }
    dilations.push_back(1);
    strides.push_back(1);
    pads.push_back(0);
    pads.push_back(0);
  }

  // Set attributes common to both forwards and backwards conv
  cop->setAttr("dilation", rewriter.getDenseI64ArrayAttr(dilations));
  cop->setAttr("stride", rewriter.getDenseI64ArrayAttr(strides));
  cop->setAttr("acc_type", TypeAttr::get(accType));
  int64_t group = op.getGroup();
  cop->setAttr("group", rewriter.getI64IntegerAttr(group));

  // Set padding for forwards and backwards convolution. Note: the padding here
  // applies to input padding. TransposeConv2D will still require an output pad
  // attribute, so we can just set that to zeros
  if (isBwdDataConvOp) {
    SmallVector<int64_t> zeroPads(pads.size(), 0);
    cop->setAttr("out_pad", rewriter.getDenseI64ArrayAttr(zeroPads));
  }
  cop->setAttr("pad", rewriter.getDenseI64ArrayAttr(pads));

  // For both types of backwards convolution, we will be using
  // tosa.transpose_conv2d, so we are going to add a conv_kind attribute so
  // that we can distinguish between the two types in TosaToRock.
  // TODO: We will need to add conv_kind = "bwd_weight" when we eventually
  // add support for bwd_weight ops in MIGraphX.
  if (isa<migraphx::ConvolutionBwdDataOp>(op)) {
    cop->setAttr("conv_kind", rewriter.getStringAttr("bwd_data"));
  }

  // Convert optional attributes
  if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config"))
    cop->setAttr("perf_config", attr);

  if (dims == 1) {
    auto shapeValue = tosa::getTosaConstShape(
        rewriter, loc, cast<ShapedType>(newOutTy).getShape());
    cop = tosa::ReshapeOp::create(rewriter, loc, cop->getResult(0), shapeValue);
  }

  // transpose the output back to NCHW so that it can match following
  // operators.
  auto top = getTransposeOp(loc, cop->getResult(0), rewriter, fromChannelLast);
  rewriter.replaceOp(op, top);
  return success();
}

template <typename DotType>
LogicalResult DotConverter<DotType>::matchAndRewrite(
    DotType op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
  Location loc = op->getLoc();
  auto inA = cast<TypedValue<RankedTensorType>>(adaptor.getInA());
  auto inB = cast<TypedValue<RankedTensorType>>(adaptor.getInB());
  auto results = op->getResults();
  Type elementTy = inA.getType().getElementType();
  auto origOutputTy = cast<MIXRShapedType>(results[0].getType());
  Type outElementTy = origOutputTy.getElementType();
  Type newOutElementTy = getTypeConverter()->convertType(outElementTy);

  if (outElementTy.isUnsignedInteger())
    return op.emitError("No support for unsigned dot product.\n");

  // check batch dimension. Tosa matmul only allow a single dimension for it,
  // add reshape ops to flatten and restore the original dimension.
  ArrayRef<int64_t> origOutDims = origOutputTy.getShape();
  RankedTensorType newOutType =
      RankedTensorType::get(origOutDims, newOutElementTy);
  size_t outRank = origOutDims.size();
  ArrayRef<int64_t> orgDimsA = inA.getType().getShape();
  ArrayRef<int64_t> orgDimsB = inB.getType().getShape();
  size_t rankA = orgDimsA.size();
  size_t rankB = orgDimsB.size();

  // A, B, Out have the same rank. rank=2 assumes batch=1.
  // Here handling special cases.
  if (outRank != 3 || rankA != rankB ||
      (outRank == 3 && orgDimsA[0] != orgDimsB[0])) {
    int64_t batchSizeA = 1, batchSizeB = 1, batchSizeC = 1;
    for (size_t i = 0; i < outRank - 2; i++) {
      batchSizeC *= origOutDims[i];
    }
    for (size_t i = 0; i < rankA - 2; i++) {
      batchSizeA *= orgDimsA[i];
    }
    for (size_t i = 0; i < rankB - 2; i++) {
      batchSizeB *= orgDimsB[i];
    }

    int64_t newDimsA[3] = {batchSizeA, orgDimsA[outRank - 2],
                           orgDimsA[outRank - 1]};
    int64_t newDimsB[3] = {batchSizeB, orgDimsB[outRank - 2],
                           orgDimsB[outRank - 1]};
    int64_t newDimsOut[3] = {batchSizeC, origOutDims[outRank - 2],
                             origOutDims[outRank - 1]};
    if (batchSizeA != batchSizeB) {
      // support when batchB dimension is broadcast
      if (batchSizeB == 1) {
        // modify [g, m, k, n] to [1, g*m, k, n]
        newDimsA[0] = 1;
        newDimsA[1] *= batchSizeA;
        newDimsOut[0] = 1;
        newDimsOut[1] *= batchSizeC;
      } else {
        // currently not supporting the other case, broadcast A could be
        // supported with an additional transpose.
        return op->emitError("tosa.matmul can't broadcast input.");
      }
    }
    RankedTensorType newAType = RankedTensorType::get(newDimsA, elementTy);
    RankedTensorType newBType = RankedTensorType::get(newDimsB, elementTy);
    newOutType = RankedTensorType::get(newDimsOut, newOutElementTy);
    auto newDimsAValue = tosa::getTosaConstShape(rewriter, loc, newDimsA);
    auto reshapeAOp =
        tosa::ReshapeOp::create(rewriter, loc, newAType, inA, newDimsAValue);
    auto newDimsBValue = tosa::getTosaConstShape(rewriter, loc, newDimsB);
    auto reshapeBOp =
        tosa::ReshapeOp::create(rewriter, loc, newBType, inB, newDimsBValue);

    // reassign inputs.
    inA = cast<TypedValue<RankedTensorType>>(reshapeAOp.getResult());
    inB = cast<TypedValue<RankedTensorType>>(reshapeBOp.getResult());
  }
  auto aZp =
      tosa::createZeroPointTensor(rewriter, loc, inA.getType(), 0).value();
  auto bZp =
      tosa::createZeroPointTensor(rewriter, loc, inB.getType(), 0).value();
  // Construct tosa.matmul.
  auto mop =
      tosa::MatMulOp::create(rewriter, loc, newOutType, inA, inB, aZp, bZp);

  // Determine the accumulation type based on the output type.
  Type accType;
  if (isa<FloatType>(elementTy)) {
    accType = rewriter.getF32Type();
  } else if (isa<IntegerType>(elementTy)) {
    accType = rewriter.getI32Type();
  } else {
    llvm_unreachable("not expected type");
  }
  mop->setAttr("acc_type", TypeAttr::get(accType));

  // Convert optional attributes
  if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config"))
    mop->setAttr("perf_config", attr);

  if (outRank != 3 || rankA != rankB ||
      (outRank == 3 && orgDimsA[0] != orgDimsB[0])) {
    auto origOutDimsValue = tosa::getTosaConstShape(rewriter, loc, origOutDims);
    auto rop = tosa::ReshapeOp::create(
        rewriter, loc, getTypeConverter()->convertType(origOutputTy), mop,
        origOutDimsValue);
    rewriter.replaceOp(op, rop);
    return success();
  }
  rewriter.replaceOp(op, mop);
  return success();
}

//===----------------------------------------------------------------------===//
// Tensor views and shape manipulation
//===----------------------------------------------------------------------===//
namespace {
struct BroadcastConverter final
    : public OpConversionPattern<migraphx::BroadcastOp> {
  using OpConversionPattern<migraphx::BroadcastOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::BroadcastOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct MultiBroadcastConverter final
    : public OpConversionPattern<migraphx::MultiBroadcastOp> {
  using OpConversionPattern<migraphx::MultiBroadcastOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::MultiBroadcastOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct TransposeConverter final
    : public OpConversionPattern<migraphx::TransposeOp> {
  using OpConversionPattern<migraphx::TransposeOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::TransposeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct ReshapeConverter final
    : public OpConversionPattern<migraphx::ReshapeOp> {
  using OpConversionPattern<migraphx::ReshapeOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::ReshapeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct SliceConverter final : public OpConversionPattern<migraphx::SliceOp> {
  using OpConversionPattern<migraphx::SliceOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::SliceOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

LogicalResult
BroadcastConverter::matchAndRewrite(migraphx::BroadcastOp op, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
  Location loc = op->getLoc();
  ArrayRef<int64_t> inShape = op.getInput().getType().getShape();
  ArrayRef<int64_t> outShape = op.getOutput().getType().getShape();
  uint32_t outRank = op.getOutput().getType().getRank();
  Type elemType = op.getOutput().getType().getElementType();
  Type newOutElementTy = getTypeConverter()->convertType(elemType);
  auto axis =
      static_cast<size_t>(cast<IntegerAttr>(op->getAttr("axis")).getInt());

  SmallVector<int64_t, 5> newShape;
  for (uint32_t i = 0; i < outRank; i++) {
    if (i >= axis && (i - axis) < inShape.size()) {
      newShape.push_back(inShape[i - axis]);
    } else {
      newShape.push_back(1);
    }
  }
  auto newShapeValue = tosa::getTosaConstShape(rewriter, loc, newShape);
  tosa::ReshapeOp sameRankReshapedOp = createOpAndInfer<tosa::ReshapeOp>(
      rewriter, loc, newOutElementTy, adaptor.getInput(), newShapeValue);

  auto outType = RankedTensorType::get(outShape, newOutElementTy);
  // We create a dummy zero addition with implicit broadcasting
  // because tosa does not have an explicit broadcast op
  auto zeroTensor = getZeroTensor(loc, outType, rewriter);
  auto addWithZero = createOpAndInfer<tosa::AddOp>(
      rewriter, loc, newOutElementTy, zeroTensor, sameRankReshapedOp);

  rewriter.replaceOp(op, addWithZero);
  return success();
}

LogicalResult MultiBroadcastConverter::matchAndRewrite(
    migraphx::MultiBroadcastOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op->getLoc();
  auto inType = cast<RankedTensorType>(adaptor.getInput().getType());
  auto outType = cast<RankedTensorType>(
      getTypeConverter()->convertType(op.getOutput().getType()));
  ArrayRef<int64_t> outShape = outType.getShape();
  ArrayRef<int64_t> outStrides = op.getOutput().getType().getStrides();
  uint32_t inRank = inType.getRank();
  uint32_t outRank = outType.getRank();
  Type elemType = outType.getElementType();

  // If its a splat constant, we can broadcast it trivially
  if (tosa::ConstOp constOp =
          adaptor.getInput().getDefiningOp<tosa::ConstOp>()) {
    if (constOp.getValuesAttr().isSplat()) {
      auto outTy = RankedTensorType::get(outShape, elemType);
      auto bcastConstAttr = DenseElementsAttr::get(
          outTy, cast<DenseElementsAttr>(constOp.getValuesAttr())
                     .getSplatValue<Attribute>());
      tosa::ConstOp newConstOp =
          tosa::ConstOp::create(rewriter, loc, outTy, bcastConstAttr);
      rewriter.replaceOp(op, newConstOp);
      return success();
    }
  }

  if (outRank < inRank) {
    return op.emitError("MultiBroadcastOp shouldn't reduce rank.\n");
  }

  Value replacingValue = adaptor.getInput();
  if (outRank > inRank) {
    SmallVector<int64_t, 5> newShape;
    newShape.reserve(outRank);
    for (auto [len, stride] : llvm::zip_equal(outShape, outStrides))
      newShape.push_back(stride == 0 ? 1 : len);

    auto newShapeValue = tosa::getTosaConstShape(rewriter, loc, newShape);
    tosa::ReshapeOp sameRankReshapedOp = createOpAndInfer<tosa::ReshapeOp>(
        rewriter, loc, elemType, adaptor.getInput(), newShapeValue);
    replacingValue = sameRankReshapedOp.getResult();
  }

  // We create a dummy zero addition with implicit broadcasting
  // because tosa does not have an explicit broadcast op
  auto zeroTensor = getZeroTensor(loc, outType, rewriter);
  auto addWithZero = createOpAndInfer<tosa::AddOp>(rewriter, loc, elemType,
                                                   zeroTensor, replacingValue);

  rewriter.replaceOp(op, addWithZero);
  return success();
}

LogicalResult
TransposeConverter::matchAndRewrite(migraphx::TransposeOp op, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  SmallVector<int32_t, 4> permutation;
  permutation.reserve(op.getPermutation().size());
  for (auto permElem : op.getPermutation().getAsRange<IntegerAttr>())
    permutation.push_back(permElem.getInt());
  Value result = getTransposeOp(loc, adaptor.getInput(), rewriter, permutation);
  rewriter.replaceOp(op, result);
  return success();
}

LogicalResult
ReshapeConverter::matchAndRewrite(migraphx::ReshapeOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
  Location loc = op->getLoc();
  Value input = adaptor.getInput();
  ArrayAttr dims = adaptor.getDims();
  Type outputTy = getTypeConverter()->convertType(op.getOutput().getType());
  SmallVector<int64_t, 5> newShape;
  for (auto dim : dims) {
    newShape.push_back(dyn_cast<IntegerAttr>(dim).getInt());
  }

  auto newShapeValue = tosa::getTosaConstShape(rewriter, loc, newShape);
  auto rop =
      tosa::ReshapeOp::create(rewriter, loc, outputTy, input, newShapeValue);

  rewriter.replaceOp(op, rop);
  return success();
}

LogicalResult
SliceConverter::matchAndRewrite(migraphx::SliceOp op, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
  Location loc = op->getLoc();
  SmallVector<int64_t, 5> start;
  SmallVector<int64_t, 5> size;
  ArrayAttr axes = op.getAxes();
  ArrayAttr axesStarts = op.getStarts();
  ArrayAttr axesEnds = op.getEnds();

  Value input = adaptor.getInput();
  auto newInType = cast<RankedTensorType>(input.getType());
  ArrayRef<int64_t> inShape = newInType.getShape();
  for (int64_t dim : inShape) {
    start.push_back(0);
    size.push_back(dim);
  }

  for (auto [axis, axisS, axisE] : llvm::zip(axes, axesStarts, axesEnds)) {
    int64_t axisInt = cast<IntegerAttr>(axis).getInt();
    int64_t axisSInt = cast<IntegerAttr>(axisS).getInt();
    int64_t axisEInt = cast<IntegerAttr>(axisE).getInt();
    start[axisInt] = axisSInt;
    size[axisInt] = axisEInt - axisSInt;
  }

  auto startValue = tosa::getTosaConstShape(rewriter, loc, start);
  auto sizeValue = tosa::getTosaConstShape(rewriter, loc, size);
  auto sliceOp = createOpAndInfer<tosa::SliceOp>(
      rewriter, loc, newInType.getElementType(), input, startValue, sizeValue);
  rewriter.replaceOp(op, sliceOp);
  return success();
}

//===----------------------------------------------------------------------===//
// Reductions
//===----------------------------------------------------------------------===//
namespace {
struct ReduceMeanConverter final
    : public OpConversionPattern<migraphx::ReduceMeanOp> {
  using OpConversionPattern<migraphx::ReduceMeanOp>::OpConversionPattern;

  tosa::ConstOp createNumElementsTosaConst(
      Location loc, TypedValue<RankedTensorType> inputTensor,
      IntegerAttr axisAttr, ConversionPatternRewriter &rewriter) const;

  LogicalResult
  matchAndRewrite(migraphx::ReduceMeanOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

namespace {
template <typename MIGraphXOp, typename TosaOp>
struct ReduceConverter final : public OpConversionPattern<MIGraphXOp> {
  using OpConversionPattern<MIGraphXOp>::OpConversionPattern;
  using OpAdaptor = typename OpConversionPattern<MIGraphXOp>::OpAdaptor;

  LogicalResult
  matchAndRewrite(MIGraphXOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

} // namespace

tosa::ConstOp ReduceMeanConverter::createNumElementsTosaConst(
    Location loc, TypedValue<RankedTensorType> inputTensor,
    IntegerAttr axisAttr, ConversionPatternRewriter &rewriter) const {
  Type elementType = inputTensor.getType().getElementType();
  int64_t axis = axisAttr.getValue().getSExtValue();
  Attribute numElements;
  if (elementType.isIntOrIndex()) {
    numElements = rewriter.getIntegerAttr(
        elementType, inputTensor.getType().getShape()[axis]);
  } else {
    numElements = rewriter.getFloatAttr(
        elementType,
        (static_cast<double>(inputTensor.getType().getShape()[axis])));
  }
  RankedTensorType tensorType = RankedTensorType::get({1}, elementType);
  return tosa::ConstOp::create(
      rewriter, loc, tensorType,
      DenseElementsAttr::get(tensorType, {numElements}));
}

LogicalResult ReduceMeanConverter::matchAndRewrite(
    migraphx::ReduceMeanOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  ArrayRef<Attribute> axes = op.getAxes().getValue();
  if (axes.size() != 1) {
    return op.emitError("We only support single axes reductions!");
  }
  IntegerAttr axis =
      rewriter.getI32IntegerAttr(cast<IntegerAttr>(axes[0]).getInt());
  auto input = cast<TypedValue<RankedTensorType>>(adaptor.getInput());
  Type elementType = input.getType().getElementType();

  tosa::ConstOp tosaConstantNumElements =
      createNumElementsTosaConst(loc, input, axis, rewriter);
  auto tosaReciprocal = createOpAndInfer<tosa::ReciprocalOp>(
      rewriter, loc, elementType, tosaConstantNumElements);

  // reshape tosaReciprocal to have same number of dimensions as the input
  llvm::SmallVector<int64_t> newShape(input.getType().getRank(), 1);
  auto shapeValue = tosa::getTosaConstShape(rewriter, loc, newShape);
  Value tosaReciprocalReshaped = createOpAndInfer<tosa::ReshapeOp>(
      rewriter, loc, elementType, tosaReciprocal, shapeValue);

  auto shiftType = RankedTensorType::get({1}, rewriter.getIntegerType(8));
  auto shiftZeroAttr = DenseElementsAttr::get(
      shiftType, rewriter.getZeroAttr(rewriter.getIntegerType(8)));
  Value constZero =
      tosa::ConstOp::create(rewriter, loc, shiftType, shiftZeroAttr);
  auto tosaMul = createOpAndInfer<tosa::MulOp>(
      rewriter, loc, elementType, adaptor.getInput(), tosaReciprocalReshaped,
      /*shift=*/constZero);
  auto tosaReduceSum = createOpAndInfer<tosa::ReduceSumOp>(
      rewriter, loc, elementType, tosaMul, axis);
  rewriter.replaceOp(op, tosaReduceSum);
  return success();
}

template <typename MIGraphXOp, typename TosaOp>
LogicalResult ReduceConverter<MIGraphXOp, TosaOp>::matchAndRewrite(
    MIGraphXOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  ArrayRef<Attribute> axes = op.getAxes().getValue();
  if (axes.size() != 1) {
    return op.emitError("We only support single axes reductions!");
  }
  IntegerAttr axis =
      rewriter.getI32IntegerAttr(cast<IntegerAttr>(axes[0]).getInt());
  auto input = cast<TypedValue<RankedTensorType>>(adaptor.getInput());
  Type elementType = input.getType().getElementType();
  auto tosaReduce =
      createOpAndInfer<TosaOp>(rewriter, loc, elementType, input, axis);
  rewriter.replaceOp(op, tosaReduce);
  return success();
}

//===----------------------------------------------------------------------===//
// Binary operations
//===----------------------------------------------------------------------===//
namespace {
struct DivConverter final : public OpConversionPattern<migraphx::DivOp> {
  using OpConversionPattern<migraphx::DivOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::DivOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct MulConverter final : public OpConversionPattern<migraphx::MulOp> {
  using OpConversionPattern<migraphx::MulOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::MulOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

LogicalResult
DivConverter::matchAndRewrite(migraphx::DivOp op, OpAdaptor adaptor,
                              ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  auto inATensor = cast<TypedValue<RankedTensorType>>(adaptor.getInA());
  auto inBTensor = cast<TypedValue<RankedTensorType>>(adaptor.getInB());
  Type elementType = inATensor.getType().getElementType();
  if (isa<IntegerType>(elementType)) {
    auto origAElementType = op.getInA().getType().getElementType();
    auto origBElementType = op.getInB().getType().getElementType();
    Value div;
    if (origAElementType.isUnsignedInteger() ||
        origBElementType.isUnsignedInteger()) {
      if (origAElementType != origBElementType)
        return op->emitError("Types of A and B must be the same");
      mlir::SmallVector<mlir::Value, 2> inputs = {inATensor, inBTensor};
      auto op = tosa::CustomOp::create(rewriter, loc, inATensor.getType(),
                                       "unsigned_div", "rocmlir", "", inputs);
      div = op->getResult(0);
    } else {
      div = createOpAndInfer<tosa::IntDivOp>(rewriter, loc, elementType,
                                             inATensor, inBTensor);
    }
    rewriter.replaceOp(op, div);
    return success();
  }
  Value recip = createOpAndInfer<tosa::ReciprocalOp>(rewriter, loc, elementType,
                                                     inBTensor);

  auto shiftType = RankedTensorType::get({1}, rewriter.getIntegerType(8));
  auto shiftZeroAttr = DenseElementsAttr::get(
      shiftType, rewriter.getZeroAttr(rewriter.getIntegerType(8)));
  Value constZero =
      tosa::ConstOp::create(rewriter, loc, shiftType, shiftZeroAttr);
  Value mul = createOpAndInfer<tosa::MulOp>(
      rewriter, loc, elementType, inATensor, recip, /*shift=*/constZero);
  rewriter.replaceOp(op, mul);
  return success();
}

LogicalResult
MulConverter::matchAndRewrite(migraphx::MulOp op, OpAdaptor adaptor,
                              ConversionPatternRewriter &rewriter) const {
  auto shiftType = RankedTensorType::get({1}, rewriter.getIntegerType(8));
  auto shiftZeroAttr = DenseElementsAttr::get(
      shiftType, rewriter.getZeroAttr(rewriter.getIntegerType(8)));
  Value constZero =
      tosa::ConstOp::create(rewriter, op->getLoc(), shiftType, shiftZeroAttr);
  rewriter.replaceOpWithNewOp<tosa::MulOp>(
      op, getTypeConverter()->convertType(op.getResult().getType()),
      adaptor.getInA(), adaptor.getInB(), /*shift=*/constZero);
  return success();
}

//===----------------------------------------------------------------------===//
// Unary operations
//===----------------------------------------------------------------------===//
namespace {
struct DeQuantizeLinearConverter final
    : public OpConversionPattern<migraphx::DeQuantizeLinearOp> {
  using OpConversionPattern<migraphx::DeQuantizeLinearOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::DeQuantizeLinearOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct QuantizeLinearConverter final
    : public OpConversionPattern<migraphx::QuantizeLinearOp> {
  using OpConversionPattern<migraphx::QuantizeLinearOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::QuantizeLinearOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct ConvertConverter final
    : public OpConversionPattern<migraphx::ConvertOp> {
  using OpConversionPattern<migraphx::ConvertOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::ConvertOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct NegConverter final : public OpConversionPattern<migraphx::NegOp> {
  using OpConversionPattern<migraphx::NegOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::NegOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct ReluConverter final : public OpConversionPattern<migraphx::ReluOp> {
  using OpConversionPattern<migraphx::ReluOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::ReluOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct SoftmaxConverter final
    : public OpConversionPattern<migraphx::SoftmaxOp> {
  using OpConversionPattern<migraphx::SoftmaxOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::SoftmaxOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

// MIGraphX implements:
// Let T = scale element type
// output[i] = (convert<T>(input[i]) - convert<T>(zero_pts[i])) * scales[i];
// For f32, this matches ONNX reference, dequantizing to f16, if it's ever done
// will be less precise than the reference but that's probably fine.
LogicalResult DeQuantizeLinearConverter::matchAndRewrite(
    migraphx::DeQuantizeLinearOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Value input = adaptor.getInput();
  Value scale = adaptor.getScale();
  Value output = op.getOutput();
  Location loc = op->getLoc();

  Type origOutputType = getShapedElementTy(output);
  Type outputType = getTypeConverter()->convertType(origOutputType);
  Value upcastInput =
      createCastOp(rewriter, loc, outputType, input,
                   op.getInput().getType().getElementType(), origOutputType);

  Value shifted = upcastInput;
  if (auto bias = adaptor.getBias()) {
    Value upcastBias =
        createCastOp(rewriter, loc, outputType, bias,
                     op.getBias().getType().getElementType(), origOutputType);
    shifted = createOpAndInfer<tosa::SubOp>(rewriter, loc, outputType,
                                            upcastInput, upcastBias);
  }

  auto shiftType = RankedTensorType::get({1}, rewriter.getIntegerType(8));
  auto shiftZeroAttr = DenseElementsAttr::get(
      shiftType, rewriter.getZeroAttr(rewriter.getIntegerType(8)));
  Value constZero =
      tosa::ConstOp::create(rewriter, loc, shiftType, shiftZeroAttr);
  Value scaled = createOpAndInfer<tosa::MulOp>(
      rewriter, loc, outputType, shifted, scale, /*shift=*/constZero);

  rewriter.replaceOp(op, scaled);
  return success();
}

// MIGraphX pseudo code:
// int32_t quantized = static_cast<int32>(
//      std::round(input[i] / scales[i])) + zero_pts[i];
// output[i] = std::max(-128, std::min(127, quantized));
LogicalResult QuantizeLinearConverter::matchAndRewrite(
    migraphx::QuantizeLinearOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  Value input = adaptor.getInput();
  Value scale = adaptor.getScale();
  Value bias = adaptor.getBias();
  Value output = op.getOutput();

  Type elementType = getShapedElementTy(input);
  Value inverseScale =
      createOpAndInfer<tosa::ReciprocalOp>(rewriter, loc, elementType, scale);

  auto shiftType = RankedTensorType::get({1}, rewriter.getIntegerType(8));
  auto shiftZeroAttr = DenseElementsAttr::get(
      shiftType, rewriter.getZeroAttr(rewriter.getIntegerType(8)));
  Value constZero =
      tosa::ConstOp::create(rewriter, loc, shiftType, shiftZeroAttr);
  Value scaled = createOpAndInfer<tosa::MulOp>(
      rewriter, loc, elementType, input, inverseScale, /*shift=*/constZero);

  Type origOutputType = getShapedElementTy(output);
  Type outputType = getTypeConverter()->convertType(origOutputType);
  // If there is a bias, we upcast to the larger of the bias type and int32_t
  // or float (which is what the bias type is in dequantize, the MLIR
  // quantization implementation, and other ML frameworks) and then do a
  // clamping truncation to the output type so that adding a bias saturates
  // instead of overflowing.
  Type biasType = outputType;
  if (bias) {
    biasType = getShapedElementTy(bias);
  }
  if ((bias || origOutputType != outputType) &&
      biasType.getIntOrFloatBitWidth() < 32) {
    biasType = isa<IntegerType>(biasType) ? cast<Type>(rewriter.getI32Type())
                                          : cast<Type>(rewriter.getF32Type());
  }
  Value asShort = createCastOp(rewriter, loc, biasType, scaled, elementType);
  Value biased = asShort;
  if (bias) {
    bias = createCastOp(rewriter, loc, biasType, bias,
                        op.getBias().getType().getElementType());
    biased =
        createOpAndInfer<tosa::AddOp>(rewriter, loc, biasType, asShort, bias);
  }

  Value result = biased;
  if (biasType != outputType) {
    unsigned width = outputType.getIntOrFloatBitWidth();
    APInt minI(width, 0), maxI(width, 0);
    // Must be floats because tosa.clamp expects a f32 attribute specifically.
    APFloat minF(0.0f), maxF(0.0f);
    if (auto outFloatType = dyn_cast<FloatType>(outputType)) {
      const llvm::fltSemantics &outSem = outFloatType.getFloatSemantics();
      const llvm::fltSemantics &biasSem =
          cast<FloatType>(biasType).getFloatSemantics();
      minF = APFloat::getLargest(outSem, /*Negative=*/true);
      maxF = APFloat::getLargest(outSem, /*Negative=*/false);
      bool itsExtendNoWayWeCanLoseInfo = false;
      std::ignore = minF.convert(biasSem, APFloat::rmNearestTiesToEven,
                                 &itsExtendNoWayWeCanLoseInfo);
      std::ignore = maxF.convert(biasSem, APFloat::rmNearestTiesToEven,
                                 &itsExtendNoWayWeCanLoseInfo);
      minI = APInt(64, (int64_t)(minF.convertToFloat()));
      maxI = APInt(64, (int64_t)(minF.convertToFloat()));
    } else {
      minI = origOutputType.isUnsignedInteger()
                 ? APInt::getMinValue(width)
                 : APInt::getSignedMinValue(width);
      maxI = origOutputType.isUnsignedInteger()
                 ? APInt::getMaxValue(width)
                 : APInt::getSignedMaxValue(width);
      minF.convertFromAPInt(minI, /*IsSigned=*/origOutputType.isSignedInteger(),
                            APFloat::rmNearestTiesToEven);
      maxF.convertFromAPInt(maxI, /*IsSigned=*/origOutputType.isSignedInteger(),
                            APFloat::rmNearestTiesToEven);
    }

    Attribute minVal, maxVal;
    if (isa<IntegerType>(biasType)) {
      auto minValUI64 = origOutputType.isUnsignedInteger()
                            ? minI.getZExtValue()
                            : minI.getSExtValue();
      auto maxValUI64 = origOutputType.isUnsignedInteger()
                            ? maxI.getZExtValue()
                            : maxI.getSExtValue();
      minVal = rewriter.getIntegerAttr(biasType, minValUI64);
      maxVal = rewriter.getIntegerAttr(biasType, maxValUI64);
    } else if (isa<FloatType>(biasType)) {
      minVal = rewriter.getFloatAttr(biasType, minF);
      maxVal = rewriter.getFloatAttr(biasType, maxF);
    } else {
      llvm_unreachable("unknown type for QuantizeLinearConverter");
    }
    result = createOpAndInfer<tosa::ClampOp>(rewriter, loc, biasType, result,
                                             minVal, maxVal);
    result = createCastOp(rewriter, loc, outputType, result, biasType,
                          origOutputType);
  }
  rewriter.replaceOp(op, result);

  return success();
}

LogicalResult
ConvertConverter::matchAndRewrite(migraphx::ConvertOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {

  auto inputType = op.getInA().getType().getElementType();
  auto outputType = op.getResult().getType().getElementType();
  if (inputType.isUnsignedInteger() || outputType.isUnsignedInteger()) {
    assert(!inputType.isSignedInteger() && !outputType.isSignedInteger());
    rewriter.replaceOpWithNewOp<tosa::CustomOp>(
        op, getTypeConverter()->convertType(op.getResult().getType()),
        "unsigned_cast", "rocmlir", "", adaptor.getInA());
  } else {
    rewriter.replaceOpWithNewOp<tosa::CastOp>(
        op, getTypeConverter()->convertType(op.getResult().getType()),
        adaptor.getInA());
  }
  return success();
}

LogicalResult
NegConverter::matchAndRewrite(migraphx::NegOp op, OpAdaptor adaptor,
                              ConversionPatternRewriter &rewriter) const {
  auto outElementType = op.getResult().getType().getElementType();
  if (outElementType.isUnsignedInteger())
    return op.emitOpError("can't negate an unsigned int type");

  rewriter.replaceOpWithNewOp<tosa::NegateOp>(
      op, getTypeConverter()->convertType(op.getResult().getType()),
      adaptor.getInA());
  return success();
}

LogicalResult
ReluConverter::matchAndRewrite(migraphx::ReluOp op, OpAdaptor adaptor,
                               ConversionPatternRewriter &rewriter) const {
  Value inA = adaptor.getInA();
  auto outType = cast<RankedTensorType>(
      getTypeConverter()->convertType(op.getResult().getType()));
  auto zero = getZeroTensor(op.getLoc(), outType, rewriter);
  // Since the zero is second, this handles any implicit broadcast.
  rewriter.replaceOpWithNewOp<tosa::MaximumOp>(op, outType, inA, zero);
  return success();
}

LogicalResult
SoftmaxConverter::matchAndRewrite(migraphx::SoftmaxOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  Value input = adaptor.getInput();
  IntegerAttr axisAttr = rewriter.getI32IntegerAttr(op.getAxisAttr().getInt());
  ShapedType inputType = cast<ShapedType>(input.getType());
  Type elementType = inputType.getElementType();

  auto tosaMax = createOpAndInfer<tosa::ReduceMaxOp>(rewriter, loc, elementType,
                                                     input, axisAttr);
  auto tosaSub =
      createOpAndInfer<tosa::SubOp>(rewriter, loc, elementType, input, tosaMax);
  auto tosaExp =
      createOpAndInfer<tosa::ExpOp>(rewriter, loc, elementType, tosaSub);
  auto tosaReduceSum = createOpAndInfer<tosa::ReduceSumOp>(
      rewriter, loc, elementType, tosaExp, axisAttr);
  auto tosaReciprocal = createOpAndInfer<tosa::ReciprocalOp>(
      rewriter, loc, elementType, tosaReduceSum);

  auto shiftType = RankedTensorType::get({1}, rewriter.getIntegerType(8));
  auto shiftZeroAttr = DenseElementsAttr::get(
      shiftType, rewriter.getZeroAttr(rewriter.getIntegerType(8)));
  Value constZero =
      tosa::ConstOp::create(rewriter, loc, shiftType, shiftZeroAttr);
  auto tosaMul = createOpAndInfer<tosa::MulOp>(
      rewriter, loc, elementType, tosaExp, tosaReciprocal, /*shift=*/constZero);

  rewriter.replaceOp(op, tosaMul);
  return success();
}

//===----------------------------------------------------------------------===//
// Misc. ops
//===----------------------------------------------------------------------===//
namespace {
struct LiteralConverter final
    : public OpConversionPattern<migraphx::LiteralOp> {
  using OpConversionPattern<migraphx::LiteralOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::LiteralOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct ClipConverter final : public OpConversionPattern<migraphx::ClipOp> {
  using OpConversionPattern<migraphx::ClipOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::ClipOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct WhereConverter final : public OpConversionPattern<migraphx::WhereOp> {
  using OpConversionPattern<migraphx::WhereOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::WhereOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct GreaterConverter final : public OpConversionPattern<migraphx::Greater> {
  using OpConversionPattern<migraphx::Greater>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::Greater op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

LogicalResult
LiteralConverter::matchAndRewrite(migraphx::LiteralOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
  MIXRShapedType type = op.getResult().getType();
  RankedTensorType newType =
      cast<RankedTensorType>(getTypeConverter()->convertType(type));

  ElementsAttr value = op.getValue();
  if (value.getType() != newType) {
    if (value.isSplat()) {
      // Get the original splat value (for example SI8 value)
      Attribute splatValue = value.getSplatValue<Attribute>();

      // Reinterpret the splatValue under the new type (for example SI8 -> I8),
      // preserving bytes
      Attribute newSplatValue;
      if (auto intAttr = dyn_cast<IntegerAttr>(splatValue))
        newSplatValue =
            IntegerAttr::get(newType.getElementType(), intAttr.getValue());
      else if (auto floatAttr = dyn_cast<FloatAttr>(splatValue))
        newSplatValue =
            FloatAttr::get(newType.getElementType(), floatAttr.getValue());
      else
        return failure();

      // Create the new SplatElementsAttr (for example I8 type) with preserved
      // value bytes
      value = SplatElementsAttr::get(newType, newSplatValue);
    } else {
      // For non-splat attributes, we need to convert each element to the new
      // type
      SmallVector<Attribute> convertedElements;
      convertedElements.reserve(value.getNumElements());

      for (auto it : value.getValues<Attribute>()) {
        Attribute convertedElement;
        if (auto intAttr = dyn_cast<IntegerAttr>(it))
          convertedElement =
              IntegerAttr::get(newType.getElementType(), intAttr.getValue());
        else if (auto floatAttr = dyn_cast<FloatAttr>(it))
          convertedElement =
              FloatAttr::get(newType.getElementType(), floatAttr.getValue());
        else
          return failure();

        convertedElements.push_back(convertedElement);
      }

      // Create a new DenseElementsAttr with the converted elements and new type
      value = DenseElementsAttr::get(newType, convertedElements);
    }
  }

  // Replace with the new operation using the updated tensor type
  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, newType, value);
  return success();
}

LogicalResult
ClipConverter::matchAndRewrite(migraphx::ClipOp op, OpAdaptor adaptor,
                               ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  Value x = adaptor.getX();
  Value minVals = adaptor.getMinVals();
  Value maxVals = adaptor.getMaxVals();
  auto outType = cast<RankedTensorType>(
      getTypeConverter()->convertType(op.getResult().getType()));
  Value atLeastMin =
      tosa::MaximumOp::create(rewriter, loc, outType, x, minVals);
  rewriter.replaceOpWithNewOp<tosa::MinimumOp>(op, outType, atLeastMin,
                                               maxVals);
  return success();
}

LogicalResult
WhereConverter::matchAndRewrite(migraphx::WhereOp op, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  Value rawCond = adaptor.getCond();
  Value inA = adaptor.getInA();
  Value inB = adaptor.getInB();
  Value cond = createCastOp(rewriter, loc, rewriter.getI1Type(), rawCond,
                            op.getCond().getType().getElementType());
  rewriter.replaceOpWithNewOp<tosa::SelectOp>(
      op, getTypeConverter()->convertType(op.getResult().getType()), cond, inA,
      inB);
  return success();
}

LogicalResult
GreaterConverter::matchAndRewrite(migraphx::Greater op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
  Value inA = adaptor.getInA();
  Value inB = adaptor.getInB();

  // Create a new tensor type with I1 element type
  auto newType =
      RankedTensorType::get(op.getType().getShape(), rewriter.getI1Type());
  auto goe =
      rewriter.createOrFold<tosa::GreaterOp>(op->getLoc(), newType, inA, inB);
  rewriter.replaceOpWithNewOp<tosa::CastOp>(op, adaptor.getInA().getType(),
                                            goe);

  return success();
}

//===----------------------------------------------------------------------===//
// Function boundaries
//===----------------------------------------------------------------------===//
namespace {
struct AsLogicalShapeConverter final
    : public OpConversionPattern<migraphx::AsLogicalShapeOp> {
  using OpConversionPattern<migraphx::AsLogicalShapeOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::AsLogicalShapeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

struct AsUnderlyingShapeConverter final
    : public OpConversionPattern<migraphx::AsUnderlyingShapeOp> {
  using OpConversionPattern<migraphx::AsUnderlyingShapeOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(migraphx::AsUnderlyingShapeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};

/// This mirrors the call op conversion pattern but works for mhal.launch.
struct MHALLaunchConverter final : public OpConversionPattern<mhal::LaunchOp> {
  using OpConversionPattern<mhal::LaunchOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mhal::LaunchOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

LogicalResult AsLogicalShapeConverter::matchAndRewrite(
    migraphx::AsLogicalShapeOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  MIXRShapedType inType = op.getIn().getType();
  RankedTensorType resultType = op.getOut().getType();
  Value in = adaptor.getIn();

  // First, expand ourselves back out to the N-D type that we're logically
  // working with in memory.
  RankedTensorType memoryLayoutType = inType.asMemoryLayoutTensor();
  Value expanded = in;
  if (in.getType() != memoryLayoutType) {
    auto shapeValue =
        tosa::getTosaConstShape(rewriter, loc, memoryLayoutType.getShape());
    expanded = tosa::ReshapeOp::create(rewriter, loc, in, shapeValue);
  }

  // This is the permutation that reorders the strides into standard shape.
  // Equivalently, it is the permutation that, when applied to a standard
  // shape, produces its in-memory layout. So, to get back to standard/logical
  // shape, we need to invert it.
  SmallVector<int64_t, 4> inversePermutation;
  inType.getStridePermutation(inversePermutation);
  SmallVector<int32_t> permutation;
  permutation.resize_for_overwrite(inversePermutation.size());
  bool hasTranspose = false;
  for (auto [to, from] : llvm::enumerate(inversePermutation)) {
    permutation[from] = to;
    hasTranspose |= (from != static_cast<int32_t>(to));
  }
  Value transposed = expanded;
  if (hasTranspose)
    transposed = getTransposeOp(loc, expanded, rewriter, permutation);
  auto transposedType = cast<RankedTensorType>(transposed.getType());
  if (transposedType == resultType) {
    rewriter.replaceOp(op, transposed);
    return success();
  }

  SmallVector<int64_t, 4> slicingShape(resultType.getShape());
  for (auto [dim, stride] :
       llvm::zip_equal(slicingShape, inType.getStrides())) {
    if (stride == 0)
      dim = 1;
  }

  Value maybeSliced = transposed;
  if (transposedType.getShape() != ArrayRef(slicingShape)) {
    SmallVector<int64_t, 4> starts(permutation.size(), 0);
    RankedTensorType sliceType = resultType.clone(slicingShape);
    auto startsValue = tosa::getTosaConstShape(rewriter, loc, starts);
    auto slicingShapeValue =
        tosa::getTosaConstShape(rewriter, loc, slicingShape);
    maybeSliced = tosa::SliceOp::create(rewriter, loc, sliceType, transposed,
                                        startsValue, slicingShapeValue);
  }
  Value maybeBroadcast = maybeSliced;
  if (maybeSliced.getType() != resultType) {
    // We need a broadcast
    Value zeroTensor = getZeroTensor(loc, resultType, rewriter);
    maybeBroadcast =
        tosa::AddOp::create(rewriter, loc, resultType, zeroTensor, maybeSliced);
  }
  rewriter.replaceOp(op, maybeBroadcast);
  return success();
}

LogicalResult AsUnderlyingShapeConverter::matchAndRewrite(
    migraphx::AsUnderlyingShapeOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  MIXRShapedType resultType = op.getOut().getType();
  RankedTensorType memoryLayoutType = resultType.asMemoryLayoutTensor();
  auto resultTensorType =
      cast<RankedTensorType>(getTypeConverter()->convertType(resultType));
  if (!resultTensorType)
    return op.emitOpError("unsupported conversion to underlying shape");
  Value in = adaptor.getIn();
  SmallVector<int64_t, 4> permutation;
  // This is the permutation that reorderd strides into the order they'd be in
  // in a standard shape. So, applying it to a logically-shaped tensor gets
  // you the tensor in in-memory layout.
  resultType.getStridePermutation(permutation);
  // TOSA transpose takes i32
  SmallVector<int32_t, 4> permutationI32 =
      llvm::map_to_vector(permutation, [](int64_t val) -> int32_t {
        return static_cast<int32_t>(val);
      });

  Value transposed = in;
  if (!llvm::is_sorted(permutation))
    transposed = getTransposeOp(loc, in, rewriter, permutationI32);
  if (transposed.getType() != memoryLayoutType) {
    rewriter.eraseOp(transposed.getDefiningOp());
    return op.emitOpError(
        "writing to tensors with long strides or broadcasts is unsupported");
  }

  Value collapsed = transposed;
  if (transposed.getType() != resultTensorType) {
    auto shapeValue =
        tosa::getTosaConstShape(rewriter, loc, resultTensorType.getShape());
    collapsed = tosa::ReshapeOp::create(rewriter, loc, transposed, shapeValue);
  }
  rewriter.replaceOp(op, collapsed);
  return success();
}

LogicalResult MHALLaunchConverter::matchAndRewrite(
    mhal::LaunchOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  // Convert the original function results.
  SmallVector<Type, 2> resultTypes;
  if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
    return failure();

  // If this isn't a one-to-one type mapping, we don't know how to aggregate
  // the results.
  if (op->getNumResults() != resultTypes.size())
    return failure();

  // Substitute with the new result types from the corresponding FuncType
  // conversion.
  rewriter.replaceOpWithNewOp<mhal::LaunchOp>(
      op, op.getCalleeAttr(), resultTypes, adaptor.getOperands());
  return success();
}

//===----------------------------------------------------------------------===//
// External interface
//===----------------------------------------------------------------------===//

void migraphx::populateMIGraphXToTosaConversionPatterns(
    RewritePatternSet &patterns, TypeConverter &typeConverter) {
  patterns.add<ConvConverter<ConvolutionBwdDataOp>,
               ConvConverter<ConvolutionOp>, ConvConverter<QuantConvolutionOp>,
               DotConverter<DotOp>, DotConverter<QuantDotOp>,
               BroadcastConverter, MultiBroadcastConverter, TransposeConverter,
               ReshapeConverter, SliceConverter, ReduceMeanConverter,
               ReduceConverter<ReduceSumOp, tosa::ReduceSumOp>,
               ReduceConverter<ReduceMaxOp, tosa::ReduceMaxOp>,
               TrivialConverter<AddOp, tosa::AddOp>,
               TrivialConverter<SubOp, tosa::SubOp>,
               TrivialConverter<PowOp, tosa::PowOp>, DivConverter, MulConverter,
               TrivialConverter<AbsOp, tosa::AbsOp>,
               TrivialConverter<CeilOp, tosa::CeilOp>,
               TrivialConverter<ErfOp, tosa::ErfOp>,
               TrivialConverter<ExpOp, tosa::ExpOp>,
               TrivialConverter<FloorOp, tosa::FloorOp>,
               TrivialConverter<LogOp, tosa::LogOp>,
               TrivialConverter<RecipOp, tosa::ReciprocalOp>,
               TrivialConverter<RsqrtOp, tosa::RsqrtOp>,
               TrivialConverter<SigmoidOp, tosa::SigmoidOp>,
               TrivialConverter<TanhOp, tosa::TanhOp>, QuantizeLinearConverter,
               DeQuantizeLinearConverter, ConvertConverter, NegConverter,
               ReluConverter, SoftmaxConverter, LiteralConverter, ClipConverter,
               WhereConverter, GreaterConverter>(typeConverter,
                                                 patterns.getContext());
}

void mlir::migraphx::populateMIGraphXFuncBoundaryToTosaConversionPatterns(
    RewritePatternSet &patterns, TypeConverter &typeConverter) {
  patterns.add<AsLogicalShapeConverter, AsUnderlyingShapeConverter,
               TrivialConverter<func::ReturnOp, func::ReturnOp>,
               MHALLaunchConverter>(typeConverter, patterns.getContext());
  // Add upstream patterns that take care of func.func and its friends.
  populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter);
  populateCallOpTypeConversionPattern(patterns, typeConverter);
}
