Skip to content

Commit

Permalink
Merge branch 'feature/onnx-to-tosa' into tiagot.fix_conv2d_onnx_to_to…
Browse files Browse the repository at this point in the history
…sa_no_bias
  • Loading branch information
ttjost authored Feb 24, 2025
2 parents e6146f1 + 177704c commit 7737733
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 45 deletions.
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ class ONNXDequantizeLinearOpLoweringToTOSA

int64_t axis = op.getAxis();
// See https://github.com/onnx/onnx/issues/6067
if (axis == 1 && resultType.getRank() == 1)
if (axis == 1 && (resultType.getRank() == 1 || resultType.getRank() == 0))
axis = 0;
if (axis < -resultType.getRank() || axis >= resultType.getRank()) {
if (resultType.getRank() != 0 &&
(axis < -resultType.getRank() || axis >= resultType.getRank())) {
return rewriter.notifyMatchFailure(loc, "axis is invalid");
}
if (axis < 0)
Expand Down
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ class ONNXQuantizeLinearOpLoweringToTOSA

int64_t axis = op.getAxis();
// See https://github.com/onnx/onnx/issues/6067
if (axis == 1 && resultType.getRank() == 1)
if (axis == 1 && (resultType.getRank() == 1 || resultType.getRank() == 0))
axis = 0;
if (axis < -resultType.getRank() || axis >= resultType.getRank()) {
if (resultType.getRank() != 0 &&
(axis < -resultType.getRank() || axis >= resultType.getRank())) {
return rewriter.notifyMatchFailure(loc, "axis is invalid");
}
if (axis < 0)
Expand Down
7 changes: 7 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ mlir::Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,
mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Value tensor, size_t axis, size_t rank) {
auto inTy = cast<ShapedType>(tensor.getType());
if (rank == 0) {
// target rank is a scalar
llvm::SmallVector<int64_t> newShape;
return rewriter.createOrFold<mlir::tosa::ReshapeOp>(loc,
RankedTensorType::get(newShape, inTy.getElementType()), tensor,
newShape);
}
llvm::SmallVector<int64_t> newShape(rank, 1);
newShape[axis] = inTy.getNumElements();
auto resultTy = RankedTensorType::get(newShape, inTy.getElementType());
Expand Down
6 changes: 3 additions & 3 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ T getValueFromTosaConst(mlir::Value &val) {
// This function is made to work with both onnx.const and tosa.const
mlir::ElementsAttr getElementsAttrFromConst(mlir::Value &val);

// Takes a 1-d `tensor` with k elements and reshapes it into an `rank`-d tensor
// with shape {1, ..., 1, k, 1, ..., 1 }
// where `k` it at position `axis`.
// Takes a 1-d `tensor` with k elements and reshapes it into an `rank`-d or
// scalar tensor with shape {1, ..., 1, k, 1, ..., 1 } where `k` it at position
// `axis`.
mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Value tensor, size_t axis, size_t rank);

Expand Down
10 changes: 6 additions & 4 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1882,9 +1882,10 @@ def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear",
`zero-point` is usually not used in the case of float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz quantization,
but the dequantization formula remains the same for consistency and 'x_scale' still determines the output type.
}];
let arguments = (ins AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[I32]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$x,
// AMD: Manual addition of uint16
let arguments = (ins AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[I32]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$x,
AnyTypeOf<[TensorOf<[F32]>, TensorOf<[F16]>, TensorOf<[BF16]>]>:$x_scale,
AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[I32]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, NoneType]>:$x_zero_point,
AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[I32]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, NoneType]>:$x_zero_point,
DefaultValuedAttr<SI64Attr, "1">:$axis);
let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[F16]>, TensorOf<[BF16]>]>:$y);
let extraClassDeclaration = [{
Expand Down Expand Up @@ -6129,12 +6130,13 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
but the quantization formula remains the same for consistency and
the type of the attribute 'y_zero_point' still determines the quantization type.
}];
// AMD: Manual addition of uint16
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[F16]>, TensorOf<[BF16]>, TensorOf<[I32]>]>:$x,
AnyTypeOf<[TensorOf<[F32]>, TensorOf<[F16]>, TensorOf<[BF16]>, TensorOf<[I32]>]>:$y_scale,
AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, NoneType]>:$y_zero_point,
AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, NoneType]>:$y_zero_point,
DefaultValuedAttr<SI64Attr, "1">:$axis,
DefaultValuedAttr<SI64Attr, "1">:$saturate);
let results = (outs AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$y);
let results = (outs AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 3;
Expand Down
11 changes: 11 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,17 @@ bool isScalarTensor(Value v) {
(getRank(v.getType()) == 1 && getShape(v.getType())[0] == 1)));
}

