Skip to content

Commit 8451b0e

Browse files
committed
DRAFT: [GPU] Optimize PA GQA second token
1 parent 9fa105e commit 8451b0e

File tree

5 files changed

+191
-68
lines changed

5 files changed

+191
-68
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
10091009
impl->use_micro_sdpa = true;
10101010
}
10111011

1012+
std::cout << "use_micro=" << impl->use_micro_sdpa << " KV-cache layouts=["
1013+
<< impl_param.get_input_layout(3).to_short_string() << ", "
1014+
<< impl_param.get_input_layout(4).to_short_string() << "]\n";
1015+
10121016
return impl;
10131017
}
10141018

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+146-66
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
#error pa_sdpa_opt.cl
2525
#endif
2626

27+
#if HEADS_PER_REQUEST > 1
28+
#define STORE_QUERY_TO_SLM 1
29+
#define TO_SOFTMAX_ACCUMULATOR_TYPE_VEC CAT(convert_float, HEADS_PER_REQUEST)
30+
#endif
31+
2732
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
2833
__attribute__((reqd_work_group_size(1, 1, HEAD_SIZE * SG_SCALE_FACTOR)))
2934
KERNEL(pa_sdpa_opt)(
@@ -74,7 +79,18 @@ KERNEL(pa_sdpa_opt)(
7479
// tmp_out: [sequences_num, HEADS_NUM, total_partitions_num, HEAD_SIZE]
7580

7681
const uint seq_idx = get_global_id(0);
82+
#if HEADS_PER_REQUEST > 1
83+
const uint heads_group_idx = get_global_id(1);
84+
const uint head_num_idx = heads_group_idx * HEADS_PER_REQUEST - ((heads_group_idx / REQUESTS_PER_KV_HEAD) * HEADS_PROCESSING_LEFTOVER);
85+
const uint iter_heads_num = min(BROADCAST_GROUP_SIZE - ((heads_group_idx % REQUESTS_PER_KV_HEAD) * HEADS_PER_REQUEST), (uint)HEADS_PER_REQUEST);
86+
87+
// if (get_global_id(0) == 0 && get_global_id(2) == 0) {
88+
// printf("id=%d, head_num_idx=%d, iter_heads_num=%d\n", heads_group_idx, head_num_idx, iter_heads_num);
89+
// }
90+
91+
#else
7792
const uint head_num_idx = get_global_id(1);
93+
#endif
7894
const uint sglid = get_sub_group_local_id();
7995
const uint sgid = get_sub_group_id();
8096
const uint total_partitions_num = get_num_groups(2);
@@ -110,36 +126,45 @@ KERNEL(pa_sdpa_opt)(
110126

111127
#ifdef STORE_QUERY_TO_SLM
112128
// SLM buffer for query inputs
113-
__local INPUT0_TYPE slm_query[HEAD_SIZE];
129+
__local INPUT0_TYPE slm_query[HEAD_SIZE * HEADS_PER_REQUEST];
114130
#endif
115131

116132
// SLM for intermediate QK results
117-
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
133+
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE * HEADS_PER_REQUEST];
118134

119135
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
120-
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG];
121-
__local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sum_vals[SUBGROUPS_PER_WG];
136+
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG * HEADS_PER_REQUEST];
137+
__local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sum_vals[SUBGROUPS_PER_WG * HEADS_PER_REQUEST];
122138

123-
SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN;
139+
MAKE_VECTOR_TYPE(SOFTMAX_ACCUMULATOR_TYPE, HEADS_PER_REQUEST) qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN;
124140

125141
{
126142
#if STORE_QUERY_TO_SLM
127-
const uint query_idx_local = sgid * SUBGROUP_SIZE + sglid;
128-
const uint query_idx = INPUT0_OFFSET +
129-
seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
130-
head_num_idx * HEAD_SIZE +
131-
query_idx_local;
132-
133-
INPUT0_TYPE q_val = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);
143+
for (uint idx = sgid * SUBGROUP_SIZE; idx < HEADS_PER_REQUEST * HEAD_SIZE; idx += SUBGROUP_SIZE) {
144+
const uint query_idx_local = idx % HEAD_SIZE + sglid;
145+
const uint head_idx = idx / HEAD_SIZE;
134146

135-
// Apply scale value directly to the query input to improve accuracy in case of a high range of input data
136-
#ifdef SCALE_VAL
137-
q_val = TO_INPUT0_TYPE(SCALE_VAL) * q_val;
138-
#else
139-
q_val = *scale * q_val;
147+
#if HEADS_PROCESSING_LEFTOVER > 0
148+
// Do not load more than needed
149+
if (head_idx >= iter_heads_num)
150+
break;
140151
#endif
141152

142-
slm_query[query_idx_local] = q_val;
153+
const uint query_idx = INPUT0_OFFSET +
154+
seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
155+
(head_num_idx + head_idx) * HEAD_SIZE +
156+
query_idx_local;
157+
158+
INPUT0_TYPE q_val = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);
159+
// Apply scale value directly to the query input to improve accuracy in case of a high range of input data
160+
#ifdef SCALE_VAL
161+
q_val = TO_INPUT0_TYPE(SCALE_VAL) * q_val;
162+
#else
163+
q_val = *scale * q_val;
164+
#endif
165+
166+
slm_query[head_idx * HEAD_SIZE + query_idx_local] = q_val;
167+
}
143168

