Skip to content

Commit 99fde76

Browse files
committed
Update Key decompression logic to prevent register spills on XE2 and slightly modify the QK multiplication with respect to the platfrom
1 parent aa6ff61 commit 99fde76

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+23-6
Original file line numberDiff line numberDiff line change
@@ -224,17 +224,20 @@ KERNEL(pa_sdpa_opt)(
224224
#define KEY_BLOCK_UNCOMPRESSED MAKE_VECTOR_TYPE(INPUT0_TYPE, KEY_VEC_SIZE)
225225
#define TO_KEY_BLOCK_UNCOMPRESSED_TYPE(val) CAT(convert_, KEY_BLOCK_UNCOMPRESSED)(val)
226226

227-
KEY_BLOCK k_vals_packed = 0;
227+
#if IS_KV_COMPRESSED
228+
KEY_BLOCK_UNCOMPRESSED k_vals;
228229
unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
229-
k_vals_packed[i] = BLOCK_READN(INPUT1_TYPE, 1, key_cache, block_offset + qk_idx * SUBGROUP_SIZE * KEY_VEC_SIZE + i * SUBGROUP_SIZE);
230+
k_vals[i] = BLOCK_READN(INPUT1_TYPE, 1, key_cache, block_offset + qk_idx * SUBGROUP_SIZE * KEY_VEC_SIZE + i * SUBGROUP_SIZE);
231+
k_vals[i] = (k_vals[i] - comp_zp) * comp_scale;
230232
}
231-
232-
#if IS_KV_COMPRESSED
233-
KEY_BLOCK_UNCOMPRESSED k_vals = (TO_KEY_BLOCK_UNCOMPRESSED_TYPE(k_vals_packed) - comp_zp) * comp_scale;
234233
#else
235-
KEY_BLOCK k_vals = k_vals_packed;
234+
KEY_BLOCK k_vals = 0;
235+
unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
236+
k_vals[i] = BLOCK_READN(INPUT1_TYPE, 1, key_cache, block_offset + qk_idx * SUBGROUP_SIZE * KEY_VEC_SIZE + i * SUBGROUP_SIZE);
237+
}
236238
#endif
237239

240+
#if XE2_QK_MULTIPLICATION
238241
#if STORE_QUERY_TO_SLM
239242
MAKE_VECTOR_TYPE(INPUT0_TYPE, QUERIES_PER_WI) q_val;
240243
unroll_for (uint q_idx = 0; q_idx < QUERIES_PER_WI; q_idx++) {
@@ -249,6 +252,20 @@ KERNEL(pa_sdpa_opt)(
249252
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
250253
#endif
251254
}
255+
#else // !XE2_QK_MULTIPLICATION
256+
unroll_for (uint q_idx = 0; q_idx < QUERIES_PER_WI; q_idx++) {
257+
#if STORE_QUERY_TO_SLM
258+
SOFTMAX_ACCUMULATOR_TYPE q_val = slm_query[q_idx * HEAD_SIZE + qk_idx * KEY_VEC_SIZE + sglid];
259+
#endif
260+
unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
261+
#if STORE_QUERY_TO_SLM
262+
GET_VECTOR_ELEMENT(qk_acc, q_idx) = mad(sub_group_broadcast(q_val, i), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), GET_VECTOR_ELEMENT(qk_acc, q_idx));
263+
#else
264+
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
265+
#endif
266+
}
267+
}
268+
#endif // XE2_QK_MULTIPLICATION
252269
}
253270

254271
const uint token_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + block_num * SUBGROUPS_PER_WG * SUBGROUP_SIZE + sgid * SUBGROUP_SIZE + sglid;

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
247247
jit.AddConstant(MakeJitConstant("SLIDING_WINDOW_SIZE", config.paged_attention_sliding_window));
248248
jit.AddConstant(MakeJitConstant("IS_KV_COMPRESSED", params.conf.is_kv_compressed));
249249
jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, config.head_size, kernel_idx)));
250+
jit.AddConstant(MakeJitConstant("XE2_QK_MULTIPLICATION", params.engineInfo.arch == gpu_arch::xe2));
250251

251252
if (params.conf.is_kv_compressed) {
252253
auto scales_zp_size = params.inputs[0].ElementSize() * 2; // scale + zp

0 commit comments

Comments
 (0)