Skip to content

Commit 488074e

Browse files
committed
[GPU] Fix strided slice ellipsis mode
1 parent 909535f commit 488074e

File tree

6 files changed

+226
-67
lines changed

6 files changed

+226
-67
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/strided_slice.cpp

+62-10
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,36 @@
1212

1313
namespace {
1414
template <typename T, typename DT, typename = typename std::enable_if<std::is_convertible<DT, T>::value>::type>
15-
std::vector<T>& pad_vector_to_size(std::vector<T>& data, size_t size, DT value) {
16-
for (size_t i = data.size(); i < size; ++i) {
17-
data.push_back(static_cast<T>(value));
15+
void pad_vector_to_size(std::vector<T>& data, size_t size, DT value, const std::vector<int64_t>& ellipsis_mask) {
16+
bool apply_ellipsis_mask = std::count(ellipsis_mask.begin(), ellipsis_mask.end(), 1) == 1;
17+
if (apply_ellipsis_mask) {
18+
std::vector<T> temp;
19+
size_t ellipsis_pos1 = 0;
20+
for (size_t i = 0; i < ellipsis_mask.size(); i++) {
21+
if (ellipsis_mask[i] == 1) {
22+
ellipsis_pos1 = i;
23+
break;
24+
}
25+
}
26+
27+
size_t dims_after = data.size() - ellipsis_pos1 - 1;
28+
size_t ellipsis_pos2 = size - dims_after - 1;;
29+
30+
for (size_t i = 0; i < ellipsis_pos1; i++)
31+
temp.push_back(data[i]);
32+
33+
for (size_t i = ellipsis_pos1; i < ellipsis_pos2 + 1; i++)
34+
temp.push_back(value);
35+
36+
for (size_t i = 1; i < size - ellipsis_pos2; i++)
37+
temp.push_back(data[i + ellipsis_pos1]);
38+
39+
data = temp;
40+
} else {
41+
for (size_t i = data.size(); i < size; ++i) {
42+
data.push_back(static_cast<T>(value));
43+
}
1844
}
19-
return data;
2045
}
2146

2247
template <typename T, typename MT>
@@ -64,6 +89,8 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
6489

6590
public:
6691
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
92+
std::cout << "Get params for " << impl_param.desc->id << "\n";
93+
6794
const auto& prim = impl_param.typed_desc<strided_slice>();
6895
auto params = get_default_params<kernel_selector::strided_slice_params>(impl_param, is_shape_agnostic);
6996
const size_t dims_num = params.inputs[0].Dimentions();
@@ -72,9 +99,27 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
7299
std::vector<int32_t> end(prim->end.begin(), prim->end.end());
73100
std::vector<int32_t> strides(prim->strides.begin(), prim->strides.end());
74101

102+
auto print_vec = [](std::vector<int32_t>& vec) {
103+
std::stringstream ss;
104+
105+
for (size_t i = 0; i < vec.size(); i++) {
106+
ss << vec[i] << ", ";
107+
}
108+
109+
return ss.str();
110+
};
111+
112+
113+
std::cout << "Begin vals: " << print_vec(begin) << "\n";
114+
std::cout << "End vals: " << print_vec(end) << "\n";
115+
std::cout << "Strides vals: " << print_vec(strides) << "\n";
116+
75117
// Getting data from constant inputs. There are 3 args: Begin, End, Stride
76118
if (!begin.empty() && !params.has_dynamic_tensors()) {
77-
pad_vector_to_size(begin, dims_num, 0);
119+
std::cout << "Constant begin\n";
120+
std::cout << "Begin before " << print_vec(begin) << "\n";
121+
pad_vector_to_size(begin, dims_num, 0, prim->ellipsis_mask);
122+
std::cout << "Begin after " << print_vec(begin) << "\n";
78123
params.begin_type = kernel_selector::base_params::ArgType::Constant;
79124
params.striding_params.push_back(begin);
80125
} else {
@@ -91,7 +136,7 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
91136
return offset;
92137
};
93138
if (!end.empty() && !params.has_dynamic_tensors()) {
94-
pad_vector_to_size(end, dims_num, 1);
139+
pad_vector_to_size(end, dims_num, 1, prim->ellipsis_mask);
95140
params.end_type = kernel_selector::base_params::ArgType::Constant;
96141
params.striding_params.push_back(end);
97142
} else {
@@ -108,7 +153,9 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
108153
return offset;
109154
};
110155
if (!strides.empty() && !params.has_dynamic_tensors()) {
111-
pad_vector_to_size(strides, dims_num, 1);
156+
std::cout << "Stride before " << print_vec(strides) << "\n";
157+
pad_vector_to_size(strides, dims_num, 1, prim->ellipsis_mask);
158+
std::cout << "Stride after " << print_vec(strides) << "\n";
112159
params.stride_type = kernel_selector::base_params::ArgType::Constant;
113160
params.striding_params.push_back(strides);
114161
} else {
@@ -118,23 +165,27 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
118165
params.stride_dims = stride_layout.count();
119166
}
120167

168+
std::cout << "params.striding_params " << params.striding_params.size() << "\n";
121169
auto begin_mask_ = prim->begin_mask;
122170
auto end_mask_ = prim->end_mask;
123171
auto new_axis_mask_ = prim->new_axis_mask;
124172
auto shrink_axis_mask_ = prim->shrink_axis_mask;
173+
auto ellipsis_mask_ = prim->ellipsis_mask;
125174

126175
std::vector<uint8_t> begin_mask(begin_mask_.begin(), begin_mask_.end());
127176
std::vector<uint8_t> end_mask(end_mask_.begin(), end_mask_.end());
128177
std::vector<uint8_t> new_axis_mask(new_axis_mask_.begin(), new_axis_mask_.end());
129178
std::vector<uint8_t> shrink_axis_mask(shrink_axis_mask_.begin(), shrink_axis_mask_.end());
179+
std::vector<uint8_t> ellipsis_mask(ellipsis_mask_.begin(), ellipsis_mask_.end());
130180
params.end_mask = std::move(end_mask);
131-
pad_vector_to_size(params.end_mask, dims_num, 0);
181+
pad_vector_to_size(params.end_mask, dims_num, 0, prim->ellipsis_mask);
132182
params.begin_mask = std::move(begin_mask);
133-
pad_vector_to_size(params.begin_mask, dims_num, 0);
183+
pad_vector_to_size(params.begin_mask, dims_num, 0, prim->ellipsis_mask);
134184

135185
params.new_axis_mask = new_axis_mask;
136186
params.shrink_axis_mask = shrink_axis_mask;
137-
pad_vector_to_size(params.shrink_axis_mask, dims_num, 0);
187+
params.ellipsis_mask = ellipsis_mask;
188+
pad_vector_to_size(params.shrink_axis_mask, dims_num, 0, prim->ellipsis_mask);
138189

139190
std::vector<size_t> logical_dims = params.inputs[0].LogicalDims();
140191
std::reverse(logical_dims.begin(), logical_dims.end()); // get dims in bfyx order
@@ -202,6 +253,7 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
202253
}
203254

204255
void update_dispatch_data(const kernel_impl_params& impl_param) override {
256+
std::cout << "Update dispatch data\n";
205257
auto kernel_params = get_kernel_params(impl_param, true);
206258
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data);
207259
}

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1541,6 +1541,9 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
15411541
_outputs = allocate_outputs();
15421542
}
15431543
}
1544+
if (_node) {
1545+
GPU_DEBUG_TRACE_DETAIL << _node->type()->to_string(*_node) << "\n";
1546+
}
15441547
if (_impl) {
15451548
_impl->set_node_params(node);
15461549
if (_impl->is_dynamic() && !_impl->is_cpu()) {

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/strided_slice_ref.cl

+97-57
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@ inline void FUNC(get_slice_step)(OPTIONAL_SHAPE_INFO_ARG
1010
int* step_batch, int* step_feature,
1111
int* step_w, int* step_z, int* step_y, int* step_x)
1212
{
13-
const uint batch_index = 0;
14-
const uint feature_index = 1;
13+
const uint batch_index = BATCH_DIM_IDX;
14+
const uint feature_index = FEATURE_DIM_IDX;
1515
#ifdef OUTPUT_LAYOUT_BFYX
16-
const uint y_index = 2;
17-
const uint x_index = 3;
16+
const uint y_index = Y_DIM_IDX;
17+
const uint x_index = X_DIM_IDX;
1818
#elif OUTPUT_LAYOUT_BFZYX
19-
const uint z_index = 2;
20-
const uint y_index = 3;
21-
const uint x_index = 4;
19+
const uint z_index = Z_DIM_IDX;
20+
const uint y_index = Y_DIM_IDX;
21+
const uint x_index = X_DIM_IDX;
2222
#elif OUTPUT_LAYOUT_BFWZYX
23-
const uint w_index = 2;
24-
const uint z_index = 3;
25-
const uint y_index = 4;
26-
const uint x_index = 5;
23+
const uint w_index = W_DIM_IDX;
24+
const uint z_index = Z_DIM_IDX;
25+
const uint y_index = Y_DIM_IDX;
26+
const uint x_index = X_DIM_IDX;
2727
#endif
2828

2929
*step_batch = batch_index < STRIDE_DIMS ? stride[batch_index] : 1;
@@ -55,20 +55,20 @@ inline void FUNC(get_slice_end)(OPTIONAL_SHAPE_INFO_ARG
5555
const uint out_z_num = INPUT0_SIZE_Z;
5656
const uint out_y_num = INPUT0_SIZE_Y;
5757
const uint out_x_num = INPUT0_SIZE_X;
58-
const uint batch_index = 0;
59-
const uint feature_index = 1;
58+
const uint batch_index = BATCH_DIM_IDX;
59+
const uint feature_index = FEATURE_DIM_IDX;
6060
#ifdef OUTPUT_LAYOUT_BFYX
61-
const uint y_index = 2;
62-
const uint x_index = 3;
61+
const uint y_index = Y_DIM_IDX;
62+
const uint x_index = X_DIM_IDX;
6363
#elif OUTPUT_LAYOUT_BFZYX
64-
const uint z_index = 2;
65-
const uint y_index = 3;
66-
const uint x_index = 4;
64+
const uint z_index = Z_DIM_IDX;
65+
const uint y_index = Y_DIM_IDX;
66+
const uint x_index = X_DIM_IDX;
6767
#elif OUTPUT_LAYOUT_BFWZYX
68-
const uint w_index = 2;
69-
const uint z_index = 3;
70-
const uint y_index = 4;
71-
const uint x_index = 5;
68+
const uint w_index = W_DIM_IDX;
69+
const uint z_index = Z_DIM_IDX;
70+
const uint y_index = Y_DIM_IDX;
71+
const uint x_index = X_DIM_IDX;
7272
#endif
7373
END_TYPE batch = batch_index < END_DIMS ? end[batch_index] : 0;
7474
END_TYPE feature = feature_index < END_DIMS ? end[feature_index] : 0;
@@ -100,20 +100,20 @@ inline void FUNC(get_slice_begin)(OPTIONAL_SHAPE_INFO_ARG
100100
int* begin_batch, int* begin_feature,
101101
int* begin_w, int* begin_z, int* begin_y, int* begin_x)
102102
{
103-
const uint batch_index = 0;
104-
const uint feature_index = 1;
103+
const uint batch_index = BATCH_DIM_IDX;
104+
const uint feature_index = FEATURE_DIM_IDX;
105105
#ifdef OUTPUT_LAYOUT_BFYX
106-
const uint y_index = 2;
107-
const uint x_index = 3;
106+
const uint y_index = Y_DIM_IDX;
107+
const uint x_index = X_DIM_IDX;
108108
#elif OUTPUT_LAYOUT_BFZYX
109-
const uint z_index = 2;
110-
const uint y_index = 3;
111-
const uint x_index = 4;
109+
const uint z_index = Z_DIM_IDX;
110+
const uint y_index = Y_DIM_IDX;
111+
const uint x_index = X_DIM_IDX;
112112
#elif OUTPUT_LAYOUT_BFWZYX
113-
const uint w_index = 2;
114-
const uint z_index = 3;
115-
const uint y_index = 4;
116-
const uint x_index = 5;
113+
const uint w_index = W_DIM_IDX;
114+
const uint z_index = Z_DIM_IDX;
115+
const uint y_index = Y_DIM_IDX;
116+
const uint x_index = X_DIM_IDX;
117117
#endif
118118

119119
BEGIN_TYPE batch = batch_index < BEGIN_DIMS ? begin[batch_index] : 0;
@@ -160,7 +160,7 @@ inline void FUNC(calculate_index)(int* step, int* begin_num, int* end_num, const
160160
{
161161
int real_begin = *begin_num < 0 ? *begin_num + out_num : *begin_num;
162162
int real_end = *end_num < 0 ? *end_num + out_num : *end_num;
163-
if (*step < 0) {
163+
if (*step < 0) {
164164
real_begin = max((int)(0), min((int)(out_num - 1), real_begin));
165165
real_end = max((int)(-1), min((int)out_num, real_end));
166166
if (real_begin < real_end) { // for reversing
@@ -239,6 +239,17 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
239239
end_x = SLICE_END_X;
240240
#endif // END_TYPE
241241

242+
// if (step_feature == -1 && step_x == 1) {
243+
// step_feature = 1;
244+
// step_x = -1;
245+
// }
246+
247+
if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
248+
printf("Step sizes (bfyx): %d %d %d %d. %d %d %d %d, %d\n", step_batch, step_feature, step_y, step_x, BATCH_DIM_IDX, FEATURE_DIM_IDX, Y_DIM_IDX, X_DIM_IDX, STRIDE_DIMS);
249+
printf("Begin sizes (bfyx): %d %d %d %d\n", begin_batch, begin_feature, begin_y, begin_x);
250+
printf("End sizes (bfyx): %d %d %d %d\n", end_batch, end_feature, end_y, end_x);
251+
}
252+
242253
#ifdef SHRINK_MODE
243254
FUNC_CALL(calculate_index)(&step_batch, &begin_batch, &end_batch, INPUT0_BATCH_NUM, SHRINK_BATCH);
244255
FUNC_CALL(calculate_index)(&step_feature, &begin_feature, &end_feature, INPUT0_FEATURE_NUM, SHRINK_FEATURE);
@@ -289,33 +300,62 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
289300

290301
#if NEW_AXIS_MODE
291302
// If NEW_AXIS_MODE that just copy input to output
292-
#ifdef OUTPUT_LAYOUT_BFYX
303+
#ifdef INPUT0_LAYOUT_BFYX
304+
const uint index_in_batch = (feature * (uint)get_global_size(2) + (uint)get_global_id(2)) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
305+
const uint input_feature_id = (feature * (uint)get_global_size(2) + (uint)get_global_id(2)) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
293306
const uint w_input = 0;
294307
const uint z_input = 0;
295-
const uint y_input = (uint)get_global_id(2) / INPUT0_SIZE_X;
296-
const uint x_input = (uint)get_global_id(2) % INPUT0_SIZE_X;
297-
#elif OUTPUT_LAYOUT_BFZYX
308+
const uint y_input = index_in_batch / OUTPUT_SIZE_X;
309+
const uint x_input = index_in_batch % OUTPUT_SIZE_X;
310+
#elif INPUT0_LAYOUT_BFZYX
311+
const uint index_in_batch = (feature * (uint)get_global_size(2) + (uint)get_global_id(2)) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z);
312+
const uint input_feature_id = (feature * (uint)get_global_size(2) + (uint)get_global_id(2)) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z);
298313
const uint w_input = 0;
299-
const uint yx_input = (uint)get_global_id(2) % (INPUT0_SIZE_X * INPUT0_SIZE_Y);
300-
const uint z_input = (uint)get_global_id(2) / (INPUT0_SIZE_X * INPUT0_SIZE_Y);
301-
const uint y_input = yx_input / INPUT0_SIZE_X;
302-
const uint x_input = yx_input % INPUT0_SIZE_X;
303-
#elif OUTPUT_LAYOUT_BFWZYX
304-
const uint zyx_input = (uint)get_global_id(2) % (INPUT0_SIZE_X * INPUT0_SIZE_Y * INPUT0_SIZE_Z);
305-
const uint w_input = (uint)get_global_id(2) / (INPUT0_SIZE_X * INPUT0_SIZE_Y * INPUT0_SIZE_Z);
306-
const uint z_input = zyx_input / (INPUT0_SIZE_X * INPUT0_SIZE_Y);
307-
const uint yx_input = zyx_input % (INPUT0_SIZE_X * INPUT0_SIZE_Y);
308-
const uint y_input = yx_input / INPUT0_SIZE_X;
309-
const uint x_input = yx_input % INPUT0_SIZE_X;
314+
const uint yx_input = index_in_batch % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
315+
const uint z_input = index_in_batch / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
316+
const uint y_input = yx_input / OUTPUT_SIZE_X;
317+
const uint x_input = yx_input % OUTPUT_SIZE_X;
318+
#elif INPUT0_LAYOUT_BFWZYX
319+
const uint index_in_batch = (feature * (uint)get_global_size(2) + (uint)get_global_id(2)) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z * OUTPUT_SIZE_W);
320+
const uint input_feature_id = (feature * (uint)get_global_size(2) + (uint)get_global_id(2)) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z * OUTPUT_SIZE_W);
321+
const uint zyx_input = index_in_batch % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z);
322+
const uint w_input = index_in_batch / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z);
323+
const uint z_input = zyx_input / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
324+
const uint yx_input = zyx_input % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
325+
const uint y_input = yx_input / OUTPUT_SIZE_X;
326+
const uint x_input = yx_input % OUTPUT_SIZE_X;
310327
#endif
328+
311329
const uint input_index = INPUT0_OFFSET +
312330
batch * INPUT0_BATCH_PITCH +
313-
feature * INPUT0_FEATURE_PITCH +
314-
w_input * INPUT0_W_PITCH +
315-
z_input * INPUT0_Z_PITCH +
316-
y_input * INPUT0_Y_PITCH +
317-
x_input * INPUT0_X_PITCH;
318-
output[input_index] = input[input_index];
331+
input_feature_id * INPUT0_FEATURE_PITCH +
332+
w_input * OUTPUT_W_PITCH +
333+
z_input * OUTPUT_Z_PITCH +
334+
y_input * OUTPUT_Y_PITCH +
335+
x_input * OUTPUT_X_PITCH;
336+
337+
#ifdef OUTPUT_LAYOUT_BFYX
338+
const uint y = (uint)get_global_id(2) / OUTPUT_SIZE_X;
339+
const uint x = (uint)get_global_id(2) % OUTPUT_SIZE_X;
340+
const uint output_index = OUTPUT_GET_INDEX(batch, feature, y, x);
341+
#elif OUTPUT_LAYOUT_BFZYX
342+
const uint yx = (uint)get_global_id(2) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
343+
const uint z = (uint)get_global_id(2) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
344+
const uint y = yx / OUTPUT_SIZE_X;
345+
const uint x = yx % OUTPUT_SIZE_X;
346+
const uint output_index = OUTPUT_GET_INDEX(batch, feature, z, y, x);
347+
#elif OUTPUT_LAYOUT_BFWZYX
348+
const uint zyx = (uint)get_global_id(2) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z);
349+
const uint w = (uint)get_global_id(2) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z);
350+
const uint z = zyx / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
351+
const uint yx = zyx % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y);
352+
const uint y = yx / OUTPUT_SIZE_X;
353+
const uint x = yx % OUTPUT_SIZE_X;
354+
const uint output_index = OUTPUT_GET_INDEX(batch, feature, w, z, y, x);
355+
#endif
356+
357+
output[output_index] = input[input_index];
358+
319359
#else // NEW_AXIS_MODE
320360
#ifdef OUTPUT_LAYOUT_BFYX
321361
const uint w = 0;
@@ -359,7 +399,7 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
359399
const uint input_index = INPUT0_OFFSET +
360400
(slice_begin_batch + batch * slice_steps_batch) * INPUT0_BATCH_PITCH +
361401
(slice_begin_feature + feature * slice_steps_feature) * INPUT0_FEATURE_PITCH +
362-
#if INPUT0_LAYOUT_BFWZYX
402+
#if INPUT0_LAYOUT_BFWZYX
363403
(slice_begin_w + w * slice_steps_w) * INPUT0_W_PITCH +
364404
(slice_begin_z + z * slice_steps_z) * INPUT0_Z_PITCH +
365405
(slice_begin_y + y * slice_steps_y) * INPUT0_Y_PITCH +
@@ -390,4 +430,4 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
390430
output[output_index] = ACTIVATION(input[input_index], ACTIVATION_PARAMS);
391431
#endif
392432
#endif // NEW_AXIS_MODE
393-
}
433+
}

0 commit comments

Comments
 (0)