144169
barrier(CLK_LOCAL_MEM_FENCE);
145170
#else
@@ -175,7 +200,7 @@ KERNEL(pa_sdpa_opt)(
175200
#endif
176201
const uint block_offset = block_indices[start_block_idx + block_num * SUBGROUPS_PER_WG] * ADJUSTED_HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * ADJUSTED_HEAD_SIZE * SUBGROUP_SIZE;
177202

178-
SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO;
203+
MAKE_VECTOR_TYPE(SOFTMAX_ACCUMULATOR_TYPE, HEADS_PER_REQUEST) qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO;
179204

180205
#define KEY_VEC_SIZE SUBGROUP_SIZE
181206
#define KEY_BLOCK MAKE_VECTOR_TYPE(INPUT1_TYPE, KEY_VEC_SIZE)
@@ -202,12 +227,15 @@ KERNEL(pa_sdpa_opt)(
202227
#endif
203228

204229
#if STORE_QUERY_TO_SLM
205-
INPUT0_TYPE q_val = slm_query[qk_idx * KEY_VEC_SIZE + sglid];
230+
MAKE_VECTOR_TYPE(INPUT0_TYPE, HEADS_PER_REQUEST) q_val;
231+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
232+
q_val[i] = slm_query[i * HEAD_SIZE + qk_idx * KEY_VEC_SIZE + sglid];
233+
}
206234
#endif
207235

208236
unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
209237
#if STORE_QUERY_TO_SLM
210-
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val, i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
238+
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE_VEC(sub_group_broadcast(q_val, i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
211239
#else
212240
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
213241
#endif
@@ -218,7 +246,7 @@ KERNEL(pa_sdpa_opt)(
218246

219247
#ifdef HAS_ALIBI
220248
const int alibi_val = (1 - seq_len) + token_idx;
221-
qk_acc += alibi_slopes[head_num_idx] * alibi_val;
249+
qk_acc += alibi_slopes[head_num_idx] * alibi_val; // TODO: UPDATE THIS
222250
#endif
223251

224252
#if SLIDING_WINDOW_SIZE != 0
@@ -228,30 +256,41 @@ KERNEL(pa_sdpa_opt)(
228256
#endif
229257
qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN;
230258

231-
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc));
259+
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE_VEC(qk_acc));
232260

233-
slm_qk_vals[block_num * SUBGROUPS_PER_WG * SUBGROUP_SIZE + sgid * SUBGROUP_SIZE + sglid] = qk_acc;
261+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
262+
slm_qk_vals[i * SEQ_LEN_PARTITION_SIZE + block_num * SUBGROUPS_PER_WG * SUBGROUP_SIZE + sgid * SUBGROUP_SIZE + sglid] = qk_acc[i];
263+
}
234264
}
235265

236-
qk_max = sub_group_reduce_max(qk_max);
266+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
267+
qk_max[i] = sub_group_reduce_max(qk_max[i]);
268+
}
237269
}
238270

239271
{
240272
// SoftMax calculation
241273
if (sglid == 0) {
242-
slm_qk_max_vals[sgid] = qk_max;
274+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
275+
slm_qk_max_vals[i * SUBGROUPS_PER_WG + sgid] = qk_max[i];
276+
}
243277
}
244278

245279
barrier(CLK_LOCAL_MEM_FENCE);
246280

247281
qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN;
248-
if (sglid < SUBGROUPS_PER_WG)
249-
qk_max = slm_qk_max_vals[sglid];
282+
if (sglid < SUBGROUPS_PER_WG) {
283+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
284+
qk_max[i] = slm_qk_max_vals[i * SUBGROUPS_PER_WG + sglid];
285+
}
286+
}
250287

251288
// Final max value after reduction across of all SG and WI
252-
qk_max = sub_group_reduce_max(qk_max);
289+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
290+
qk_max[i] = sub_group_reduce_max(qk_max[i]);
291+
}
253292

254-
SOFTMAX_ACCUMULATOR_TYPE exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
293+
MAKE_VECTOR_TYPE(SOFTMAX_ACCUMULATOR_TYPE, HEADS_PER_REQUEST) exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
255294