IgnoreDiagnostic::IgnoreDiagnostic(DiagnosticEngine &diagEngine)
: diagEngine(diagEngine) {
id = diagEngine.registerHandler(
[](mlir::Diagnostic & /*diag*/) { return success(); });
}

IgnoreDiagnostic::~IgnoreDiagnostic() {
// Reset to the previous state.
diagEngine.eraseHandler(id);
}

bool hasIntegerPowerExponent(ONNXPowOp *op, int64_t &exponentValue) {
Value exponent = op->getY();
ElementsAttr elementAttr = getElementAttributeFromONNXValue(exponent);
Expand Down
11 changes: 11 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,17 @@ int64_t mlirTypeToOnnxType(mlir::Type elemType);
/// Check if a value is a scalar tensor.
bool isScalarTensor(mlir::Value v);

class IgnoreDiagnostic {
public:
IgnoreDiagnostic(mlir::DiagnosticEngine &diagEngine);

~IgnoreDiagnostic();

private:
mlir::DiagnosticEngine &diagEngine;
mlir::DiagnosticEngine::HandlerID id;
};

bool hasIntegerPowerExponent(mlir::ONNXPowOp *op, int64_t &exponentValue);

//===----------------------------------------------------------------------===//
Expand Down
77 changes: 77 additions & 0 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -1581,6 +1582,80 @@ struct CustomOpFuseMatMulPattern : public OpRewritePattern<ONNXCustomOp> {
}
};

namespace {

[[nodiscard]] bool isCustomMicrosoftOp(
ONNXCustomOp customOp, StringRef expectedName) {
if (!customOp.getFunctionName().equals_insensitive(expectedName)) {
return false;
}

const auto domAttr = customOp->getAttrOfType<StringAttr>("domain_name");
return domAttr && domAttr.getValue().equals_insensitive("com.microsoft");
}

} // namespace

template <typename OpToCreate>
struct CustomOpMicrosoftQDuantizeLinear {
LogicalResult matchAndRewriteImpl(ONNXCustomOp customOp,
PatternRewriter &rewriter, StringRef expectedName) const {
using namespace onnx_mlir;

if (!isCustomMicrosoftOp(customOp, expectedName))
return failure();
if (customOp->getNumOperands() != 3) {
return failure();
}

const auto scale = customOp->getOperand(1);
const auto zeroPoint = customOp->getOperand(2);
if (!isScalarTensor(scale) || !isScalarTensor(zeroPoint)) {
return rewriter.notifyMatchFailure(
customOp, "Only supports per-tensor quantization for now");
}
// Axis is ignored if scale and zeroPoint are scalars

auto newOp = rewriter.create<OpToCreate>(customOp->getLoc(),
customOp.getResult(0).getType(), customOp->getOperand(0), scale,
zeroPoint);

IgnoreDiagnostic diag(customOp->getContext()->getDiagEngine());
bool isNewOpValid;
if (auto info = newOp->getName().getRegisteredInfo()) {
isNewOpValid = succeeded(info->verifyInvariants(newOp));
} else {
isNewOpValid = succeeded(mlir::verify(newOp));
}
if (!isNewOpValid) {
rewriter.eraseOp(newOp);
return rewriter.notifyMatchFailure(customOp, "Failed verification");
}
rewriter.replaceOp(customOp, newOp);
return success();
}
};

struct CustomOpMicrosoftQuantizeLinear
: public OpRewritePattern<ONNXCustomOp>,
public CustomOpMicrosoftQDuantizeLinear<ONNXQuantizeLinearOp> {
using OpRewritePattern<ONNXCustomOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
return matchAndRewriteImpl(customOp, rewriter, "QuantizeLinear");
}
};

struct CustomOpMicrosoftDequantizeLinear
: public OpRewritePattern<ONNXCustomOp>,
CustomOpMicrosoftQDuantizeLinear<ONNXDequantizeLinearOp> {
using OpRewritePattern<ONNXCustomOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
return matchAndRewriteImpl(customOp, rewriter, "DequantizeLinear");
}
};

// Transform InstanceNormalization into LayerNormalization
struct InstanceNormIntoLayerNormPattern
: public OpRewritePattern<ONNXInstanceNormalizationOp> {
Expand Down Expand Up @@ -2039,6 +2114,8 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
// Decompose CustomOp FusedMatMul introduced by onnxruntime:
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
patterns.insert<CustomOpFuseMatMulPattern>(context);
patterns.insert<CustomOpMicrosoftQuantizeLinear>(context);
patterns.insert<CustomOpMicrosoftDequantizeLinear>(context);
patterns.insert<InstanceNormIntoLayerNormPattern>(context);
patterns.insert<GroupNormIntoLayerNormPattern1>(context);
patterns.insert<GroupNormIntoLayerNormPattern2>(context);
Expand Down
24 changes: 13 additions & 11 deletions src/Dialect/ONNX/Transforms/Recompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,16 +815,17 @@ void RecomposeONNXToONNXPass::runOnOperation() {
return true;
});

