diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 58c475cded..e0a41ac4df 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -179,8 +179,9 @@ class ONNXMulOpLoweringToTosa : public OpConversionPattern { Value lhs = adaptor.getA(); Value rhs = adaptor.getB(); - rewriter.replaceOpWithNewOp( - op, op.getType(), lhs, rhs, /*shift =*/0); + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + Value mulOp = tosaBuilder.mul(lhs, rhs); + rewriter.replaceOp(op, {mulOp}); return success(); } diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 61e67799a9..eb97bb45c1 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -99,6 +99,28 @@ func.func @test_mul(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> t // ----- +func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_mul_rank_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_1_]]) <{new_shape = array}> : (tensor<21x1xf32>) -> tensor<1x21x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.mul"([[PARAM_0_]], [[VAR_0_]]) <{shift = 0 : i32}> : (tensor<13x21x1xf32>, tensor<1x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_mul_rank_broadcast2(%arg0: tensor<21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_mul_rank_broadcast2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.reshape"([[PARAM_0_]]) <{new_shape = array}> : (tensor<21x1xf32>) -> tensor<1x21x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.mul"([[VAR_0_]], [[PARAM_1_]]) <{shift = 0 : i32}> : (tensor<1x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + func.func @test_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> { %0 = "onnx.Div"(%arg0, %arg1) : (tensor<13x21x1xi32>, tensor<13x21x1xi32>) -> tensor<13x21x1xi32> "func.return"(%0) : (tensor<13x21x1xi32>) -> ()