Skip to content

Commit aad89fb

Browse files
authored
Support dyn shapes in BatchNormDecomposition transformation (openvinotoolkit#23290)
### Details: Support dyn shapes in BatchNormDecomposition transformation ### Tickets: - *CVS-133609*
1 parent 2c0efbb commit aad89fb

File tree

2 files changed

+102
-14
lines changed

2 files changed

+102
-14
lines changed

src/common/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp

+27-14
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ using namespace ov;
2828

2929
ov::pass::BatchNormDecomposition::BatchNormDecomposition() {
3030
MATCHER_SCOPE(BatchNormDecomposition);
31-
auto bn_1 = pattern::wrap_type<ov::op::v0::BatchNormInference>({pattern::any_input(pattern::has_static_shape()),
32-
pattern::any_input(pattern::has_static_shape()),
31+
auto bn_1 = pattern::wrap_type<ov::op::v0::BatchNormInference>({pattern::any_input(),
32+
pattern::any_input(),
3333
pattern::any_input(pattern::has_static_rank()),
34-
pattern::any_input(pattern::has_static_shape()),
35-
pattern::any_input(pattern::has_static_shape())});
34+
pattern::any_input(),
35+
pattern::any_input()});
3636
auto bn_5 = pattern::wrap_type<ov::op::v5::BatchNormInference>({pattern::any_input(pattern::has_static_rank()),
37-
pattern::any_input(pattern::has_static_shape()),
38-
pattern::any_input(pattern::has_static_shape()),
39-
pattern::any_input(pattern::has_static_shape()),
40-
pattern::any_input(pattern::has_static_shape())});
37+
pattern::any_input(),
38+
pattern::any_input(),
39+
pattern::any_input(),
40+
pattern::any_input()});
4141
auto bn = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{bn_1, bn_5});
4242

4343
matcher_pass_callback callback = [this](ov::pass::pattern::Matcher& m) {
@@ -83,9 +83,8 @@ ov::pass::BatchNormDecomposition::BatchNormDecomposition() {
8383
std::make_shared<ov::op::v1::Reshape>(gamma_div_scale, new_shape, true);
8484
std::shared_ptr<Node> beta_aligned = std::make_shared<ov::op::v1::Reshape>(m_beta, new_shape, true);
8585
std::shared_ptr<Node> mean_aligned = std::make_shared<ov::op::v1::Reshape>(m_mean, new_shape, true);
86-
std::shared_ptr<Node> mean_negative = std::make_shared<ov::op::v1::Multiply>(
87-
mean_aligned,
88-
ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1}));
86+
auto mul_const = ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1});
87+
std::shared_ptr<Node> mean_negative = std::make_shared<ov::op::v1::Multiply>(mean_aligned, mul_const);
8988

