Skip to content

Commit 2a0f595

Browse files
authored
[TRANSFORMATIONS][GPU] Add GroupNormalization fusion to common optimizations (openvinotoolkit#28387)
### Details: - Added GroupNormalization fusion pass that can handle pattern observed in many customer models that were exported via ONNX in a way that uses InstanceNormalization as a proxy for GroupNormalization. It covers also more traditional cases without additional instance norm related parameters. - Per suggestion from @vladimir-paramuzov, for now enabled GroupNormalization fusion only for GPU plugin. Once it will be verified that it doesn't cause regressions in other backends, we can enable it for them as well. ### Tickets: - 160436
1 parent d9a5e6b commit 2a0f595

File tree

12 files changed

+1182
-132
lines changed

12 files changed

+1182
-132
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/pass/graph_rewrite.hpp"
8+
#include "transformations_visibility.hpp"
9+
10+
namespace ov {
11+
namespace pass {
12+
13+
class TRANSFORMATIONS_API GroupNormalizationFusion;
14+
15+
} // namespace pass
16+
} // namespace ov
17+
18+
/**
19+
* @ingroup ov_transformation_common_api
20+
* @brief GroupNormalizationFusion transformation replaces
21+
* following pattern with fused GroupNormalization op:
22+
* group_norm_gamma * (instance_norm_gamma * MVN(x) + instance_norm_beta) + group_norm_beta
23+
* note that instance norm related parameters are optional:
24+
* - instance_norm_gamma is assumed to be filled with ones if not present in the graph
25+
* - instance_norm_beta is assumed to be filled with zeros if not present in the graph
26+
*/
27+
28+
class ov::pass::GroupNormalizationFusion : public ov::pass::MatcherPass {
29+
public:
30+
OPENVINO_MATCHER_PASS_RTTI("GroupNormalizationFusion");
31+
GroupNormalizationFusion();
32+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/common_optimizations/group_normalization_fusion.hpp"
6+
7+
#include "itt.hpp"
8+
#include "openvino/core/rt_info.hpp"
9+
#include "openvino/op/add.hpp"
10+
#include "openvino/op/constant.hpp"
11+
#include "openvino/op/gather.hpp"
12+
#include "openvino/op/group_normalization.hpp"
13+
#include "openvino/op/multiply.hpp"
14+
#include "openvino/op/mvn.hpp"
15+
#include "openvino/op/reshape.hpp"
16+
#include "openvino/op/squeeze.hpp"
17+
#include "openvino/pass/pattern/op/optional.hpp"
18+
#include "openvino/pass/pattern/op/wrap_type.hpp"
19+
#include "transformations/utils/utils.hpp"
20+
21+
using namespace ov::pass::pattern;
22+
23+
ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
24+
MATCHER_SCOPE(GroupNormalizationFusion);
25+
26+
auto has_real_not_quantized_type = [](const ov::Output<ov::Node>& output) -> bool {
27+
const auto& T = output.get_element_type();
28+
return (T.is_real() && (!T.is_quantized()));
29+
};
30+
31+
auto has_at_least_2d_shape = [](const ov::Output<ov::Node>& output) -> bool {
32+
const auto& output_ps = output.get_partial_shape();
33+
return (output_ps.rank().is_static()) && (output_ps.rank().get_length() >= 2);
34+
};
35+
36+
auto input_m = any_input(all_of({has_real_not_quantized_type, has_at_least_2d_shape, has_static_dim(1)}));
37+
38+
auto pre_mvn_shape_const_m = wrap_type<ov::op::v0::Constant>(all_of({rank_equals(1), has_static_dim(0)}));
39+
auto pre_mvn_reshape_m =
40+
wrap_type<ov::op::v1::Reshape>({input_m, pre_mvn_shape_const_m},
41+
all_of({has_real_not_quantized_type, rank_equals(3), has_static_dim(1)}));
42+
43+
auto mvn_reduction_axes_const_m = wrap_type<ov::op::v0::Constant>(all_of({rank_equals(1), has_static_dim(0)}));
44+
auto mvn_m = wrap_type<ov::op::v6::MVN>({pre_mvn_reshape_m, mvn_reduction_axes_const_m});
45+
46+
auto instance_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()}));
47+
auto instance_norm_opt_gamma_m = optional<ov::op::v1::Multiply>({mvn_m, instance_norm_gamma_m});
48+
49+
auto instance_norm_beta_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()}));
50+
auto instance_norm_opt_gamma_opt_beta_m =
51+
optional<ov::op::v1::Add>({instance_norm_opt_gamma_m, instance_norm_beta_m});
52+
53+
auto post_instance_norm_shape_m = any_input(all_of({rank_equals(1), has_static_dim(0)}));
54+
auto post_instance_norm_reshape_m =
55+
wrap_type<ov::op::v1::Reshape>({instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m},
56+
all_of({has_real_not_quantized_type, has_at_least_2d_shape, has_static_dim(1)}));
57+
58+
auto group_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()}));
59+
auto group_norm_gamma_multiply_m =
60+
wrap_type<ov::op::v1::Multiply>({post_instance_norm_reshape_m, group_norm_gamma_m});
61+
62+
auto group_norm_beta_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()}));
63+
auto group_norm_beta_add_m = wrap_type<ov::op::v1::Add>({group_norm_gamma_multiply_m, group_norm_beta_m});
64+
65+
ov::matcher_pass_callback callback = [=](Matcher& m) {
66+
const auto& pattern_map = m.get_pattern_value_map();
67+
68+
const auto& input = pattern_map.at(input_m);
69+
const auto& input_ps = input.get_partial_shape();
70+
71+
const auto& T = input.get_element_type();
72+
73+
const auto& pre_mvn_reshape_out_ps = pattern_map.at(pre_mvn_reshape_m).get_partial_shape();
74+
75+
const size_t num_channels = static_cast<size_t>(input_ps[1].get_max_length());
76+
const size_t num_groups = static_cast<size_t>(pre_mvn_reshape_out_ps[1].get_max_length());
77+
78+
// we expect to reshape input in a way that would merge all spatial dimensions
79+
// but leave batch and channel dimensions untouched
80+
const auto& pre_mvn_shape = pattern_map.at(pre_mvn_shape_const_m);
81+
const auto& pre_mvn_shape_const =
82+
ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(pre_mvn_shape_const_m).get_node_shared_ptr());
83+
const auto& pre_mvn_shape_out_ps = pre_mvn_shape.get_shape();
84+
if (pre_mvn_shape_out_ps[0] != 3)
85+
return false;
86+
87+
auto pre_mvn_shape_vals_correct = [](const std::vector<int64_t>& pre_mvn_shape_vals,
88+
const ov::PartialShape& input_ps,
89+
const ov::Dimension::value_type num_groups) -> bool {
90+
bool res = true;
91+
if (input_ps[0].is_dynamic()) {
92+
if (pre_mvn_shape_vals[0] != 0ll)
93+
res = false;
94+
} else {
95+
if ((pre_mvn_shape_vals[0] != 0ll) &&
96+
(pre_mvn_shape_vals[0] != static_cast<long long>(input_ps[0].get_max_length())))
97+
res = false;
98+
}
99+
if ((pre_mvn_shape_vals[1] != 0ll) && (pre_mvn_shape_vals[1] != static_cast<long long>(num_groups)))
100+
res = false;
101+
if (pre_mvn_shape_vals[2] != -1ll)
102+
res = false;
103+
return res;
104+
};
105+
106+
if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const->cast_vector<int64_t>(), input_ps, num_groups))
107+
return false;
108+
109+
// number of channels has to be divisible by number of groups
110+
if (num_channels % num_groups != 0)
111+
return false;
112+
113+
// first dimension of MVN input (batch_size) has to be the same
114+
// as in pattern input
115+
if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length())
116+
return false;
117+
118+
// we expect to execute normalization over last dimension of MVN input
119+
const auto& mvn_reduction_axes = pattern_map.at(mvn_reduction_axes_const_m);
120+
const auto& mvn_reduction_axes_const =
121+
ov::as_type_ptr<ov::op::v0::Constant>(mvn_reduction_axes.get_node_shared_ptr());
122+
const auto& mvn_reduction_axes_out_shape = mvn_reduction_axes.get_shape();
123+
if (mvn_reduction_axes_out_shape[0] != 1)
124+
return false;
125+
126+
auto mvn_reduction_axes_correct = [](const std::vector<int64_t>& mvn_reduction_axes) -> bool {
127+
bool res = true;
128+
if ((mvn_reduction_axes[0] != 2ll) && (mvn_reduction_axes[0] != -1ll))
129+
return false;
130+
return res;
131+
};
132+
133+
if (!mvn_reduction_axes_correct(mvn_reduction_axes_const->cast_vector<int64_t>()))
134+
return false;
135+
136+
const auto& post_instance_norm_reshape_out_ps =
137+
pattern_map.at(post_instance_norm_reshape_m).get_partial_shape();
138+
// post instance norm shape has to be same as in pattern input
139+
if (post_instance_norm_reshape_out_ps != input_ps)
140+
return false;
141+
142+
const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m);
143+
if (group_norm_gamma.get_element_type() != T)
144+
return false;
145+
if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels)
146+
return false;
147+
148+
const auto& group_norm_beta = pattern_map.at(group_norm_beta_m);
149+
if (group_norm_beta.get_element_type() != T)
150+
return false;
151+
if (ov::shape_size(group_norm_beta.get_shape()) != num_channels)
152+
return false;
153+
154+
ov::NodeVector nodes;
155+
156+
std::shared_ptr<ov::Node> group_norm_gamma_1d_m = std::make_shared<ov::op::v0::Squeeze>(group_norm_gamma);
157+
nodes.push_back(group_norm_gamma_1d_m);
158+
const auto& group_norm_gamma_1d_out_ps = group_norm_gamma_1d_m->get_output_partial_shape(0);
159+
160+
auto expected_param_shape = ov::PartialShape({static_cast<ov::Dimension>(num_channels)});
161+
if (group_norm_gamma_1d_out_ps != expected_param_shape)
162+
return false;
163+
164+
std::shared_ptr<ov::Node> group_norm_beta_1d_m = std::make_shared<ov::op::v0::Squeeze>(group_norm_beta);
165+
nodes.push_back(group_norm_beta_1d_m);
166+
const auto& group_norm_beta_1d_out_ps = group_norm_beta_1d_m->get_output_partial_shape(0);
167+
168+
if (group_norm_beta_1d_out_ps != expected_param_shape)
169+
return false;
170+
171+
auto gather_axis_const_m = op::v0::Constant::create(element::i64, Shape{1}, {0});
172+
nodes.push_back(gather_axis_const_m);
173+
auto gather_indices_vals = std::vector<int64_t>();
174+
for (auto i = 0ull; i < num_groups; i++)
175+
gather_indices_vals.insert(gather_indices_vals.end(), num_channels / num_groups, i);
176+
auto gather_indices_const_m = op::v0::Constant::create(element::i64, Shape{num_channels}, gather_indices_vals);
177+
nodes.push_back(gather_indices_const_m);
178+
179+
if (pattern_map.count(instance_norm_beta_m) > 0) {
180+
const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m);
181+
if (instance_norm_beta.get_element_type() != T)
182+
return false;
183+
if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups)
184+
return false;
185+
186+
// ensure that instance_norm_beta will have shape compatible
187+
// with group_norm parameters, i.e. 1D vector of shape (num_channels)
188+
std::shared_ptr<ov::Node> instance_norm_beta_1d_m = nullptr;
189+
if (ov::shape_size(instance_norm_beta.get_shape()) == 1) {
190+
auto shape_1d_const_m = op::v0::Constant::create(element::i64, Shape{1}, {1});
191+
nodes.push_back(shape_1d_const_m);
192+
instance_norm_beta_1d_m =
193+
std::make_shared<ov::op::v1::Reshape>(instance_norm_beta, shape_1d_const_m, true);
194+
nodes.push_back(instance_norm_beta_1d_m);
195+
} else {
196+
instance_norm_beta_1d_m = std::make_shared<ov::op::v0::Squeeze>(instance_norm_beta);
197+
nodes.push_back(instance_norm_beta_1d_m);
198+
}
199+
200+
instance_norm_beta_1d_m = std::make_shared<ov::op::v8::Gather>(instance_norm_beta_1d_m,
201+
gather_indices_const_m,
202+
gather_axis_const_m);
203+
nodes.push_back(instance_norm_beta_1d_m);
204+
205+
const auto& instance_norm_beta_1d_ps = instance_norm_beta_1d_m->get_output_partial_shape(0);
206+
if (instance_norm_beta_1d_ps != expected_param_shape)
207+
return false;
208+
209+
// group_norm_beta = group_norm_gamma * instance_norm_beta + group_norm_beta
210+
auto group_norm_beta_corr_multiply_m =
211+
std::make_shared<ov::op::v1::Multiply>(group_norm_gamma_1d_m, instance_norm_beta_1d_m);
212+
nodes.push_back(group_norm_beta_corr_multiply_m);
213+
group_norm_beta_1d_m =
214+
std::make_shared<ov::op::v1::Add>(group_norm_beta_corr_multiply_m, group_norm_beta_1d_m);
215+
nodes.push_back(group_norm_beta_1d_m);
216+
}
217+
218+
if (pattern_map.count(instance_norm_gamma_m) > 0) {
219+
const auto& instance_norm_gamma = pattern_map.at(instance_norm_gamma_m);
220+
if (instance_norm_gamma.get_element_type() != T)
221+
return false;
222+
if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups)
223+
return false;
224+
225+
// ensure that instance_norm_gamma will have shape compatible
226+
// with group_norm parameters, i.e. 1D vector of shape (num_channels)
227+
std::shared_ptr<ov::Node> instance_norm_gamma_1d_m = nullptr;
228+
if (ov::shape_size(instance_norm_gamma.get_shape()) == 1) {
229+
auto shape_1d_const_m = op::v0::Constant::create(element::i64, Shape{1}, {1});
230+
nodes.push_back(shape_1d_const_m);
231+
instance_norm_gamma_1d_m =
232+
std::make_shared<ov::op::v1::Reshape>(instance_norm_gamma, shape_1d_const_m, true);
233+
nodes.push_back(instance_norm_gamma_1d_m);
234+
} else {
235+
instance_norm_gamma_1d_m = std::make_shared<ov::op::v0::Squeeze>(instance_norm_gamma);
236+
nodes.push_back(instance_norm_gamma_1d_m);
237+
}
238+
239+
instance_norm_gamma_1d_m = std::make_shared<ov::op::v8::Gather>(instance_norm_gamma_1d_m,
240+
gather_indices_const_m,
241+
gather_axis_const_m);
242+
nodes.push_back(instance_norm_gamma_1d_m);
243+
244+
const auto& instance_norm_gamma_1d_ps = instance_norm_gamma_1d_m->get_output_partial_shape(0);
245+
if (instance_norm_gamma_1d_ps != expected_param_shape)
246+
return false;
247+
248+
// group_norm_gamma *= instance_norm_gamma
249+
group_norm_gamma_1d_m =
250+
std::make_shared<ov::op::v1::Multiply>(group_norm_gamma_1d_m, instance_norm_gamma_1d_m);
251+
nodes.push_back(group_norm_gamma_1d_m);
252+
}
253+
254+
// we need to cast mvn to MVN layer type in order to read actual epsilon value
255+
const auto& mvn_out = pattern_map.at(mvn_m);
256+
const auto& mvn = ov::as_type_ptr<ov::op::v6::MVN>(mvn_out.get_node_shared_ptr());
257+
const auto& epsilon = mvn->get_eps();
258+
259+
// reuse original friendly names for gamma and beta inputs
260+
group_norm_gamma_1d_m->set_friendly_name(group_norm_gamma_m->get_friendly_name());
261+
group_norm_beta_1d_m->set_friendly_name(group_norm_beta_m->get_friendly_name());
262+
263+
// we can finally create GroupNormalization op
264+
std::shared_ptr<ov::Node> group_norm = std::make_shared<ov::op::v12::GroupNormalization>(input,
265+
group_norm_gamma_1d_m,
266+
group_norm_beta_1d_m,
267+
num_groups,
268+
epsilon);
269+
nodes.push_back(group_norm);
270+
271+
// and do actual graph substitution
272+
group_norm->set_friendly_name(m.get_match_root()->get_friendly_name());
273+
ov::copy_runtime_info(m.get_matched_nodes(), nodes);
274+
ov::replace_node(m.get_match_root(), group_norm);
275+
return true;
276+
};
277+
278+
auto m = std::make_shared<Matcher>(group_norm_beta_add_m, matcher_name);
279+
this->register_matcher(m, callback);
280+
}

0 commit comments

Comments
 (0)