// Recompose QLinearMatMul, starting from QuantizeLinear.
// Pattern: DequanizeLinear + MatMul + QuantizeLinear.
target.addDynamicallyLegalOp<ONNXQuantizeLinearOp>(
[](ONNXQuantizeLinearOp op) {
Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale,
outZeroPoint;
return !RecomposeQLinearMatMulFromQuantizeLinearPattern::
matchQLinearMatMulPattern(op, a, aScale, aZeroPoint, b, bScale,
bZeroPoint, outScale, outZeroPoint);
});
// AMD Disabled
// // Recompose QLinearMatMul, starting from QuantizeLinear.
// // Pattern: DequanizeLinear + MatMul + QuantizeLinear.
// target.addDynamicallyLegalOp<ONNXQuantizeLinearOp>(
// [](ONNXQuantizeLinearOp op) {
// Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale,
// outZeroPoint;
// return !RecomposeQLinearMatMulFromQuantizeLinearPattern::
// matchQLinearMatMulPattern(op, a, aScale, aZeroPoint, b, bScale,
// bZeroPoint, outScale, outZeroPoint);
// });

RewritePatternSet patterns(context);
onnx_mlir::getRecomposeONNXToONNXPatterns(patterns);
Expand All @@ -840,7 +841,8 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
MLIRContext *context = patterns.getContext();
patterns.insert<RecomposeGeluFromMulPattern>(context);
patterns.insert<RecomposeLayerNormFromMulPattern>(context);
patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
// AMD Disabled as downstream has no special support for it
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
}

