@@ -412,16 +412,12 @@ KERNEL(gemm_tiled_opt)(
412
412
c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_read [subtile_k_id ], simd_local_id )),
413
413
b_tile [subtile_k_id * SIMD_WIDTH + simd_local_id ], c_tile [dot_id ]);
414
414
#else // TILE_K > SIMD_WIDTH
415
- #if IS_DYNAMIC && B_VEC_SIZE > 1
416
- #if TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
415
+ #if B_VEC_SIZE > 1 && TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
417
416
MAKE_VECTOR_TYPE (INPUT1_TYPE , B_VEC_SIZE ) b_tile_tmp ;
418
417
unroll_for (uint b_elem = 0 ; b_elem < B_VEC_SIZE ; ++ b_elem ) {
419
418
b_tile_tmp [b_elem ] = b_tile [b_elem ][simd_local_id ];
420
419
}
421
420
c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_read , simd_local_id )), b_tile_tmp , c_tile [dot_id ]);
422
- #else
423
- c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_read , simd_local_id )), b_tile [simd_local_id ], c_tile [dot_id ]);
424
- #endif
425
421
#else
426
422
c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_read , simd_local_id )), b_tile [simd_local_id ], c_tile [dot_id ]);
427
423
#endif
@@ -464,7 +460,15 @@ KERNEL(gemm_tiled_opt)(
464
460
// Tile C calculation for TN, TT cases
465
461
unroll_for (uint dot_id = 0 ; dot_id < tile_m_iterations ; dot_id ++ ) {
466
462
unroll_for (uint simd_local_id = 0 ; simd_local_id < SIMD_WIDTH ; simd_local_id ++ ) {
467
- c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_tile [dot_id ], simd_local_id )), b_tile [simd_local_id ], c_tile [dot_id ]);
463
+ #if B_VEC_SIZE > 1 && TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
464
+ MAKE_VECTOR_TYPE (INPUT1_TYPE , B_VEC_SIZE ) b_tile_tmp ;
465
+ unroll_for (uint b_elem = 0 ; b_elem < B_VEC_SIZE ; ++ b_elem ) {
466
+ b_tile_tmp [b_elem ] = b_tile [b_elem ][simd_local_id ];
467
+ }
468
+ c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_tile [dot_id ], simd_local_id )), b_tile_tmp , c_tile [dot_id ]);
469
+ #else
470
+ c_tile [dot_id ] = mad ((INPUT0_TYPE )(sub_group_broadcast (a_tile [dot_id ], simd_local_id )), b_tile [simd_local_id ], c_tile [dot_id ]);
471
+ #endif
468
472
}
469
473
} // Tile C calculation for TN, TT cases end
470
474
#endif // !TRANSPOSE_INPUT0
0 commit comments