|
18 | 18 | using namespace ov;
|
19 | 19 | using namespace testing;
|
20 | 20 |
|
| 21 | +std::shared_ptr<ov::Model> get_ref_model_with_dyn_shapes(ov::element::Type precision, const PartialShape& input_shape) { |
| 22 | + auto input = std::make_shared<opset1::Parameter>(precision, input_shape); |
| 23 | + auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 24 | + auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 25 | + auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 26 | + auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 27 | + // scale_add = variance + eps |
| 28 | + auto scale_add = std::make_shared<ov::op::v1::Add>(var, ov::op::v0::Constant::create(precision, Shape{}, {0.001})); |
| 29 | + // scale = sqrt(variance + eps) |
| 30 | + auto scale = std::make_shared<ov::op::v0::Sqrt>(scale_add); |
| 31 | + // Divide `gamma` by `sqrt(variance + eps)` |
| 32 | + auto gamma_div_scale = std::make_shared<ov::op::v1::Divide>(gamma, scale); |
| 33 | + |
| 34 | + int64_t dims_to_add = input->get_partial_shape().rank().get_length() - 2; |
| 35 | + const auto one = ov::op::v0::Constant::create(element::i64, Shape{1}, {1}); |
| 36 | + const auto tail_shape_rank = ov::op::v0::Constant::create(element::i64, Shape{1}, {dims_to_add}); |
| 37 | + const auto tail_shape = std::make_shared<ov::op::v3::Broadcast>(one, tail_shape_rank); |
| 38 | + const auto C_dim = std::make_shared<ov::op::v3::ShapeOf>(gamma); |
| 39 | + // create new shape [1, C, 1, 1, ...] |
| 40 | + const auto new_shape = std::make_shared<ov::op::v0::Concat>(OutputVector{one, C_dim, tail_shape}, 0); |
| 41 | + |
| 42 | + std::shared_ptr<Node> gamma_div_scale_aligned = |
| 43 | + std::make_shared<ov::op::v1::Reshape>(gamma_div_scale, new_shape, true); |
| 44 | + std::shared_ptr<Node> beta_aligned = std::make_shared<ov::op::v1::Reshape>(beta, new_shape, true); |
| 45 | + std::shared_ptr<Node> mean_aligned = std::make_shared<ov::op::v1::Reshape>(mean, new_shape, true); |
| 46 | + std::shared_ptr<Node> mean_negative = std::make_shared<ov::op::v1::Multiply>( |
| 47 | + mean_aligned, |
| 48 | + ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1})); |
| 49 | + |
| 50 | + // input_sub_mean = input + mean * -1 |
| 51 | + auto input_sub_mean = std::make_shared<ov::op::v1::Add>(input, mean_negative); |
| 52 | + // Multiply `input - mean` and `gamma / sqrt(variance + eps)` |
| 53 | + auto mul = std::make_shared<ov::op::v1::Multiply>(input_sub_mean, gamma_div_scale_aligned); |
| 54 | + // Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta` |
| 55 | + auto add = std::make_shared<ov::op::v1::Add>(mul, beta_aligned); |
| 56 | + |
| 57 | + return std::make_shared<ov::Model>(NodeVector{add}, ParameterVector{input, gamma, beta, mean, var}); |
| 58 | +} |
| 59 | + |
21 | 60 | TEST_F(TransformationTestsF, BatchNormDecompositionStaticRankOpset1) {
|
22 | 61 | const PartialShape input_shape{-1, -1, -1, -1};
|
23 | 62 | const auto precision = element::f32;
|
@@ -74,6 +113,42 @@ TEST_F(TransformationTestsF, BatchNormDecompositionStaticRankOpset5) {
|
74 | 113 | }
|
75 | 114 | }
|
76 | 115 |
|
| 116 | +TEST_F(TransformationTestsF, BatchNormDecompositionDynamicShapesOpset1) { |
| 117 | + const PartialShape input_shape{-1, -1, -1, -1}; |
| 118 | + const auto precision = element::f32; |
| 119 | + { |
| 120 | + auto input = std::make_shared<opset1::Parameter>(precision, input_shape); |
| 121 | + auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 122 | + auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 123 | + auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 124 | + auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 125 | + auto batch_norm = std::make_shared<opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001); |
| 126 | + |
| 127 | + model = std::make_shared<ov::Model>(NodeVector{batch_norm}, ParameterVector{input, gamma, beta, mean, var}); |
| 128 | + manager.register_pass<ov::pass::BatchNormDecomposition>(); |
| 129 | + comparator.enable(FunctionsComparator::CONST_VALUES); |
| 130 | + } |
| 131 | + { model_ref = get_ref_model_with_dyn_shapes(precision, input_shape); } |
| 132 | +} |
| 133 | + |
| 134 | +TEST_F(TransformationTestsF, BatchNormDecompositionDynamicShapesOpset5) { |
| 135 | + const PartialShape input_shape{-1, -1, -1, -1}; |
| 136 | + const auto precision = element::f32; |
| 137 | + { |
| 138 | + auto input = std::make_shared<opset1::Parameter>(precision, input_shape); |
| 139 | + auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 140 | + auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 141 | + auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 142 | + auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1}); |
| 143 | + auto batch_norm = std::make_shared<opset5::BatchNormInference>(input, gamma, beta, mean, var, 0.001); |
| 144 | + |
| 145 | + model = std::make_shared<ov::Model>(NodeVector{batch_norm}, ParameterVector{input, gamma, beta, mean, var}); |
| 146 | + manager.register_pass<ov::pass::BatchNormDecomposition>(); |
| 147 | + comparator.enable(FunctionsComparator::CONST_VALUES); |
| 148 | + } |
| 149 | + { model_ref = get_ref_model_with_dyn_shapes(precision, input_shape); } |
| 150 | +} |
| 151 | + |
77 | 152 | TEST_F(TransformationTestsF, BatchNormDecompositionDynamicRank) {
|
78 | 153 | {
|
79 | 154 | auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape::dynamic());
|
|
0 commit comments