From 0ac0f51e66c4c1ad82f1a8a322a425f4909a7bbf Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Fri, 5 Jan 2024 15:12:07 +0000 Subject: [PATCH] fix: align decomposition of Hardswish with torch-mlir implementation. --- src/Transform/ONNX/Decompose.cpp | 2 +- test/mlir/onnx/onnx_decompose.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Transform/ONNX/Decompose.cpp b/src/Transform/ONNX/Decompose.cpp index bebfdadced..d9c3a98820 100644 --- a/src/Transform/ONNX/Decompose.cpp +++ b/src/Transform/ONNX/Decompose.cpp @@ -571,7 +571,7 @@ struct DecomposeHardSwishPattern : public ConversionPattern { hardSwishOp.getType(), input, rewriter.getF32FloatAttr(1.0 / 6.0), rewriter.getF32FloatAttr(0.5)); rewriter.replaceOpWithNewOp( - op, hardSwishOp.getType(), input, hardSigmoid); + op, hardSwishOp.getType(), hardSigmoid, input); return success(); } }; diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 44ec1bb52a..ab826baf36 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -489,6 +489,6 @@ func.func @test_hardswish_f32(%arg0: tensor) -> tensor { // CHECK-LABEL: func @test_hardswish_f32 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor // CHECK: [[VAR_0_:%.+]] = "onnx.HardSigmoid"([[PARAM_0_]]) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor) -> tensor -// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_0_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[VAR_0_]], [[PARAM_0_]]) : (tensor, tensor) -> tensor // CHECK: return [[VAR_1_]] : tensor } \ No newline at end of file