@@ -224,17 +224,20 @@ KERNEL(pa_sdpa_opt)(
224
224
#define KEY_BLOCK_UNCOMPRESSED MAKE_VECTOR_TYPE(INPUT0_TYPE, KEY_VEC_SIZE)
225
225
#define TO_KEY_BLOCK_UNCOMPRESSED_TYPE (val ) CAT(convert_, KEY_BLOCK_UNCOMPRESSED)(val)
226
226
227
- KEY_BLOCK k_vals_packed = 0 ;
227
+ #if IS_KV_COMPRESSED
228
+ KEY_BLOCK_UNCOMPRESSED k_vals ;
228
229
unroll_for (uint i = 0 ; i < KEY_VEC_SIZE ; i ++ ) {
229
- k_vals_packed [i ] = BLOCK_READN (INPUT1_TYPE , 1 , key_cache , block_offset + qk_idx * SUBGROUP_SIZE * KEY_VEC_SIZE + i * SUBGROUP_SIZE );
230
+ k_vals [i ] = BLOCK_READN (INPUT1_TYPE , 1 , key_cache , block_offset + qk_idx * SUBGROUP_SIZE * KEY_VEC_SIZE + i * SUBGROUP_SIZE );
231
+ k_vals [i ] = (k_vals [i ] - comp_zp ) * comp_scale ;
230
232
}
231
-
232
- #if IS_KV_COMPRESSED
233
- KEY_BLOCK_UNCOMPRESSED k_vals = (TO_KEY_BLOCK_UNCOMPRESSED_TYPE (k_vals_packed ) - comp_zp ) * comp_scale ;
234
233
#else
235
- KEY_BLOCK k_vals = k_vals_packed ;
234
+ KEY_BLOCK k_vals = 0 ;
235
+ unroll_for (uint i = 0 ; i < KEY_VEC_SIZE ; i ++ ) {
236
+ k_vals [i ] = BLOCK_READN (INPUT1_TYPE , 1 , key_cache , block_offset + qk_idx * SUBGROUP_SIZE * KEY_VEC_SIZE + i * SUBGROUP_SIZE );
237
+ }
236
238
#endif
237
239
240
+ #if XE2_QK_MULTIPLICATION
238
241
#if STORE_QUERY_TO_SLM
239
242
MAKE_VECTOR_TYPE (INPUT0_TYPE , QUERIES_PER_WI ) q_val ;
240
243
unroll_for (uint q_idx = 0 ; q_idx < QUERIES_PER_WI ; q_idx ++ ) {
@@ -249,6 +252,20 @@ KERNEL(pa_sdpa_opt)(
249
252
qk_acc = mad (TO_SOFTMAX_ACCUMULATOR_TYPE (sub_group_broadcast (q_val [qk_idx ], i )), TO_SOFTMAX_ACCUMULATOR_TYPE (k_vals [i ]), qk_acc );
250
253
#endif
251
254
}
255
+ #else // !XE2_QK_MULTIPLICATION
256
+ unroll_for (uint q_idx = 0 ; q_idx < QUERIES_PER_WI ; q_idx ++ ) {
257
+ #if STORE_QUERY_TO_SLM
258
+ SOFTMAX_ACCUMULATOR_TYPE q_val = slm_query [q_idx * HEAD_SIZE + qk_idx * KEY_VEC_SIZE + sglid ];
259
+ #endif
260
+ unroll_for (uint i = 0 ; i < KEY_VEC_SIZE ; i ++ ) {
261
+ #if STORE_QUERY_TO_SLM
262
+ GET_VECTOR_ELEMENT (qk_acc , q_idx ) = mad (sub_group_broadcast (q_val , i ), TO_SOFTMAX_ACCUMULATOR_TYPE (k_vals [i ]), GET_VECTOR_ELEMENT (qk_acc , q_idx ));
263
+ #else
264
+ qk_acc = mad (TO_SOFTMAX_ACCUMULATOR_TYPE (sub_group_broadcast (q_val [qk_idx ], i )), TO_SOFTMAX_ACCUMULATOR_TYPE (k_vals [i ]), qk_acc );
265
+ #endif
266
+ }
267
+ }
268
+ #endif // XE2_QK_MULTIPLICATION
252
269
}
253
270
254
271
const uint token_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + block_num * SUBGROUPS_PER_WG * SUBGROUP_SIZE + sgid * SUBGROUP_SIZE + sglid ;
0 commit comments