Skip to content

Commit

Permalink
Merge pull request #32 from Xilinx/tiagot.align_hardswish_decomp_with…
Browse files Browse the repository at this point in the history
…_torch_mlir

fix: align decomposition of Hardswish with torch-mlir implementation.
  • Loading branch information
ttjost authored Jan 8, 2024
2 parents 4357bc1 + 0ac0f51 commit 3950d13
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ struct DecomposeHardSwishPattern : public ConversionPattern {
hardSwishOp.getType(), input, rewriter.getF32FloatAttr(1.0 / 6.0),
rewriter.getF32FloatAttr(0.5));
rewriter.replaceOpWithNewOp<ONNXMulOp>(
op, hardSwishOp.getType(), input, hardSigmoid);
op, hardSwishOp.getType(), hardSigmoid, input);
return success();
}
};
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_decompose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,6 @@ func.func @test_hardswish_f32(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK-LABEL: func @test_hardswish_f32
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: [[VAR_0_:%.+]] = "onnx.HardSigmoid"([[PARAM_0_]]) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[VAR_0_]], [[PARAM_0_]]) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: return [[VAR_1_]] : tensor<?x?x?xf32>
}

0 comments on commit 3950d13

Please sign in to comment.