/*!
Expand Down
13 changes: 13 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,16 @@ func.func @f8E4M3FN(%arg0: tensor<5xf8E4M3FN>, %arg1: tensor<f32>) -> tensor<5xf
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<5xf32>, tensor<1xf32>) -> tensor<5xf32>
// CHECK: return %[[VAL_4]] : tensor<5xf32>

// -----


func.func @all_scalar(%arg0 : tensor<i8>) -> tensor<f32> {
%0 = onnx.Constant dense<3.125000e-02> : tensor<f32>
%1 = onnx.Constant dense<0> : tensor<i8>
%2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<f32>
return %2 : tensor<f32>
}

// CHECK-LABEL: all_scalar
// CHECK-NOT: onnx.DequantizeLinear
13 changes: 13 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/NN/QuantizeLinear.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,16 @@ func.func @default_axis(%arg0 : tensor<32xf32>) -> tensor<32xi8> {

// CHECK-LABEL: default_axis
// CHECK-NOT: onnx.QuantizeLinear

// -----


func.func @all_scalar(%arg0 : tensor<f32>) -> tensor<i8> {
%0 = onnx.Constant dense<3.125000e-02> : tensor<f32>
%1 = onnx.Constant dense<0> : tensor<i8>
%2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<f32>, tensor<f32>, tensor<i8>) -> tensor<i8>
return %2 : tensor<i8>
}

// CHECK-LABEL: all_scalar
// CHECK-NOT: onnx.QuantizeLinear
13 changes: 7 additions & 6 deletions test/mlir/driver/static_quantization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ module {
}
"onnx.EntryPoint"() {func = @main_graph} : () -> ()

// COM: AMD Disabled
// CHECK-LABEL: func.func @qlinear_matmul
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>, [[PARAM_3_:%.+]]: tensor<768x768xi8>, [[PARAM_4_:%.+]]: tensor<f32>, [[PARAM_5_:%.+]]: tensor<i8>, [[PARAM_6_:%.+]]: tensor<f32>, [[PARAM_7_:%.+]]: tensor<i8>) -> tensor<?x?x768xi8> {
// CHECK: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64, onnx_node_name = "onnx.QuantizeLinear_0", saturate = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// CHECK: [[VAR_1_:%.+]] = "onnx.QLinearMatMul"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]]) {onnx_node_name = "onnx.QLinearMatMul_1"} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>, tensor<768x768xi8>, tensor<f32>, tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// CHECK: return [[VAR_1_]] : tensor<?x?x768xi8>
// CHECK: }
// CHECK: "onnx.EntryPoint"() {func = @main_graph} : () -> ()
// DISABLED-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>, [[PARAM_3_:%.+]]: tensor<768x768xi8>, [[PARAM_4_:%.+]]: tensor<f32>, [[PARAM_5_:%.+]]: tensor<i8>, [[PARAM_6_:%.+]]: tensor<f32>, [[PARAM_7_:%.+]]: tensor<i8>) -> tensor<?x?x768xi8> {
// DISABLED: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64, onnx_node_name = "onnx.QuantizeLinear_0", saturate = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// DISABLED: [[VAR_1_:%.+]] = "onnx.QLinearMatMul"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]]) {onnx_node_name = "onnx.QLinearMatMul_1"} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>, tensor<768x768xi8>, tensor<f32>, tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// DISABLED: return [[VAR_1_]] : tensor<?x?x768xi8>
// DISABLED: }
// DISABLED: "onnx.EntryPoint"() {func = @main_graph} : () -> ()
}
56 changes: 56 additions & 0 deletions test/mlir/onnx/onnx_decompose_customop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,59 @@ func.func @customop_fusedmatmul_not_rewrite_no_alpha(%arg0: tensor<*xf32>, %arg1
// CHECK: onnx.Return [[VAR_0_]] : tensor<*xf32>
// CHECK: }
}

// -----


func.func @customop_quantize(%arg0: tensor<*xf32>, %arg1: tensor<f32>, %arg2: tensor<ui16>) -> tensor<*xui16> {
%1 = "onnx.Custom"(%arg0, %arg1, %arg2) {domain_name = "com.microsoft", function_name = "QuantizeLinear"} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
onnx.Return %1: tensor<*xui16>

// CHECK-LABEL: func.func @customop_quantize
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<ui16>) -> tensor<*xui16> {
// CHECK: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor<*xf32>, tensor<f32>, tensor<ui16>) -> tensor<*xui16>
// CHECK: onnx.Return [[VAR_0_]] : tensor<*xui16>
// CHECK: }
}

// -----

// COM: Do not recompose per axis quantization (for now)
func.func @customop_quantize_axis(%arg0: tensor<*xf32>, %arg1: tensor<5xf32>, %arg2: tensor<5xui16>) -> tensor<*xui16> {
%1 = "onnx.Custom"(%arg0, %arg1, %arg2) {domain_name = "com.microsoft", function_name = "QuantizeLinear"} : (tensor<*xf32>, tensor<5xf32>, tensor<5xui16>) -> tensor<*xui16>
onnx.Return %1: tensor<*xui16>

// CHECK-LABEL: func.func @customop_quantize_axis
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<5xf32>, [[PARAM_2_:%.+]]: tensor<5xui16>) -> tensor<*xui16> {
// CHECK: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {domain_name = "com.microsoft", function_name = "QuantizeLinear"} : (tensor<*xf32>, tensor<5xf32>, tensor<5xui16>) -> tensor<*xui16>
// CHECK: onnx.Return [[VAR_0_]] : tensor<*xui16>
// CHECK: }
}

// -----


func.func @customop_dequantize(%arg0: tensor<*xui16>, %arg1: tensor<f32>, %arg2: tensor<ui16>) -> tensor<*xf32> {
%1 = "onnx.Custom"(%arg0, %arg1, %arg2) {domain_name = "com.microsoft", function_name = "DequantizeLinear"} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
onnx.Return %1: tensor<*xf32>

// CHECK-LABEL: func.func @customop_dequantize
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xui16>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<ui16>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.DequantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64} : (tensor<*xui16>, tensor<f32>, tensor<ui16>) -> tensor<*xf32>
// CHECK: onnx.Return [[VAR_0_]] : tensor<*xf32>
// CHECK: }
}

// -----

// COM: Do not recompose per axis quantization (for now)
func.func @customop_dequantize_axis(%arg0: tensor<*xui16>, %arg1: tensor<5xf32>, %arg2: tensor<5xui16>) -> tensor<*xf32> {
%1 = "onnx.Custom"(%arg0, %arg1, %arg2) {domain_name = "com.microsoft", function_name = "DequantizeLinear"} : (tensor<*xui16>, tensor<5xf32>, tensor<5xui16>) -> tensor<*xf32>
onnx.Return %1: tensor<*xf32>

// CHECK-LABEL: func.func @customop_dequantize_axis
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xui16>, [[PARAM_1_:%.+]]: tensor<5xf32>, [[PARAM_2_:%.+]]: tensor<5xui16>) -> tensor<*xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {domain_name = "com.microsoft", function_name = "DequantizeLinear"} : (tensor<*xui16>, tensor<5xf32>, tensor<5xui16>) -> tensor<*xf32>
// CHECK: onnx.Return [[VAR_0_]] : tensor<*xf32>
// CHECK: }
}
Loading

0 comments on commit 7737733

Please sign in to comment.