@@ -10,20 +10,20 @@ inline void FUNC(get_slice_step)(OPTIONAL_SHAPE_INFO_ARG
10
10
int * step_batch , int * step_feature ,
11
11
int * step_w , int * step_z , int * step_y , int * step_x )
12
12
{
13
- const uint batch_index = 0 ;
14
- const uint feature_index = 1 ;
13
+ const uint batch_index = BATCH_DIM_IDX ;
14
+ const uint feature_index = FEATURE_DIM_IDX ;
15
15
#ifdef OUTPUT_LAYOUT_BFYX
16
- const uint y_index = 2 ;
17
- const uint x_index = 3 ;
16
+ const uint y_index = Y_DIM_IDX ;
17
+ const uint x_index = X_DIM_IDX ;
18
18
#elif OUTPUT_LAYOUT_BFZYX
19
- const uint z_index = 2 ;
20
- const uint y_index = 3 ;
21
- const uint x_index = 4 ;
19
+ const uint z_index = Z_DIM_IDX ;
20
+ const uint y_index = Y_DIM_IDX ;
21
+ const uint x_index = X_DIM_IDX ;
22
22
#elif OUTPUT_LAYOUT_BFWZYX
23
- const uint w_index = 2 ;
24
- const uint z_index = 3 ;
25
- const uint y_index = 4 ;
26
- const uint x_index = 5 ;
23
+ const uint w_index = W_DIM_IDX ;
24
+ const uint z_index = Z_DIM_IDX ;
25
+ const uint y_index = Y_DIM_IDX ;
26
+ const uint x_index = X_DIM_IDX ;
27
27
#endif
28
28
29
29
* step_batch = batch_index < STRIDE_DIMS ? stride [batch_index ] : 1 ;
@@ -55,20 +55,20 @@ inline void FUNC(get_slice_end)(OPTIONAL_SHAPE_INFO_ARG
55
55
const uint out_z_num = INPUT0_SIZE_Z ;
56
56
const uint out_y_num = INPUT0_SIZE_Y ;
57
57
const uint out_x_num = INPUT0_SIZE_X ;
58
- const uint batch_index = 0 ;
59
- const uint feature_index = 1 ;
58
+ const uint batch_index = BATCH_DIM_IDX ;
59
+ const uint feature_index = FEATURE_DIM_IDX ;
60
60
#ifdef OUTPUT_LAYOUT_BFYX
61
- const uint y_index = 2 ;
62
- const uint x_index = 3 ;
61
+ const uint y_index = Y_DIM_IDX ;
62
+ const uint x_index = X_DIM_IDX ;
63
63
#elif OUTPUT_LAYOUT_BFZYX
64
- const uint z_index = 2 ;
65
- const uint y_index = 3 ;
66
- const uint x_index = 4 ;
64
+ const uint z_index = Z_DIM_IDX ;
65
+ const uint y_index = Y_DIM_IDX ;
66
+ const uint x_index = X_DIM_IDX ;
67
67
#elif OUTPUT_LAYOUT_BFWZYX
68
- const uint w_index = 2 ;
69
- const uint z_index = 3 ;
70
- const uint y_index = 4 ;
71
- const uint x_index = 5 ;
68
+ const uint w_index = W_DIM_IDX ;
69
+ const uint z_index = Z_DIM_IDX ;
70
+ const uint y_index = Y_DIM_IDX ;
71
+ const uint x_index = X_DIM_IDX ;
72
72
#endif
73
73
END_TYPE batch = batch_index < END_DIMS ? end [batch_index ] : 0 ;
74
74
END_TYPE feature = feature_index < END_DIMS ? end [feature_index ] : 0 ;
@@ -100,20 +100,20 @@ inline void FUNC(get_slice_begin)(OPTIONAL_SHAPE_INFO_ARG
100
100
int * begin_batch , int * begin_feature ,
101
101
int * begin_w , int * begin_z , int * begin_y , int * begin_x )
102
102
{
103
- const uint batch_index = 0 ;
104
- const uint feature_index = 1 ;
103
+ const uint batch_index = BATCH_DIM_IDX ;
104
+ const uint feature_index = FEATURE_DIM_IDX ;
105
105
#ifdef OUTPUT_LAYOUT_BFYX
106
- const uint y_index = 2 ;
107
- const uint x_index = 3 ;
106
+ const uint y_index = Y_DIM_IDX ;
107
+ const uint x_index = X_DIM_IDX ;
108
108
#elif OUTPUT_LAYOUT_BFZYX
109
- const uint z_index = 2 ;
110
- const uint y_index = 3 ;
111
- const uint x_index = 4 ;
109
+ const uint z_index = Z_DIM_IDX ;
110
+ const uint y_index = Y_DIM_IDX ;
111
+ const uint x_index = X_DIM_IDX ;
112
112
#elif OUTPUT_LAYOUT_BFWZYX
113
- const uint w_index = 2 ;
114
- const uint z_index = 3 ;
115
- const uint y_index = 4 ;
116
- const uint x_index = 5 ;
113
+ const uint w_index = W_DIM_IDX ;
114
+ const uint z_index = Z_DIM_IDX ;
115
+ const uint y_index = Y_DIM_IDX ;
116
+ const uint x_index = X_DIM_IDX ;
117
117
#endif
118
118
119
119
BEGIN_TYPE batch = batch_index < BEGIN_DIMS ? begin [batch_index ] : 0 ;
@@ -160,7 +160,7 @@ inline void FUNC(calculate_index)(int* step, int* begin_num, int* end_num, const
160
160
{
161
161
int real_begin = * begin_num < 0 ? * begin_num + out_num : * begin_num ;
162
162
int real_end = * end_num < 0 ? * end_num + out_num : * end_num ;
163
- if (* step < 0 ) {
163
+ if (* step < 0 ) {
164
164
real_begin = max ((int )(0 ), min ((int )(out_num - 1 ), real_begin ));
165
165
real_end = max ((int )(-1 ), min ((int )out_num , real_end ));
166
166
if (real_begin < real_end ) { // for reversing
@@ -239,6 +239,17 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
239
239
end_x = SLICE_END_X ;
240
240
#endif // END_TYPE
241
241
242
+ // if (step_feature == -1 && step_x == 1) {
243
+ // step_feature = 1;
244
+ // step_x = -1;
245
+ // }
246
+
247
+ if (get_global_id (0 ) == 0 && get_global_id (1 ) == 0 && get_global_id (2 ) == 0 ) {
248
+ printf ("Step sizes (bfyx): %d %d %d %d. %d %d %d %d, %d\n" , step_batch , step_feature , step_y , step_x , BATCH_DIM_IDX , FEATURE_DIM_IDX , Y_DIM_IDX , X_DIM_IDX , STRIDE_DIMS );
249
+ printf ("Begin sizes (bfyx): %d %d %d %d\n" , begin_batch , begin_feature , begin_y , begin_x );
250
+ printf ("End sizes (bfyx): %d %d %d %d\n" , end_batch , end_feature , end_y , end_x );
251
+ }
252
+
242
253
#ifdef SHRINK_MODE
243
254
FUNC_CALL (calculate_index )(& step_batch , & begin_batch , & end_batch , INPUT0_BATCH_NUM , SHRINK_BATCH );
244
255
FUNC_CALL (calculate_index )(& step_feature , & begin_feature , & end_feature , INPUT0_FEATURE_NUM , SHRINK_FEATURE );
@@ -289,33 +300,62 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
289
300
290
301
#if NEW_AXIS_MODE
291
302
// If NEW_AXIS_MODE that just copy input to output
292
- #ifdef OUTPUT_LAYOUT_BFYX
303
+ #ifdef INPUT0_LAYOUT_BFYX
304
+ const uint index_in_batch = (feature * (uint )get_global_size (2 ) + (uint )get_global_id (2 )) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
305
+ const uint input_feature_id = (feature * (uint )get_global_size (2 ) + (uint )get_global_id (2 )) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
293
306
const uint w_input = 0 ;
294
307
const uint z_input = 0 ;
295
- const uint y_input = (uint )get_global_id (2 ) / INPUT0_SIZE_X ;
296
- const uint x_input = (uint )get_global_id (2 ) % INPUT0_SIZE_X ;
297
- #elif OUTPUT_LAYOUT_BFZYX
308
+ const uint y_input = index_in_batch / OUTPUT_SIZE_X ;
309
+ const uint x_input = index_in_batch % OUTPUT_SIZE_X ;
310
+ #elif INPUT0_LAYOUT_BFZYX
311
+ const uint index_in_batch = (feature * (uint )get_global_size (2 ) + (uint )get_global_id (2 )) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z );
312
+ const uint input_feature_id = (feature * (uint )get_global_size (2 ) + (uint )get_global_id (2 )) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z );
298
313
const uint w_input = 0 ;
299
- const uint yx_input = (uint )get_global_id (2 ) % (INPUT0_SIZE_X * INPUT0_SIZE_Y );
300
- const uint z_input = (uint )get_global_id (2 ) / (INPUT0_SIZE_X * INPUT0_SIZE_Y );
301
- const uint y_input = yx_input / INPUT0_SIZE_X ;
302
- const uint x_input = yx_input % INPUT0_SIZE_X ;
303
- #elif OUTPUT_LAYOUT_BFWZYX
304
- const uint zyx_input = (uint )get_global_id (2 ) % (INPUT0_SIZE_X * INPUT0_SIZE_Y * INPUT0_SIZE_Z );
305
- const uint w_input = (uint )get_global_id (2 ) / (INPUT0_SIZE_X * INPUT0_SIZE_Y * INPUT0_SIZE_Z );
306
- const uint z_input = zyx_input / (INPUT0_SIZE_X * INPUT0_SIZE_Y );
307
- const uint yx_input = zyx_input % (INPUT0_SIZE_X * INPUT0_SIZE_Y );
308
- const uint y_input = yx_input / INPUT0_SIZE_X ;
309
- const uint x_input = yx_input % INPUT0_SIZE_X ;
314
+ const uint yx_input = index_in_batch % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
315
+ const uint z_input = index_in_batch / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
316
+ const uint y_input = yx_input / OUTPUT_SIZE_X ;
317
+ const uint x_input = yx_input % OUTPUT_SIZE_X ;
318
+ #elif INPUT0_LAYOUT_BFWZYX
319
+ const uint index_in_batch = (feature * (uint )get_global_size (2 ) + (uint )get_global_id (2 )) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z * OUTPUT_SIZE_W );
320
+ const uint input_feature_id = (feature * (uint )get_global_size (2 ) + (uint )get_global_id (2 )) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z * OUTPUT_SIZE_W );
321
+ const uint zyx_input = index_in_batch % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z );
322
+ const uint w_input = index_in_batch / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z );
323
+ const uint z_input = zyx_input / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
324
+ const uint yx_input = zyx_input % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
325
+ const uint y_input = yx_input / OUTPUT_SIZE_X ;
326
+ const uint x_input = yx_input % OUTPUT_SIZE_X ;
310
327
#endif
328
+
311
329
const uint input_index = INPUT0_OFFSET +
312
330
batch * INPUT0_BATCH_PITCH +
313
- feature * INPUT0_FEATURE_PITCH +
314
- w_input * INPUT0_W_PITCH +
315
- z_input * INPUT0_Z_PITCH +
316
- y_input * INPUT0_Y_PITCH +
317
- x_input * INPUT0_X_PITCH ;
318
- output [input_index ] = input [input_index ];
331
+ input_feature_id * INPUT0_FEATURE_PITCH +
332
+ w_input * OUTPUT_W_PITCH +
333
+ z_input * OUTPUT_Z_PITCH +
334
+ y_input * OUTPUT_Y_PITCH +
335
+ x_input * OUTPUT_X_PITCH ;
336
+
337
+ #ifdef OUTPUT_LAYOUT_BFYX
338
+ const uint y = (uint )get_global_id (2 ) / OUTPUT_SIZE_X ;
339
+ const uint x = (uint )get_global_id (2 ) % OUTPUT_SIZE_X ;
340
+ const uint output_index = OUTPUT_GET_INDEX (batch , feature , y , x );
341
+ #elif OUTPUT_LAYOUT_BFZYX
342
+ const uint yx = (uint )get_global_id (2 ) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
343
+ const uint z = (uint )get_global_id (2 ) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
344
+ const uint y = yx / OUTPUT_SIZE_X ;
345
+ const uint x = yx % OUTPUT_SIZE_X ;
346
+ const uint output_index = OUTPUT_GET_INDEX (batch , feature , z , y , x );
347
+ #elif OUTPUT_LAYOUT_BFWZYX
348
+ const uint zyx = (uint )get_global_id (2 ) % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z );
349
+ const uint w = (uint )get_global_id (2 ) / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y * OUTPUT_SIZE_Z );
350
+ const uint z = zyx / (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
351
+ const uint yx = zyx % (OUTPUT_SIZE_X * OUTPUT_SIZE_Y );
352
+ const uint y = yx / OUTPUT_SIZE_X ;
353
+ const uint x = yx % OUTPUT_SIZE_X ;
354
+ const uint output_index = OUTPUT_GET_INDEX (batch , feature , w , z , y , x );
355
+ #endif
356
+
357
+ output [output_index ] = input [input_index ];
358
+
319
359
#else // NEW_AXIS_MODE
320
360
#ifdef OUTPUT_LAYOUT_BFYX
321
361
const uint w = 0 ;
@@ -359,7 +399,7 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
359
399
const uint input_index = INPUT0_OFFSET +
360
400
(slice_begin_batch + batch * slice_steps_batch ) * INPUT0_BATCH_PITCH +
361
401
(slice_begin_feature + feature * slice_steps_feature ) * INPUT0_FEATURE_PITCH +
362
- #if INPUT0_LAYOUT_BFWZYX
402
+ #if INPUT0_LAYOUT_BFWZYX
363
403
(slice_begin_w + w * slice_steps_w ) * INPUT0_W_PITCH +
364
404
(slice_begin_z + z * slice_steps_z ) * INPUT0_Z_PITCH +
365
405
(slice_begin_y + y * slice_steps_y ) * INPUT0_Y_PITCH +
@@ -390,4 +430,4 @@ KERNEL(strided_slice_ref)(OPTIONAL_SHAPE_INFO_ARG
390
430
output [output_index ] = ACTIVATION (input [input_index ], ACTIVATION_PARAMS );
391
431
#endif
392
432
#endif // NEW_AXIS_MODE
393
- }
433
+ }
0 commit comments