Skip to content

Commit c62be51

Browse files
authored
[Transformations] Enable dynamic decomposition BTS and STB ops (openvinotoolkit#15179)
* Add dynamism for BatchToSpace conversion * Extend dynamism for BatchToSpace conversion Only block input needs to be const 'cse axes_order_const is freaky * Enhace dynamism for BatchToSpace conversion Block input need not be const now. * Add dynamism for STB by elements conversion * Remove const need for crops for BTS by_elements * temp for review * Try to fix output tensor overwrite * Make test to reproduce invalid shape inference * Reproduce the error with template plugin * Fix code style * Fix 0D inputs issue * Remove 0D shape parts before Concat * Apply nested namespaces * Enable non-constant STB Block input * Fix BTS runtime info * Fix STB by elems runtime info * Add dynamism for STB conversion * Add BTS dynamic data test * Add STB dynamic data test * Reduce STB concats * Add tests naming * Edit * style * Consider other block element types * Enhance type test * Use opset10 only * Check block shape
1 parent fac03ee commit c62be51

File tree

3 files changed

+411
-262
lines changed

3 files changed

+411
-262
lines changed

src/common/transformations/src/transformations/op_conversions/convert_batch_to_space.cpp

+137-145
Original file line numberDiff line numberDiff line change
@@ -4,224 +4,216 @@
44

55
#include "transformations/op_conversions/convert_batch_to_space.hpp"
66

7+
#include <algorithm>
8+
#include <climits>
79
#include <memory>
810
#include <ngraph/pattern/op/wrap_type.hpp>
911
#include <ngraph/rt_info.hpp>
10-
#include <openvino/opsets/opset3.hpp>
12+
#include <openvino/opsets/opset10.hpp>
1113
#include <vector>
1214

1315
#include "itt.hpp"
1416

17+
using namespace std;
18+
using namespace ov::opset10;
19+
using namespace ov::element;
20+
1521
void ov::pass::ConvertBatchToSpace::convert_batch_to_space() {
1622
MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space);
17-
auto batch_to_space = ngraph::pattern::wrap_type<ov::opset3::BatchToSpace>();
18-
matcher_pass_callback callback = [](pattern::Matcher& m) {
19-
auto batch_to_space = std::dynamic_pointer_cast<ov::opset3::BatchToSpace>(m.get_match_root());
20-
if (!batch_to_space) {
23+
const auto batch_to_space = pattern::wrap_type<BatchToSpace>();
24+
matcher_pass_callback callback = [this](pattern::Matcher& m) {
25+
const auto batch_to_space = dynamic_pointer_cast<BatchToSpace>(m.get_match_root());
26+
if (!batch_to_space || transformation_callback(batch_to_space)) {
2127
return false;
2228
}
2329

24-
NodeVector new_ops;
25-
auto data = batch_to_space->input_value(0);
26-
auto block = batch_to_space->input_value(1);
27-
auto crops_begin = batch_to_space->input_value(2);
28-
auto crops_end = batch_to_space->input_value(3);
30+
NodeRegistry rg;
31+
const auto data = batch_to_space->input_value(0);
32+
const auto block = batch_to_space->input_value(1);
33+
const auto crops_begin = batch_to_space->input_value(2);
34+
const auto crops_end = batch_to_space->input_value(3);
2935

30-
if (data.get_partial_shape().is_dynamic()) {
31-
return false;
36+
const auto data_shape_rank = data.get_partial_shape().rank();
37+
if (data_shape_rank.is_dynamic()) {
38+
return false; // because StridedSlice masks are std::vector
3239
}
33-
const auto& data_shape = data.get_shape();
3440

35-
const auto block_const = std::dynamic_pointer_cast<opset3::Constant>(block.get_node_shared_ptr());
36-
const auto crops_begin_const = std::dynamic_pointer_cast<opset3::Constant>(crops_begin.get_node_shared_ptr());
37-
const auto crops_end_const = std::dynamic_pointer_cast<opset3::Constant>(crops_end.get_node_shared_ptr());
38-
39-
if (!block_const || !crops_begin_const || !crops_end_const) {
41+
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
4042
return false;
4143
}
42-
43-
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
44-
const std::vector<int64_t>& crops_end_values = crops_end_const->cast_vector<int64_t>();
44+
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
4545

4646
// First we have to disperse the data from batch, then rearrange them
4747
// so as appropriate chunks of data where close to their destination place.
48-
// Finally squeeze data from respective dimensions.ss
49-
std::vector<int64_t> dispersed_shape;
50-
int64_t b_dim_divider = 1;
51-
for (const auto& el : block_values) {
52-
b_dim_divider *= el;
53-
}
48+
// Finally squeeze data from respective dimensions
49+
50+
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
51+
const auto shape_of_data = rg.make<ShapeOf>(data, block.get_element_type());
52+
const auto batch = rg.make<Gather>(shape_of_data, zero, zero);
53+
const auto block_prod = rg.make<ReduceProd>(block, zero);
54+
const auto batch_div = rg.make<Divide>(batch, block_prod);
5455

5556
// note: B_0 is expected to be 1.
5657
// x' = reshape(`data`, [B_1, ..., B_{N - 1}, batch / (B_1 * ... B_{N - 1}), D_1, D_2, ...,
5758
// D_{N - 1}]),
5859
// where B_i = block_shape[i]
59-
dispersed_shape.insert(dispersed_shape.begin(), block_values.begin() + 1, block_values.end());
60-
dispersed_shape.push_back(data_shape.at(0) / b_dim_divider);
61-
for (size_t i = 1; i < data_shape.size(); ++i) {
62-
dispersed_shape.push_back(data_shape.at(i));
63-
}
64-
65-
const auto out_pattern_1 =
66-
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
60+
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
61+
const auto end = rg.make<Constant>(i64, Shape{1}, block_length);
62+
const auto block_tail = rg.make<Slice>(block, one, end, one);
63+
const auto data_shape_tail = rg.make<Slice>(shape_of_data, one, end, one);
64+
const auto dispersed_shape = rg.make<Concat>(OutputVector{block_tail, batch_div, data_shape_tail}, 0);
6765
const bool special_zero = false;
68-
std::shared_ptr<Node> flat_node = std::make_shared<ov::opset3::Reshape>(data, out_pattern_1, special_zero);
69-
new_ops.push_back(flat_node);
66+
shared_ptr<Node> flat_node = rg.make<Reshape>(data, dispersed_shape, special_zero);
67+
7068
// calculate axes to transpose
7169
// x'' = transpose(x', [N, N + 1, 0, N + 2, 1, ..., N + N - 1, N - 1])
72-
std::vector<size_t> axes_order{block_values.size() - 1};
73-
for (size_t i = 0; i < block_values.size() - 1; ++i) {
74-
axes_order.push_back(i + block_values.size());
70+
vector<int64_t> axes_order{block_length - 1};
71+
for (int64_t i = 0; i < block_length - 1; ++i) {
72+
axes_order.push_back(i + block_length);
7573
axes_order.push_back(i);
7674
}
75+
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
76+
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
7777

78-
const auto axes_order_const =
79-
opset3::Constant::create(element::i64,
80-
Shape{axes_order.size()},
81-
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
82-
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
83-
new_ops.push_back(flat_node);
8478
// x''' = reshape(x'', [batch / (B_1 * ... * B_{N - 1}), D_1 * B_1, D_2 * B_2, ... , D_{N - 1}
8579
// * B_{N - 1}])
86-
std::vector<int64_t> squeezed_shape;
87-
squeezed_shape.push_back(data_shape.at(0) / b_dim_divider);
88-
for (size_t i = 1; i < block_values.size(); ++i) {
89-
squeezed_shape.push_back(data_shape.at(i) * block_values.at(i));
90-
}
91-
92-
const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
93-
flat_node = std::make_shared<opset3::Reshape>(flat_node, out_pattern_2, special_zero);
94-
new_ops.push_back(flat_node);
80+
const auto squeezed_shape_tail = rg.make<Multiply>(block_tail, data_shape_tail);
81+
const auto squeezed_shape = rg.make<Concat>(OutputVector{batch_div, squeezed_shape_tail}, 0);
82+
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
9583

9684
// Crop the start and end of dimensions according to `crops_begin`, `crops_end` to produce
9785
// the output of shape:
9886
// note: `crops_begin[0], crops_end[0]` are expected to be 0.
9987
// `y = [batch / (B_1 * ... * B_{N - 1}), crop(D_1 * B_1, crops_begin[1], crops_end[1]),
10088
// crop(D_2 * B_2, crops_begin[2], crops_end[2]), ... ,
10189
// crop(D_{N - 1} * B_{N - 1}, crops_begin[N - 1], crops_end[N - 1])]`
102-
std::vector<int64_t> upperbounds_values;
103-
auto flat_node_shape = flat_node->get_shape();
104-
for (size_t i = 0; i < flat_node_shape.size(); ++i) {
105-
upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i));
106-
}
107-
108-
const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(),
109-
Shape{upperbounds_values.size()},
110-
upperbounds_values);
90+
const auto shape_of_flat_node = rg.make<ShapeOf>(flat_node, crops_end.get_element_type());
91+
const auto upperbounds = rg.make<Subtract>(shape_of_flat_node, crops_end);
11192

112-
std::vector<int64_t> begin_mask(data_shape.size(), 0);
113-
std::vector<int64_t> end_mask(data_shape.size(), 0);
114-
flat_node =
115-
std::make_shared<opset3::StridedSlice>(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask);
116-
new_ops.push_back(flat_node);
93+
const auto begin_mask = vector<int64_t>(data_shape_rank.get_length(), 0);
94+
const auto& end_mask = begin_mask;
95+
flat_node = rg.make<StridedSlice>(flat_node, crops_begin, upperbounds, begin_mask, end_mask);
11796

11897
flat_node->set_friendly_name(batch_to_space->get_friendly_name());
119-
ngraph::copy_runtime_info(batch_to_space, new_ops);
120-
ngraph::replace_node(batch_to_space, flat_node);
98+
copy_runtime_info(batch_to_space, rg.get());
99+
replace_node(batch_to_space, flat_node);
121100
return true;
122101
};
123102

124-
auto m = std::make_shared<ngraph::pattern::Matcher>(batch_to_space, matcher_name);
103+
const auto m = make_shared<pattern::Matcher>(batch_to_space, matcher_name);
125104
this->register_matcher(m, callback);
126105
}
127106

128107
void ov::pass::ConvertBatchToSpace::convert_batch_to_space_by_elements() {
129108
MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space_by_elements);
130-
auto batch_to_space = ngraph::pattern::wrap_type<ov::opset3::BatchToSpace>();
109+
const auto batch_to_space = pattern::wrap_type<BatchToSpace>();
131110
matcher_pass_callback callback = [this](pattern::Matcher& m) {
132-
auto batch_to_space = std::dynamic_pointer_cast<ov::opset3::BatchToSpace>(m.get_match_root());
133-
if (!batch_to_space) {
111+
const auto batch_to_space = dynamic_pointer_cast<BatchToSpace>(m.get_match_root());
112+
if (!batch_to_space || transformation_callback(batch_to_space)) {
134113
return false;
135114
}
136115

137-
auto data = batch_to_space->input_value(0);
116+
const auto data = batch_to_space->input_value(0);
138117

139-
if (data.get_partial_shape().is_dynamic()) {
140-
return false;
118+
const auto data_shape_rank = data.get_partial_shape().rank();
119+
if (data_shape_rank.is_dynamic()) {
120+
return false; // because StridedSlice masks are std::vector
141121
}
142-
auto data_shape = data.get_shape();
143122

144-
if (transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) {
145-
return false;
146-
}
147-
auto block = batch_to_space->input_value(1);
148-
auto crops_begin = batch_to_space->input_value(2);
149-
auto crops_end = batch_to_space->input_value(3);
150-
151-
const auto block_const = ov::as_type_ptr<opset3::Constant>(block.get_node_shared_ptr());
152-
const auto crops_begin_const = ov::as_type_ptr<opset3::Constant>(crops_begin.get_node_shared_ptr());
153-
const auto crops_end_const = ov::as_type_ptr<opset3::Constant>(crops_end.get_node_shared_ptr());
154-
155-
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
156-
const std::vector<int64_t>& crops_end_values = crops_end_const->cast_vector<int64_t>();
157-
158-
std::vector<int64_t> dispersed_shape(1);
159-
dispersed_shape.insert(dispersed_shape.end(), data_shape.begin(), data_shape.end());
160-
std::vector<size_t> axes_order(block_values.size() + 1);
161-
std::vector<int64_t> squeezed_shape(data_shape.begin(), data_shape.end());
162-
if (squeezed_shape.size() > block_values.size()) {
123+
const auto block = batch_to_space->input_value(1);
124+
const auto crops_begin = batch_to_space->input_value(2);
125+
const auto crops_end = batch_to_space->input_value(3);
126+
127+
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
163128
return false;
164129
}
165-
166-
NodeVector new_ops;
167-
168-
std::shared_ptr<Node> flat_node = data.get_node_shared_ptr();
169-
for (size_t block_idx = 1; block_idx < block_values.size(); ++block_idx) {
170-
dispersed_shape[0] = block_values[block_idx];
171-
dispersed_shape[1] /= block_values[block_idx];
172-
const auto out_pattern_1 =
173-
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
174-
const bool special_zero = false;
175-
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_1, special_zero);
176-
new_ops.push_back(flat_node);
177-
178-
size_t val = 1;
179-
for (size_t axis_idx = 0; axis_idx <= block_values.size(); ++axis_idx) {
180-
if ((block_idx + 1) == axis_idx) {
130+
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
131+
132+
NodeRegistry rg;
133+
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
134+
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
135+
const auto two = rg.make<Constant>(i64, Shape{1}, 2);
136+
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
137+
138+
const auto shape_of_data = rg.make<ShapeOf>(data, block.get_element_type());
139+
const auto et_zero = rg.make<Constant>(block.get_element_type(), Shape{1}, 0);
140+
shared_ptr<Node> dispersed_shape = rg.make<Concat>(OutputVector{et_zero, shape_of_data}, 0);
141+
shared_ptr<Node> squeezed_shape = shape_of_data;
142+
143+
shared_ptr<Node> flat_node = data.get_node_shared_ptr();
144+
145+
const auto make_concat = [&](OutputVector nodes) {
146+
nodes.erase(remove_if(nodes.begin(),
147+
nodes.end(),
148+
[](const Output<Node>& n) {
149+
return n.get_partial_shape().is_static() && n.get_shape().size() > 0 &&
150+
n.get_shape()[0] == 0;
151+
}),
152+
nodes.end());
153+
return rg.make<Concat>(nodes, 0);
154+
};
155+
156+
shared_ptr<Node> div;
157+
for (int64_t b_idx = 1; b_idx < block_length; ++b_idx) {
158+
const auto block_index = rg.make<Constant>(i64, Shape{1}, b_idx);
159+
const auto block_index_next = rg.make<Constant>(i64, Shape{1}, b_idx + 1);
160+
const auto block_value = rg.make<Gather>(block, block_index, zero);
161+
162+
// dispersed_shape[0] = block[b_idx];
163+
// dispersed_shape[1] /= block[b_idx];
164+
if (!div) {
165+
const auto batch = rg.make<Gather>(shape_of_data, zero, zero);
166+
div = rg.make<Divide>(batch, block_value);
167+
} else {
168+
div = rg.make<Divide>(div, block_value);
169+
}
170+
auto ds_tail = rg.make<Slice>(dispersed_shape, two, int_max, one);
171+
dispersed_shape = make_concat({block_value, div, ds_tail});
172+
constexpr auto special_zero = false;
173+
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, special_zero);
174+
175+
vector<int64_t> axes_order(block_length + 1);
176+
int64_t val = 1;
177+
for (int64_t axis_idx = 0; axis_idx <= block_length; ++axis_idx) {
178+
if ((b_idx + 1) == axis_idx) {
181179
axes_order[axis_idx] = 0;
182180
} else {
183181
axes_order[axis_idx] = val;
184182
val++;
185183
}
186184
}
187-
188-
const auto axes_order_const =
189-
ov::opset3::Constant::create(element::i64,
190-
Shape{axes_order.size()},
191-
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
192-
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
193-
new_ops.push_back(flat_node);
194-
195-
squeezed_shape[0] = dispersed_shape[1];
196-
squeezed_shape[block_idx] *= block_values[block_idx];
197-
dispersed_shape[block_idx + 1] = squeezed_shape[block_idx];
198-
const auto out_pattern_2 =
199-
opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
200-
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, special_zero);
201-
new_ops.push_back(flat_node);
185+
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
186+
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
187+
188+
// squeezed_shape[0] = dispersed_shape[1];
189+
// squeezed_shape[b_idx] *= block[b_idx];
190+
const auto sq_slice = rg.make<Slice>(squeezed_shape, one, block_index, one);
191+
const auto sq_bidx_dim = rg.make<Gather>(squeezed_shape, block_index, zero);
192+
const auto sq_mul = rg.make<Multiply>(sq_bidx_dim, block_value);
193+
const auto sq_shape_tail = rg.make<Slice>(squeezed_shape, block_index_next, int_max, one);
194+
squeezed_shape.reset();
195+
squeezed_shape = make_concat({div, sq_slice, sq_mul, sq_shape_tail});
196+
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
197+
198+
// dispersed_shape[b_idx + 1] = squeezed_shape[b_idx];
199+
const auto ds_front = rg.make<Slice>(dispersed_shape, zero, block_index_next, one);
200+
ds_tail = rg.make<Slice>(dispersed_shape, rg.make<Constant>(i64, Shape{1}, b_idx + 2), int_max, one);
201+
dispersed_shape = make_concat({ds_front, sq_mul, ds_tail});
202202
}
203203

204-
std::vector<int64_t> upperbounds_values;
205-
auto flat_node_shape = flat_node->get_shape();
206-
for (size_t i = 0; i < flat_node_shape.size(); ++i) {
207-
upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i));
208-
}
209-
const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(),
210-
Shape{upperbounds_values.size()},
211-
upperbounds_values);
204+
const auto shape_of_flat_node = rg.make<ShapeOf>(flat_node, crops_end.get_element_type());
205+
const auto upperbounds = rg.make<Subtract>(shape_of_flat_node, crops_end);
212206

213-
std::vector<int64_t> begin_mask(data_shape.size(), 0);
214-
std::vector<int64_t> end_mask(data_shape.size(), 0);
215-
flat_node =
216-
std::make_shared<opset3::StridedSlice>(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask);
217-
new_ops.push_back(flat_node);
207+
const auto begin_mask = vector<int64_t>(data_shape_rank.get_length(), 0);
208+
const auto& end_mask = begin_mask;
209+
flat_node = rg.make<StridedSlice>(flat_node, crops_begin, upperbounds, begin_mask, end_mask);
218210

219211
flat_node->set_friendly_name(batch_to_space->get_friendly_name());
220-
ngraph::copy_runtime_info(batch_to_space, new_ops);
221-
ngraph::replace_node(batch_to_space, flat_node);
212+
copy_runtime_info(batch_to_space, rg.get());
213+
replace_node(batch_to_space, flat_node);
222214
return true;
223215
};
224216

225-
auto m = std::make_shared<ngraph::pattern::Matcher>(batch_to_space, matcher_name);
217+
const auto m = make_shared<pattern::Matcher>(batch_to_space, matcher_name);
226218
this->register_matcher(m, callback);
227219
}

0 commit comments

Comments
 (0)