256295
const uint qk_iters_num = CEIL_DIV(SEQ_LEN_PARTITION_SIZE, SUBGROUPS_PER_WG * SUBGROUP_SIZE);
257296
for (uint qk_idx = 0; qk_idx < qk_iters_num; qk_idx++) {
@@ -264,27 +303,38 @@ KERNEL(pa_sdpa_opt)(
264303
#else
265304
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
266305
#endif
267-
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max);
268-
slm_qk_vals[local_data_idx] = qk_new;
269-
270-
exp_sum += qk_new;
306+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
307+
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[i * SEQ_LEN_PARTITION_SIZE + local_data_idx]) - qk_max[i]);
308+
slm_qk_vals[i * SEQ_LEN_PARTITION_SIZE + local_data_idx] = qk_new;
309+
exp_sum[i] += qk_new;
310+
}
271311
}
272312
}
273313

274-
exp_sum = sub_group_reduce_add(exp_sum);
314+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
315+
exp_sum[i] = sub_group_reduce_add(exp_sum[i]);
316+
}
275317

276-
if (sglid == 0)
277-
slm_exp_sum_vals[sgid] = exp_sum;
318+
if (sglid == 0) {
319+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
320+
slm_exp_sum_vals[i * SUBGROUPS_PER_WG + sgid] = exp_sum[i];
321+
}
322+
}
278323

279324
barrier(CLK_LOCAL_MEM_FENCE);
280325

281326
exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
282327

283-
if (sglid < SUBGROUPS_PER_WG)
284-
exp_sum = slm_exp_sum_vals[sglid];
328+
if (sglid < SUBGROUPS_PER_WG) {
329+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
330+
exp_sum[i] = slm_exp_sum_vals[i * SUBGROUPS_PER_WG + sglid];
331+
}
332+
}
285333

286334
// Final sum of all exp_sum values
287-
exp_sum = sub_group_reduce_add(exp_sum);
335+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
336+
exp_sum[i] = sub_group_reduce_add(exp_sum[i]);
337+
}
288338

289339
for (uint qk_idx = 0; qk_idx < qk_iters_num; qk_idx++) {
290340
const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid;
@@ -295,8 +345,10 @@ KERNEL(pa_sdpa_opt)(
295345
#else
296346
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
297347
#endif
298-
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum;
299-
slm_qk_vals[local_data_idx] = qk_new;
348+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
349+
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[i * SEQ_LEN_PARTITION_SIZE + local_data_idx]) / exp_sum[i];
350+
slm_qk_vals[i * SEQ_LEN_PARTITION_SIZE + local_data_idx] = qk_new;
351+
}
300352
}
301353
}
302354

