Skip to content

Commit 55ea9f8

Browse files
committed
Allow arbitary BLOCK_SIZE
1 parent 8e764a2 commit 55ea9f8

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_ref.cl

+4-3
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,15 @@ KERNEL(pa_sdpa_ref)(
249249

250250
OUTPUT_TYPE qk = qk_offset < context_len ? qk_vals[qk_offset] : OUTPUT_VAL_ZERO;
251251

252-
const uint block_idx = block_tables[batch_idx * blocks_num + (qk_idx / SUB_GROUP_SIZE)];
252+
const uint block_idx = block_tables[batch_idx * blocks_num + (qk_idx / BLOCK_SIZE)];
253253
if (block_idx == 0)
254254
continue;
255255

256256
const uint value_cache_offset = block_idx * KV_CACHE_BLOCK_STRIDE +
257257
(head_num_idx / NUM_QUERIES_PER_KV_HEAD) * (HEAD_SIZE * BLOCK_SIZE) +
258258
sgid * (SUB_GROUP_SIZE * BLOCK_SIZE) +
259-
sglid * BLOCK_SIZE;
259+
sglid * BLOCK_SIZE +
260+
((qk_idx / SUB_GROUP_SIZE) % (BLOCK_SIZE / SUB_GROUP_SIZE)) * SUB_GROUP_SIZE;
260261

261262
#define VALUE_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE)
262263
#define VALUE_VLOAD(offset, ptr) CAT(vload, BLOCK_SIZE)(offset, ptr)
@@ -271,7 +272,7 @@ KERNEL(pa_sdpa_ref)(
271272
// seq_idx, head_num_idx, sgid, sglid, block_idx, qk_idx, qk_offset, value_cache_offset - (block_idx * KV_CACHE_BLOCK_STRIDE), block_idx * KV_CACHE_BLOCK_STRIDE, *tmp_print);
272273
// }
273274

274-
for (uint token = 0; token < BLOCK_SIZE; token++) {
275+
for (uint token = 0; token < SUB_GROUP_SIZE; token++) {
275276
OUTPUT_TYPE qk_tmp = sub_group_broadcast(qk, token);
276277
if (qk_idx + token < context_len) {
277278
acc = mad(qk_tmp, v[token], acc);

0 commit comments

Comments
 (0)