9089
if (auto constant = ov::util::get_constant_from_source(beta_aligned))
9190
beta_aligned = constant;
@@ -103,9 +102,23 @@ ov::pass::BatchNormDecomposition::BatchNormDecomposition() {
103102

104103
add->set_friendly_name(m_bn->get_friendly_name());
105104

106-
copy_runtime_info(
107-
m_bn,
108-
{scale_add, scale, gamma_div_scale, gamma_div_scale_aligned, beta_aligned, input_sub_mean, mul, add});
105+
copy_runtime_info(m_bn,
106+
{scale_add,
107+
scale,
108+
gamma_div_scale,
109+
gamma_div_scale_aligned,
110+
beta_aligned,
111+
input_sub_mean,
112+
mul,
113+
add,
114+
mean_negative,
115+
mean_aligned,
116+
new_shape,
117+
tail_shape,
118+
tail_shape_rank,
119+
one,
120+
mul_const,
121+
C_dim});
109122

110123
replace_node(m_bn, add);
111124

src/common/transformations/tests/op_conversions/batch_norm_decomposition_test.cpp

+75
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,45 @@
1818
using namespace ov;
1919
using namespace testing;
2020

21+
std::shared_ptr<ov::Model> get_ref_model_with_dyn_shapes(ov::element::Type precision, const PartialShape& input_shape) {
22+
auto input = std::make_shared<opset1::Parameter>(precision, input_shape);
23+
auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
24+
auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
25+
auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
26+
auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
27+
// scale_add = variance + eps
28+
auto scale_add = std::make_shared<ov::op::v1::Add>(var, ov::op::v0::Constant::create(precision, Shape{}, {0.001}));
29+
// scale = sqrt(variance + eps)
30+
auto scale = std::make_shared<ov::op::v0::Sqrt>(scale_add);
31+
// Divide `gamma` by `sqrt(variance + eps)`
32+
auto gamma_div_scale = std::make_shared<ov::op::v1::Divide>(gamma, scale);
33+
34+
int64_t dims_to_add = input->get_partial_shape().rank().get_length() - 2;
35+
const auto one = ov::op::v0::Constant::create(element::i64, Shape{1}, {1});
36+
const auto tail_shape_rank = ov::op::v0::Constant::create(element::i64, Shape{1}, {dims_to_add});
37+
const auto tail_shape = std::make_shared<ov::op::v3::Broadcast>(one, tail_shape_rank);
38+
const auto C_dim = std::make_shared<ov::op::v3::ShapeOf>(gamma);
39+
// create new shape [1, C, 1, 1, ...]
40+
const auto new_shape = std::make_shared<ov::op::v0::Concat>(OutputVector{one, C_dim, tail_shape}, 0);
41+
42+
std::shared_ptr<Node> gamma_div_scale_aligned =
43+
std::make_shared<ov::op::v1::Reshape>(gamma_div_scale, new_shape, true);
44+
std::shared_ptr<Node> beta_aligned = std::make_shared<ov::op::v1::Reshape>(beta, new_shape, true);
45+
std::shared_ptr<Node> mean_aligned = std::make_shared<ov::op::v1::Reshape>(mean, new_shape, true);
46+
std::shared_ptr<Node> mean_negative = std::make_shared<ov::op::v1::Multiply>(
47+
mean_aligned,
48+
ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1}));
49+
50+
// input_sub_mean = input + mean * -1
51+
auto input_sub_mean = std::make_shared<ov::op::v1::Add>(input, mean_negative);
52+
// Multiply `input - mean` and `gamma / sqrt(variance + eps)`
53+
auto mul = std::make_shared<ov::op::v1::Multiply>(input_sub_mean, gamma_div_scale_aligned);
54+
// Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta`
55+
auto add = std::make_shared<ov::op::v1::Add>(mul, beta_aligned);
56+
57+
return std::make_shared<ov::Model>(NodeVector{add}, ParameterVector{input, gamma, beta, mean, var});
58+
}
59+
2160
TEST_F(TransformationTestsF, BatchNormDecompositionStaticRankOpset1) {
2261
const PartialShape input_shape{-1, -1, -1, -1};
2362
const auto precision = element::f32;
@@ -74,6 +113,42 @@ TEST_F(TransformationTestsF, BatchNormDecompositionStaticRankOpset5) {
74113
}
75114
}
76115

116+
TEST_F(TransformationTestsF, BatchNormDecompositionDynamicShapesOpset1) {
117+
const PartialShape input_shape{-1, -1, -1, -1};
118+
const auto precision = element::f32;
119+
{
120+
auto input = std::make_shared<opset1::Parameter>(precision, input_shape);
121+
auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
122+
auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
123+
auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
124+
auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
125+
auto batch_norm = std::make_shared<opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
126+
127+
model = std::make_shared<ov::Model>(NodeVector{batch_norm}, ParameterVector{input, gamma, beta, mean, var});
128+
manager.register_pass<ov::pass::BatchNormDecomposition>();
129+
comparator.enable(FunctionsComparator::CONST_VALUES);
130+
}
131+
{ model_ref = get_ref_model_with_dyn_shapes(precision, input_shape); }
132+
}
133+
134+
TEST_F(TransformationTestsF, BatchNormDecompositionDynamicShapesOpset5) {
135+
const PartialShape input_shape{-1, -1, -1, -1};
136+
const auto precision = element::f32;
137+
{
138+
auto input = std::make_shared<opset1::Parameter>(precision, input_shape);
139+
auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
140+
auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
141+
auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
142+
auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
143+
auto batch_norm = std::make_shared<opset5::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
144+
145+
model = std::make_shared<ov::Model>(NodeVector{batch_norm}, ParameterVector{input, gamma, beta, mean, var});
146+
manager.register_pass<ov::pass::BatchNormDecomposition>();
147+
comparator.enable(FunctionsComparator::CONST_VALUES);
148+
}
149+
{ model_ref = get_ref_model_with_dyn_shapes(precision, input_shape); }
150+
}
151+
77152
TEST_F(TransformationTestsF, BatchNormDecompositionDynamicRank) {
78153
{
79154
auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape::dynamic());

0 commit comments

Comments
 (0)