Skip to content

Commit ce8f5d4

Browse files
committed
[GPU] Add ellipsis mode support for strided slice
1 parent 909535f commit ce8f5d4

File tree

5 files changed

+124
-45
lines changed

5 files changed

+124
-45
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/strided_slice.cpp

+38-10
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,36 @@
1212

1313
namespace {
1414
template <typename T, typename DT, typename = typename std::enable_if<std::is_convertible<DT, T>::value>::type>
15-
std::vector<T>& pad_vector_to_size(std::vector<T>& data, size_t size, DT value) {
16-
for (size_t i = data.size(); i < size; ++i) {
17-
data.push_back(static_cast<T>(value));
15+
void pad_vector_to_size(std::vector<T>& data, size_t size, DT value, const std::vector<int64_t>& ellipsis_mask) {
16+
bool apply_ellipsis_mask = std::count(ellipsis_mask.begin(), ellipsis_mask.end(), 1) == 1;
17+
if (apply_ellipsis_mask && data.size() == ellipsis_mask.size()) {
18+
std::vector<T> temp;
19+
size_t ellipsis_pos1 = 0;
20+
for (size_t i = 0; i < ellipsis_mask.size(); i++) {
21+
if (ellipsis_mask[i] == 1) {
22+
ellipsis_pos1 = i;
23+
break;
24+
}
25+
}
26+
27+
size_t dims_after = data.size() - ellipsis_pos1 - 1;
28+
size_t ellipsis_pos2 = size - dims_after - 1;;
29+
30+
for (size_t i = 0; i < ellipsis_pos1; i++)
31+
temp.push_back(data[i]);
32+
33+
for (size_t i = ellipsis_pos1; i < ellipsis_pos2 + 1; i++)
34+
temp.push_back(value);
35+
36+
for (size_t i = 1; i < size - ellipsis_pos2; i++)
37+
temp.push_back(data[i + ellipsis_pos1]);
38+
39+
data = temp;
40+
} else {
41+
for (size_t i = data.size(); i < size; ++i) {
42+
data.push_back(static_cast<T>(value));
43+
}
1844
}
19-
return data;
2045
}
2146

2247
template <typename T, typename MT>
@@ -74,7 +99,7 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
7499

75100
// Getting data from constant inputs. There are 3 args: Begin, End, Stride
76101
if (!begin.empty() && !params.has_dynamic_tensors()) {
77-
pad_vector_to_size(begin, dims_num, 0);
102+
pad_vector_to_size(begin, dims_num, 0, prim->ellipsis_mask);
78103
params.begin_type = kernel_selector::base_params::ArgType::Constant;
79104
params.striding_params.push_back(begin);
80105
} else {
@@ -91,7 +116,7 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
91116
return offset;
92117
};
93118
if (!end.empty() && !params.has_dynamic_tensors()) {
94-
pad_vector_to_size(end, dims_num, 1);
119+
pad_vector_to_size(end, dims_num, 1, prim->ellipsis_mask);
95120
params.end_type = kernel_selector::base_params::ArgType::Constant;
96121
params.striding_params.push_back(end);
97122
} else {
@@ -108,7 +133,7 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
108133
return offset;
109134
};
110135
if (!strides.empty() && !params.has_dynamic_tensors()) {
111-
pad_vector_to_size(strides, dims_num, 1);
136+
pad_vector_to_size(strides, dims_num, 1, prim->ellipsis_mask);
112137
params.stride_type = kernel_selector::base_params::ArgType::Constant;
113138
params.striding_params.push_back(strides);
114139
} else {
@@ -122,19 +147,22 @@ struct strided_slice_impl : typed_primitive_impl_ocl<strided_slice> {
122147
auto end_mask_ = prim->end_mask;
123148
auto new_axis_mask_ = prim->new_axis_mask;
124149
auto shrink_axis_mask_ = prim->shrink_axis_mask;
150+
auto ellipsis_mask_ = prim->ellipsis_mask;
125151

126152
std::vector<uint8_t> begin_mask(begin_mask_.begin(), begin_mask_.end());
127153
std::vector<uint8_t> end_mask(end_mask_.begin(), end_mask_.end());
128154
std::vector<uint8_t> new_axis_mask(new_axis_mask_.begin(), new_axis_mask_.end());
129155
std::vector<uint8_t> shrink_axis_mask(shrink_axis_mask_.begin(), shrink_axis_mask_.end());
156+
std::vector<uint8_t> ellipsis_mask(ellipsis_mask_.begin(), ellipsis_mask_.end());
130157
params.end_mask = std::move(end_mask);
131-
pad_vector_to_size(params.end_mask, dims_num, 0);
158+
pad_vector_to_size(params.end_mask, dims_num, 0, prim->ellipsis_mask);
132159
params.begin_mask = std::move(begin_mask);
133-
pad_vector_to_size(params.begin_mask, dims_num, 0);
160+
pad_vector_to_size(params.begin_mask, dims_num, 0, prim->ellipsis_mask);
134161

135162
params.new_axis_mask = new_axis_mask;
136163
params.shrink_axis_mask = shrink_axis_mask;
137-
pad_vector_to_size(params.shrink_axis_mask, dims_num, 0);
164+
params.ellipsis_mask = ellipsis_mask;
165+
pad_vector_to_size(params.shrink_axis_mask, dims_num, 0, prim->ellipsis_mask);
138166

139167
std::vector<size_t> logical_dims = params.inputs[0].LogicalDims();
140168
std::reverse(logical_dims.begin(), logical_dims.end()); // get dims in bfyx order

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/strided_slice_ref.cl

+35-35
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@ inline void FUNC(get_slice_step)(OPTIONAL_SHAPE_INFO_ARG
1010
int* step_batch, int* step_feature,
1111
int* step_w, int* step_z, int* step_y, int* step_x)
1212
{
13-
const uint batch_index = 0;
14-
const uint feature_index = 1;
13+
const uint batch_index = DIM_IDX_BATCH;
14+
const uint feature_index = DIM_IDX_FEATURE;
1515
#ifdef OUTPUT_LAYOUT_BFYX
16-
const uint y_index = 2;
17-
const uint x_index = 3;
16+
const uint y_index = DIM_IDX_Y;
17+
const uint x_index = DIM_IDX_X;
1818
#elif OUTPUT_LAYOUT_BFZYX
19-
const uint z_index = 2;
20-
const uint y_index = 3;
21-
const uint x_index = 4;
19+
const uint z_index = DIM_IDX_Z;
20+
const uint y_index = DIM_IDX_Y;
21+
const uint x_index = DIM_IDX_X;
2222
#elif OUTPUT_LAYOUT_BFWZYX
23-
const uint w_index = 2;
24-
const uint z_index = 3;
25-
const uint y_index = 4;
26-
const uint x_index = 5;
23+
const uint w_index = DIM_IDX_W;
24+
const uint z_index = DIM_IDX_Z;
25+
const uint y_index = DIM_IDX_Y;
26+
const uint x_index = DIM_IDX_X;
2727
#endif
2828

2929
*step_batch = batch_index < STRIDE_DIMS ? stride[batch_index] : 1;
@@ -55,20 +55,20 @@ inline void FUNC(get_slice_end)(OPTIONAL_SHAPE_INFO_ARG
5555
const uint out_z_num = INPUT0_SIZE_Z;
5656
const uint out_y_num = INPUT0_SIZE_Y;
5757
const uint out_x_num = INPUT0_SIZE_X;
58-
const uint batch_index = 0;
59-
const uint feature_index = 1;
58+
const uint batch_index = DIM_IDX_BATCH;
59+
const uint feature_index = DIM_IDX_FEATURE;
6060
#ifdef OUTPUT_LAYOUT_BFYX
61-
const uint y_index = 2;
62-
const uint x_index = 3;
61+
const uint y_index = DIM_IDX_Y;
62+
const uint x_index = DIM_IDX_X;
6363
#elif OUTPUT_LAYOUT_BFZYX
64-
const uint z_index = 2;
65-
const uint y_index = 3;
66-
const uint x_index = 4;
64+
const uint z_index = DIM_IDX_Z;
65+
const uint y_index = DIM_IDX_Y;
66+
const uint x_index = DIM_IDX_X;
6767
#elif OUTPUT_LAYOUT_BFWZYX
68-
const uint w_index = 2;
69-
const uint z_index = 3;
70-
const uint y_index = 4;
71-
const uint x_index = 5;
68+
const uint w_index = DIM_IDX_W;
69+
const uint z_index = DIM_IDX_Z;
70+
const uint y_index = DIM_IDX_Y;
71+
const uint x_index = DIM_IDX_X;
7272
#endif
7373
END_TYPE batch = batch_index < END_DIMS ? end[batch_index] : 0;
7474
END_TYPE feature = feature_index < END_DIMS ? end[feature_index] : 0;
@@ -100,20 +100,20 @@ inline void FUNC(get_slice_begin)(OPTIONAL_SHAPE_INFO_ARG
100100
int* begin_batch, int* begin_feature,
101101
int* begin_w, int* begin_z, int* begin_y, int* begin_x)
102102
{
103-
const uint batch_index = 0;
104-
const uint feature_index = 1;
103+
const uint batch_index = DIM_IDX_BATCH;
104+
const uint feature_index = DIM_IDX_FEATURE;
105105
#ifdef OUTPUT_LAYOUT_BFYX
106-
const uint y_index = 2;
107-
const uint x_index = 3;
106+
const uint y_index = DIM_IDX_Y;
107+
const uint x_index = DIM_IDX_X;
108108
#elif OUTPUT_LAYOUT_BFZYX
109-
const uint z_index = 2;
110-
const uint y_index = 3;
111-
const uint x_index = 4;
109+
const uint z_index = DIM_IDX_Z;
110+
const uint y_index = DIM_IDX_Y;
111+
const uint x_index = DIM_IDX_X;
112112
#elif OUTPUT_LAYOUT_BFWZYX
113-
const uint w_index = 2;
114-
const uint z_index = 3;
115-
const uint y_index = 4;
116-
const uint x_index = 5;
113+
const uint w_index = DIM_IDX_W;
114+
const uint z_index = DIM_IDX_Z;
115+
const uint y_index = DIM_IDX_Y;
116+
const uint x_index = DIM_IDX_X;
117117
#endif
118118

119119
BEGIN_TYPE batch = batch_index < BEGIN_DIMS ? begin[batch_index] : 0;
@@ -160,7 +160,7 @@ inline void FUNC(calculate_index)(int* step, int* begin_num, int* end_num, const
160160
{
161161
int real_begin = *begin_num < 0 ? *begin_num + out_num : *begin_num;
162162
int real_end = *end_num < 0 ? *end_num + out_num : *end_num;
163-
if (*step < 0) {
163+
if (*step < 0) {
164164
real_begin = max((int)(0), min((int)(out_num - 1), real_begin));
165165
real_end = max((int)(-1), min((int)out_num, real_end));
166166
if (real_begin < real_end) { // for reversing
@@ -359,7 +359,7 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
359359
const uint input_index = INPUT0_OFFSET +
360360
(slice_begin_batch + batch * slice_steps_batch) * INPUT0_BATCH_PITCH +
361361
(slice_begin_feature + feature * slice_steps_feature) * INPUT0_FEATURE_PITCH +
362-
#if INPUT0_LAYOUT_BFWZYX
362+
#if INPUT0_LAYOUT_BFWZYX
363363
(slice_begin_w + w * slice_steps_w) * INPUT0_W_PITCH +
364364
(slice_begin_z + z * slice_steps_z) * INPUT0_Z_PITCH +
365365
(slice_begin_y + y * slice_steps_y) * INPUT0_Y_PITCH +

src/plugins/intel_gpu/src/kernel_selector/kernels/strided_slice/strided_slice_kernel_ref.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,38 @@ JitConstants StridedSliceKernelRef::GetJitConstants(const strided_slice_params&
165165
"NEW_AXIS_MODE",
166166
std::find(params.new_axis_mask.begin(), params.new_axis_mask.end(), 1) != params.new_axis_mask.end()));
167167

168+
std::vector<int> dims_indexes;
169+
bool ellipsis_mode = std::find(params.ellipsis_mask.begin(), params.ellipsis_mask.end(), 1) != params.ellipsis_mask.end();
170+
if (ellipsis_mode) {
171+
size_t ellipsis_pos1 = 0;
172+
for (size_t i = 0; i < params.ellipsis_mask.size(); i++) {
173+
if (params.ellipsis_mask[i] == 1) {
174+
ellipsis_pos1 = i;
175+
break;
176+
}
177+
}
178+
179+
const size_t output_rank = params.outputs[0].Dimentions();
180+
const size_t skip_dims_num = output_rank - params.ellipsis_mask.size() + 1;
181+
size_t dim_counter = 0;
182+
183+
for (size_t i = 0; i < ellipsis_pos1; i++)
184+
dims_indexes.push_back(dim_counter++);
185+
186+
for (size_t i = 0; i < skip_dims_num; i++)
187+
dims_indexes.push_back(-1);
188+
189+
dim_counter++;
190+
for (size_t i = 0; i < params.ellipsis_mask.size() - ellipsis_pos1 - 1; i++)
191+
dims_indexes.push_back(dim_counter++);
192+
193+
OPENVINO_ASSERT(dims_indexes.size() == output_rank, "[GPU] Number of indexes is expected to match with output rank");
194+
} else {
195+
dims_indexes.resize(params.outputs[0].Dimentions());
196+
std::iota(dims_indexes.begin(), dims_indexes.end(), 0);
197+
}
198+
makeJitConstForParam(jit, "DIM_IDX", dims_indexes);
199+
168200
bool shrink_mode = std::find(params.shrink_axis_mask.begin(), params.shrink_axis_mask.end(), 1) != params.shrink_axis_mask.end();
169201
if (shrink_mode) {
170202
jit.AddConstant(MakeJitConstant("SHRINK_MODE", true));

src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/strided_slice.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,22 @@ std::vector<StridedSliceSpecificParams> ss_only_test_cases_fp32 = {
152152
{ 128, 1, 1024 }})),
153153
{ -1, 0, 0 }, { 0, 0, 0 }, { 1, 1, 1 },
154154
{ 0, 1, 1 }, { 0, 1, 1 }, { 0, 0, 0 }, { 1, 0, 0 }, { 0, 0, 0 } },
155+
StridedSliceSpecificParams{ ov::test::static_shapes_to_test_representation(std::vector<ov::Shape>({
156+
{ 10, 10 }})),
157+
{ -4, 1 }, { -8, 0 }, { -1, 1 },
158+
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 1 } },
159+
StridedSliceSpecificParams{ ov::test::static_shapes_to_test_representation(std::vector<ov::Shape>({
160+
{ 2, 2, 4, 1 }})),
161+
{ 0, 0 }, { 2, 2 }, { 1, 1 },
162+
{ 0, 0 }, { 0, 0 }, { 0, 0 }, { 0, 0 }, { 0, 1 } },
163+
StridedSliceSpecificParams{ ov::test::static_shapes_to_test_representation(std::vector<ov::Shape>({
164+
{ 2, 2, 4, 1 }})),
165+
{ 0, 0 }, { 4, 1 }, { 1, -1 },
166+
{ 0, 0 }, { 0, 0 }, { 0, 0 }, { 0, 0 }, { 1, 0 } },
167+
StridedSliceSpecificParams{ ov::test::static_shapes_to_test_representation(std::vector<ov::Shape>({
168+
{ 1, 5, 30, 30, 30 }})),
169+
{ 0, 0, 0 }, { 0, 29, 29 }, { 1, 1, 1 },
170+
{1, 1, 1}, {1, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 1, 0} },
155171
};
156172

157173
std::vector<StridedSliceSpecificParams> ss_only_test_cases_i64 = {

src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/strided_slice.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ const std::vector<StridedSliceParams> testCasesCommon4D = {
368368
StridedSliceParams{ { 0, 0, 0, 20 }, { 1, 2, 30, 30 }, { 1, 1, 2, 1 }, { 0, 0, 0, 1 }, { 0, 1, 0, 1 }, { }, { }, { } },
369369
StridedSliceParams{ { 0, 1, 2, 10 }, { 1, 5, 32, 18 }, { 1, 1, 1, 2 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { }, { }, { } },
370370
StridedSliceParams{ { 0, 0, 2, 10 }, { 1, 8, 32, 18 }, { 1, 2, 1, 2 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { }, { }, { } },
371+
StridedSliceParams{ { 2, 10 }, { 32, 18 }, { 1, 2 }, { 1, 0 }, { 0, 1 }, { }, { }, { 1, 0 } },
371372
};
372373

373374
const std::vector<InputShape> inputShapesDynamic4D = {
@@ -396,6 +397,7 @@ const std::vector<StridedSliceParams> testCasesCommon5D = {
396397
StridedSliceParams{ { 0, 0, 0, 20 }, { 1, 2, 30, 30 }, { 1, 1, 2, 1 }, { 0, 0, 0, 1 }, { 0, 1, 0, 1 }, { }, { }, { } },
397398
StridedSliceParams{ { 0, 1, 2, 10 }, { 1, 5, 32, 18 }, { 1, 1, 1, 2 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { }, { }, { } },
398399
StridedSliceParams{ { 0, 0, 2, 10 }, { 1, 8, 32, 18 }, { 1, 2, 1, 2 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { }, { }, { } },
400+
StridedSliceParams{ { 0, 0, 2 }, { 1, 8, 32 }, { 1, 2, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { }, { }, { 0, 1, 0} },
399401
};
400402

401403
const std::vector<InputShape> inputShapesDynamic5D = {
@@ -421,6 +423,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Common_Dynamic_5D, StridedSliceLa
421423
const std::vector<StridedSliceParams> testCasesCommon6D = {
422424
StridedSliceParams{ { 0, 2, 5, 4 }, { 1, 4, 28, 27 }, { 1, 1, 1, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } },
423425
StridedSliceParams{ { 0, 0, 10, 20 }, { 1, 5, 28, 26 }, { 1, 1, 1, 2 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { }, { }, { } },
426+
StridedSliceParams{ { 0, 0, 0 }, { 0, 0, 0 }, { 1, 1, 1 }, { 0, 1, 0 }, { 0, 1, 0 }, { }, { }, { 0, 0, 1 } },
424427
};
425428

426429
const std::vector<InputShape> inputShapesDynamic6D = {

0 commit comments

Comments
 (0)