Skip to content

Commit 6675832

Browse files
committed
[GPU] Extend OptimizeSubsequentReshapes pass to support any cases with a single dynamic dimension
1 parent f671bfc commit 6675832

File tree

2 files changed

+65
-17
lines changed

2 files changed

+65
-17
lines changed

src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ OptimizeSubsequentReshapes::OptimizeSubsequentReshapes() {
1818
using namespace ov::pass::pattern;
1919
using ov::pass::pattern::op::Or;
2020

21-
auto dynamic_batch_only = [](Output<Node> output) {
21+
auto single_dynamic_dim = [](Output<Node> output) {
2222
const auto& shape = output.get_partial_shape();
2323

2424
if (shape.rank().is_dynamic())
@@ -27,23 +27,23 @@ OptimizeSubsequentReshapes::OptimizeSubsequentReshapes() {
2727
if (shape.size() <= 1)
2828
return false;
2929

30-
if (shape[0].is_static())
31-
return false;
30+
auto dynamic_dims = 0;
31+
for (size_t i = 0; i < shape.size(); i++)
32+
dynamic_dims += shape[i].is_dynamic() ? 1 : 0;
3233

33-
for (size_t i = 1; i < shape.size(); i++)
34-
if (shape[i].is_dynamic())
35-
return false;
34+
if (dynamic_dims != 1)
35+
return false;
3636

3737
return true;
3838
};
3939

40-
auto first_reshape_data = any_input(dynamic_batch_only);
40+
auto first_reshape_data = any_input(single_dynamic_dim);
4141
auto first_reshape_pattern = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
4242
auto first_reshape = wrap_type<ov::op::v1::Reshape>({ first_reshape_data, first_reshape_pattern },
43-
dynamic_batch_only && ov::pass::pattern::consumers_count(1));
43+
single_dynamic_dim && ov::pass::pattern::consumers_count(1));
4444

4545
auto second_reshape_pattern = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
46-
auto second_reshape = wrap_type<ov::op::v1::Reshape>({ first_reshape, second_reshape_pattern }, dynamic_batch_only);
46+
auto second_reshape = wrap_type<ov::op::v1::Reshape>({ first_reshape, second_reshape_pattern }, single_dynamic_dim);
4747

4848
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
4949
const auto& pattern_map = m.get_pattern_value_map();
@@ -74,14 +74,14 @@ OptimizeSubsequentReshapes::OptimizeSubsequentReshapes() {
7474
std::vector<int32_t> new_pattern;
7575
for (auto& dim : second_reshape_ps) {
7676
if (dim.is_dynamic()) {
77-
new_pattern.push_back(0);
77+
new_pattern.push_back(-1);
7878
} else {
7979
new_pattern.push_back(dim.get_length());
8080
}
8181
}
8282

8383
auto new_pattern_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_pattern.size()}, new_pattern);
84-
auto new_reshape = std::make_shared<ov::op::v1::Reshape>(first_reshape_node->input(0).get_source_output(), new_pattern_const, true);
84+
auto new_reshape = std::make_shared<ov::op::v1::Reshape>(first_reshape_node->input(0).get_source_output(), new_pattern_const, false);
8585
new_reshape->set_friendly_name(second_reshape_node->get_friendly_name());
8686

8787
ov::replace_node(second_reshape_node, new_reshape);

src/plugins/intel_gpu/tests/unit/transformations/optimize_subsequent_reshapes_test.cpp

+54-6
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ TEST_F(TransformationTestsF, OptimizeSubsequentReshapes1) {
3939
}
4040
{
4141
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, 1, 4096 });
42-
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, std::vector<int32_t>{ 0, 4096 });
43-
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, reshape_pattern, true);
42+
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, std::vector<int32_t>{ -1, 4096 });
43+
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, reshape_pattern, false);
4444
auto result = std::make_shared<ov::op::v0::Result>(reshape);
4545

4646
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ result }, ov::ParameterVector{ input });
@@ -63,8 +63,8 @@ TEST_F(TransformationTestsF, OptimizeSubsequentReshapes2) {
6363
}
6464
{
6565
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, 1, 4096 });
66-
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{ 0, 32, 1, 128 });
67-
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, reshape_pattern, true);
66+
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{ -1, 32, 1, 128 });
67+
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, reshape_pattern, false);
6868
auto result = std::make_shared<ov::op::v0::Result>(reshape);
6969

7070
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ result }, ov::ParameterVector{ input });
@@ -87,8 +87,56 @@ TEST_F(TransformationTestsF, OptimizeSubsequentReshapes3) {
8787
}
8888
{
8989
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, 32, 1, 128 });
90-
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, std::vector<int32_t>{ 0, 4096 });
91-
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, reshape_pattern, true);
90+
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, std::vector<int32_t>{ -1, 4096 });
91+
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, reshape_pattern, false);
92+
auto result = std::make_shared<ov::op::v0::Result>(reshape);
93+
94+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ result }, ov::ParameterVector{ input });
95+
}
96+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
97+
}
98+
99+
TEST_F(TransformationTestsF, OptimizeSubsequentReshapes4) {
100+
{
101+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ 1, -1, 256 });
102+
auto first_reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{ 0, 0, 2, 128 });
103+
auto first_reshape = std::make_shared<ov::op::v1::Reshape>(input, first_reshape_pattern, true);
104+
105+
auto second_reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, std::vector<int32_t>{ -1, 256 });
106+
auto second_reshape = std::make_shared<ov::op::v1::Reshape>(first_reshape, second_reshape_pattern, false);
107+
auto result = std::make_shared<ov::op::v0::Result>(second_reshape);
108+
109+
model = std::make_shared<ov::Model>(ov::NodeVector{ result }, ov::ParameterVector{ input });
110+
manager.register_pass<OptimizeSubsequentReshapes>();
111+
}
112+
{
113+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ 1, -1, 256 });
114+
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, std::vector<int32_t>{ -1, 256 });
115+
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, reshape_pattern, false);
116+
auto result = std::make_shared<ov::op::v0::Result>(reshape);
117+
118+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ result }, ov::ParameterVector{ input });
119+
}
120+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
121+
}
122+
123+
TEST_F(TransformationTestsF, OptimizeSubsequentReshapes5) {
124+
{
125+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ 1, 256, -1 });
126+
auto first_reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{ 0, 64, 4, -1 });
127+
auto first_reshape = std::make_shared<ov::op::v1::Reshape>(input, first_reshape_pattern, true);
128+
129+
auto second_reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{ -1, 32, 2, 4 });
130+
auto second_reshape = std::make_shared<ov::op::v1::Reshape>(first_reshape, second_reshape_pattern, true);
131+
auto result = std::make_shared<ov::op::v0::Result>(second_reshape);
132+
133+
model = std::make_shared<ov::Model>(ov::NodeVector{ result }, ov::ParameterVector{ input });
134+
manager.register_pass<OptimizeSubsequentReshapes>();
135+
}
136+
{
137+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ 1, 256, -1 });
138+
auto reshape_pattern = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{4}, std::vector<int32_t>{ -1, 32, 2, 4 });
139+
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, reshape_pattern, false);
92140
auto result = std::make_shared<ov::op::v0::Result>(reshape);
93141

94142
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ result }, ov::ParameterVector{ input });

0 commit comments

Comments
 (0)