@@ -306,11 +306,33 @@ KERNEL(gemm_tiled_opt)(
306
306
#if INDIRECT_INPUT1
307
307
if (do_indirect_load )
308
308
{
309
+ #if INPUT1_SIZE_X == 128 && INPUT1_FEATURE_NUM == 32 && defined(INPUT2_TYPE ) && 0
310
+ const __global INPUT1_TYPE * b_ptr_new = input1 ;
311
+ uint b_new = beam_table [FUNC_CALL (get_bt_index )(OPTIONAL_SHAPE_INFO_TENSOR b , f , w , z , (k * TILE_K ), x )];
312
+ uint load_idx = FUNC_CALL (get_input1_index )(OPTIONAL_SHAPE_INFO_TENSOR b_new , f , w , z , (k * TILE_K ), x );
313
+ b_ptr_new += load_idx ;
314
+ b_tile = (N > b_raw_global_id ) ? VLOAD (0 , b_ptr_new ) : 0 ;
315
+ #elif INPUT1_SIZE_X == 128 && INPUT1_FEATURE_NUM == 32 && defined(INPUT2_TYPE ) && 2
316
+ const __global INPUT1_TYPE * b_ptr_new = input1 ;
317
+ unroll_for (uint tile_n_load_idx = 0 ; tile_n_load_idx < TILE_N ; tile_n_load_idx ++ ) {
318
+ if (tile_n_offset + tile_n_load_idx >= N ) {
319
+ b_tile [tile_n_load_idx ] = 0 ;
320
+ } else {
321
+ // uint b_new = beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (k * TILE_K), tile_n_offset + tile_n_load_idx)];
322
+ // uint load_idx = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b_new, f, w, z, (k * TILE_K), tile_n_offset + tile_n_load_idx);
323
+ uint load_idx = FUNC_CALL (get_input1_indirect_index )(OPTIONAL_SHAPE_INFO_TENSOR b , f , w , z , (k * TILE_K ) + sglid , tile_n_offset + tile_n_load_idx , beam_table );
324
+ // b_tile[tile_n_load_idx] = BLOCK_READ_B(b_ptr_new + load_idx, 0);
325
+ b_tile [tile_n_load_idx ] = b_ptr_new [load_idx ];
326
+ // b_tile[tile_n_load_idx] = b_ptr_new[load_idx + sglid];
327
+ }
328
+ }
329
+ #else
309
330
unroll_for (uint b_load_id = 0 ; b_load_id < TILE_K ; b_load_id ++ ) {
310
331
uint b_load_offset = (k * TILE_K ) + b_load_id ;
311
332
uint b_idx = FUNC_CALL (get_input1_indirect_index )(OPTIONAL_SHAPE_INFO_TENSOR b , f , w , z , b_load_offset , x , beam_table );
312
333
b_tile [b_load_id ] = b_raw_global_id >= N ? 0 : input1 [b_idx ];
313
334
}
335
+ #endif
314
336
}
315
337
else
316
338
#endif
@@ -354,7 +376,14 @@ KERNEL(gemm_tiled_opt)(
354
376
c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_read [subtile_k_id ], simd_local_id )),
355
377
b_tile [subtile_k_id * SIMD_WIDTH + simd_local_id ], c_tile [dot_id ]);
356
378
#else // TILE_K > SIMD_WIDTH
379
+ #if INPUT1_SIZE_X == 128 && INPUT1_FEATURE_NUM == 32 && defined(INPUT2_TYPE ) && 2
380
+ INPUT0_TYPE tmp = a_read * b_tile [simd_local_id ];
381
+ INPUT0_TYPE res = sub_group_reduce_add (tmp );
382
+ if (sglid == simd_local_id )
383
+ c_tile [dot_id ] = res + c_tile [dot_id ];
384
+ #else
357
385
c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_read , simd_local_id )), b_tile [simd_local_id ], c_tile [dot_id ]);
386
+ #endif
358
387
#endif // TILE_K > SIMD_WIDTH
359
388
}
360
389
}
0 commit comments