@@ -305,13 +357,19 @@ KERNEL(pa_sdpa_opt)(
305357
{
306358
// Save temporary exm_sums and max_logits values for each partition_num
307359
if (seq_len > SEQ_LEN_PARTITION_SIZE && sgid == 0) {
308-
const uint exp_sums_offset = seq_idx * HEADS_NUM * total_partitions_num +
309-
head_num_idx * total_partitions_num +
310-
partition_idx;
311-
exp_sums[exp_sums_offset] = exp_sum;
312-
313-
const uint max_logits_offset = exp_sums_offset;
314-
max_logits[max_logits_offset] = qk_max;
360+
unroll_for (uint i = 0; i < HEADS_PER_REQUEST; i++) {
361+
#if HEADS_PROCESSING_LEFTOVER > 0
362+
if (i >= iter_heads_num)
363+
break;
364+
#endif
365+
const uint exp_sums_offset = seq_idx * HEADS_NUM * total_partitions_num +
366+
(head_num_idx + i) * total_partitions_num +
367+
partition_idx;
368+
exp_sums[exp_sums_offset] = exp_sum[i];
369+
370+
const uint max_logits_offset = exp_sums_offset;
371+
max_logits[max_logits_offset] = qk_max[i];
372+
}
315373
}
316374

317375
#if PAGED_ATTENTION_SCORES_OUTPUT
@@ -327,6 +385,7 @@ KERNEL(pa_sdpa_opt)(
327385
// PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
328386
// so save SEQ_LEN_PARTITION_SIZE elements for each partition
329387
if (save_softmax_results) {
388+
// TODO: UPDATE THIS
330389
const uint output_offset = subsequence_idx * HEADS_NUM * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
331390
head_num_idx * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
332391
partition_idx * SEQ_LEN_PARTITION_SIZE;
@@ -340,7 +399,7 @@ KERNEL(pa_sdpa_opt)(
340399

341400
{
342401
// QK*V calculation
343-
OUTPUT_TYPE acc = OUTPUT_VAL_ZERO;
402+
MAKE_VECTOR_TYPE(OUTPUT_TYPE, HEADS_PER_REQUEST) acc = OUTPUT_VAL_ZERO;
344403

345404
const uint partition_seq_len = min(seq_len - partition_idx * SEQ_LEN_PARTITION_SIZE, (uint)SEQ_LEN_PARTITION_SIZE);
346405

@@ -398,10 +457,12 @@ KERNEL(pa_sdpa_opt)(
398457
VALUE_BLOCK value_vals = v_vals_packed;
399458
#endif
400459

401-
OUTPUT_TYPE qk_val = slm_qk_vals[block_num * PAGED_ATTENTION_BLOCK_SIZE + sglid];
460+
unroll_for (uint iq = 0; iq < HEADS_PER_REQUEST; iq++) {
461+
OUTPUT_TYPE qk_val = slm_qk_vals[iq * SEQ_LEN_PARTITION_SIZE + block_num * PAGED_ATTENTION_BLOCK_SIZE + sglid];
402462

403-
unroll_for (uint i = 0; i < VALUE_VEC_SIZE; i++) {
404-
acc = mad(sub_group_broadcast(qk_val, i), value_vals[i], acc);
463+
unroll_for (uint i = 0; i < VALUE_VEC_SIZE; i++) {
464+
acc[iq] = mad(sub_group_broadcast(qk_val, i), value_vals[i], acc[iq]);
465+
}
405466
}
406467
}
407468

@@ -426,7 +487,11 @@ KERNEL(pa_sdpa_opt)(
426487
INPUT0_TYPE comp_zp = value_comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid];
427488
#endif
428489

429-
OUTPUT_TYPE qk_val = slm_qk_vals[blocks_num_per_partition * PAGED_ATTENTION_BLOCK_SIZE + sglid];
490+
MAKE_VECTOR_TYPE(OUTPUT_TYPE, HEADS_PER_REQUEST) qk_val;
491+
unroll_for (uint iq = 0; iq < HEADS_PER_REQUEST; iq++) {
492+
qk_val[iq] = slm_qk_vals[iq * SEQ_LEN_PARTITION_SIZE + blocks_num_per_partition * PAGED_ATTENTION_BLOCK_SIZE + sglid];
493+
}
494+
// OUTPUT_TYPE qk_val =
430495
for (uint i = 0; i < leftovers; i++) {
431496
INPUT2_TYPE value_packed = BLOCK_READN(INPUT2_TYPE, 1, value_cache, value_offset + i * HEAD_SIZE);
432497
#if IS_KV_COMPRESSED
@@ -437,7 +502,9 @@ KERNEL(pa_sdpa_opt)(
437502
VALUE_UNCOMPRESSED value_val = value_packed;
438503
#endif
439504

440-
acc = mad(sub_group_broadcast(qk_val, i), value_val, acc);
505+
unroll_for (uint iq = 0; iq < HEADS_PER_REQUEST; iq++) {
506+
acc[iq] = mad(sub_group_broadcast(qk_val[iq], i), value_val, acc[iq]);
507+
}
441508
}
442509
}
443510

@@ -469,20 +536,33 @@ KERNEL(pa_sdpa_opt)(
469536
#endif
470537

471538
if (seq_len > SEQ_LEN_PARTITION_SIZE) {
472-
const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * total_partitions_num) +
473-
head_num_idx * (HEAD_SIZE * total_partitions_num) +
474-
partition_idx * HEAD_SIZE +
475-
sgid * SUBGROUP_SIZE +
476-
sglid;
539+
unroll_for (uint iq = 0; iq < HEADS_PER_REQUEST; iq++) {
540+
#if HEADS_PROCESSING_LEFTOVER > 0
541+
if (iq >= iter_heads_num)
542+
break;
543+
#endif
544+
545+
const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * total_partitions_num) +
546+
(head_num_idx + iq) * (HEAD_SIZE * total_partitions_num) +
547+
partition_idx * HEAD_SIZE +
548+
sgid * SUBGROUP_SIZE +
549+
sglid;
477550

478-
tmp_out[tmp_out_offset] = acc;
551+
tmp_out[tmp_out_offset] = acc[iq];
552+
}
479553
} else {
480-
const uint output_offset = seq_idx * (HEADS_NUM * HEAD_SIZE) +
481-
head_num_idx * HEAD_SIZE +
482-
sgid * SUBGROUP_SIZE +
483-
sglid;
554+
unroll_for (uint iq = 0; iq < HEADS_PER_REQUEST; iq++) {
555+
#if HEADS_PROCESSING_LEFTOVER > 0
556+
if (iq >= iter_heads_num)
557+
break;
558+
#endif
559+
const uint output_offset = seq_idx * (HEADS_NUM * HEAD_SIZE) +
560+
(head_num_idx + iq) * HEAD_SIZE +
561+
sgid * SUBGROUP_SIZE +
562+
sglid;
484563

485-
output[output_offset] = acc;
564+
output[output_offset] = acc[iq];
565+
}
486566
}
487567

488568
#if SG_SCALE_FACTOR > 1

0 commit comments

Comments
 (0)