Skip to content

Commit 6129476

Browse files
committed
[GPU] Add reduce kernel
1 parent 15d02d0 commit 6129476

File tree

6 files changed

+264
-71
lines changed

6 files changed

+264
-71
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+27-4
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,6 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
291291
for (auto& ev : res_events)
292292
all_events.push_back(ev);
293293

294-
auto impl_param = *instance.get_impl_params();
295-
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic());
296-
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
297-
298294
execute_stage(all_events, instance, res_events, Stage::SDPA);
299295

300296
return aggregate_events(res_events, stream, res_events.size() > 1);
@@ -331,6 +327,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
331327
config.kv_heads_num = kv_heads_num;
332328
config.block_size = block_size;
333329
config.x_size = x_size;
330+
config.max_context_len = 1;
334331
}
335332

336333
return config;
@@ -397,6 +394,29 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
397394
params.inputs[6] = convert_data_tensor(scale_layout);
398395

399396
params.configuration = get_sdpa_configuration(impl_param);
397+
GPU_DEBUG_TRACE_DETAIL << "Number of constant_mem " << impl_param.memory_deps.size() << ", dynamic=" << is_dynamic << "\n";
398+
if (!is_dynamic) {
399+
auto& constant_mem = impl_param.memory_deps;
400+
401+
402+
const auto max_context_len_mem = constant_mem.at(7);
403+
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
404+
GPU_DEBUG_TRACE_DETAIL << "max_context_len_mem_lock=" << max_context_len_mem_lock[0] << "\n";
405+
406+
const auto is_prompt_stage_mem = constant_mem.at(5);
407+
mem_lock<uint8_t, mem_lock_type::read> is_prompt_stage_mem_lock(is_prompt_stage_mem, impl_param.get_stream());
408+
bool is_prompt_stage = is_prompt_stage_mem_lock[0];
409+
410+
if (is_prompt_stage) {
411+
// Use number of slots for KV cache as a maximum context length for the first iteration
412+
auto slot_mapping = impl_param.get_input_layout(6);
413+
params.configuration.max_context_len = slot_mapping.get_shape()[1];
414+
} else {
415+
const auto max_context_len_mem = constant_mem.at(7);
416+
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
417+
params.configuration.max_context_len = max_context_len_mem_lock[0];
418+
}
419+
}
400420

401421
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
402422
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;
@@ -434,6 +454,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
434454
void update_dispatch_data(const kernel_impl_params& impl_param) override {
435455
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic());
436456
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
457+
458+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic());
459+
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
437460
}
438461
};
439462

src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct typed_program_node<paged_attention> : public typed_program_node_base<page
2323
program_node& key_cache() const { return get_dependency(3); }
2424
program_node& value_cache() const { return get_dependency(4); }
2525

26-
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
26+
std::vector<size_t> get_shape_infer_dependencies() const override { return { 5 /* is_prompt */, 7 /* max_context_len */ }; }
2727
};
2828

2929
using paged_attention_node = typed_program_node<paged_attention>;

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_ref.cl

+152-29
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
ulong __attribute__((overloadable)) intel_get_cycle_counter( void );
4040

41+
#ifdef SDPA_STAGE_0
42+
4143
REQD_SUB_GROUP_SIZE(SUB_GROUP_SIZE)
4244
__attribute__((reqd_work_group_size(1, 1, HEAD_SIZE)))
4345
KERNEL(pa_sdpa_ref)(
@@ -49,7 +51,11 @@ KERNEL(pa_sdpa_ref)(
4951
__global const INPUT4_TYPE* context_lens,
5052
__global const INPUT5_TYPE* block_tables,
5153
__global const INPUT6_TYPE* scale,
52-
__global OUTPUT_TYPE* output)
54+
__global OUTPUT_TYPE* output,
55+
__global OUTPUT_TYPE* exp_sums,
56+
__global OUTPUT_TYPE* max_logits,
57+
__global OUTPUT_TYPE* tmp_out,
58+
uint num_of_portions)
5359
{
5460
const uint seq_idx = get_global_id(0);
5561
const uint head_num_idx = get_global_id(1);
@@ -64,6 +70,11 @@ KERNEL(pa_sdpa_ref)(
6470

6571
const uint blocks_num = INPUT5_FEATURE_NUM;
6672

73+
const uint portion_id = get_group_id(2);
74+
const uint block_start_idx = portion_id * SEQ_LEN_PORTION_SIZE / BLOCK_SIZE;
75+
const uint block_end_idx = min(block_start_idx + (SEQ_LEN_PORTION_SIZE / BLOCK_SIZE), blocks_num);
76+
77+
6778
// if (seq_idx < 2 && head_num_idx < 2 && sgid < 2 && sglid < 2) {
6879
// if (INPUT5_BATCH_NUM == 2) {
6980
// if (INPUT5_FEATURE_NUM == 0) {
@@ -135,16 +146,17 @@ KERNEL(pa_sdpa_ref)(
135146
q[i] = QUERY_BLOCK_READ(query, query_idx);
136147
}
137148

138-
for (uint block = 0; block < blocks_num; block++) {
139-
const uint block_idx = batch_idx * blocks_num + block;
149+
// JIT: Compile time restriction: devisible SEQ_LEN_PORTION_SIZE / BLOCK_SIZE
150+
for (uint block = 0; block < SEQ_LEN_PORTION_SIZE / BLOCK_SIZE; block++) {
151+
const uint block_idx = batch_idx * blocks_num + block + block_start_idx;
140152
const uint block_offset = block_tables[block_idx] * KV_CACHE_BLOCK_STRIDE;
141153

142154
OUTPUT_TYPE qk[QK_VALS_PER_SG_PER_ITER] = {0};
143155

144156
ulong timer2 = intel_get_cycle_counter();
145157
for (uint hs = 0; hs < Q_LOAD_ITERS; hs++) {
146158
for (uint qk_idx = 0; qk_idx < QK_VALS_PER_SG_PER_ITER; qk_idx++) {
147-
uint current_token = block * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx;
159+
uint current_token = (block + block_start_idx) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx;
148160
if (current_token >= context_len)
149161
continue;
150162

@@ -185,7 +197,7 @@ KERNEL(pa_sdpa_ref)(
185197

186198
// Summurize qk calculation across all WIs and apply scale
187199
for (uint qk_idx = 0; qk_idx < QK_VALS_PER_SG_PER_ITER; qk_idx++) {
188-
const uint current_token = block * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx;
200+
const uint current_token = (block + block_start_idx) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx;
189201
if (current_token < context_len) {
190202
OUTPUT_TYPE tmp_print = qk[qk_idx];
191203
qk[qk_idx] = sub_group_reduce_add(qk[qk_idx]);
@@ -194,7 +206,7 @@ KERNEL(pa_sdpa_ref)(
194206
// seq_idx, head_num_idx, sgid, sglid, qk_idx, tmp_print, qk[qk_idx]);
195207
qk[qk_idx] = scale[0] * qk[qk_idx];
196208

197-
// Apply attention mask during prefill stage
209+
// Apply attention mask at prefill stage
198210
if (INPUT0_FEATURE_NUM > 1 && current_token > token_idx) {
199211
qk[qk_idx] = qk[qk_idx] + OUTPUT_VAL_MIN;
200212
}
@@ -206,12 +218,13 @@ KERNEL(pa_sdpa_ref)(
206218
// Save QK results to local memory
207219
if (sglid < QK_VALS_PER_SG_PER_ITER) {
208220
const uint current_token = block * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + sglid;
221+
const uint current_token_global_idx = (block + block_start_idx) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + sglid;
209222
// Fixed -> // const uint qk_local_idx = block * BLOCK_SIZE * sgid * QK_VALS_PER_SG_PER_ITER + sglid;
210223
// OUTPUT_TYPE tmp_print = (current_token >= context_len ? 0 : qk[sglid]);
211224
// if (head_num_idx < 4 || head_num_idx == 31)
212225
// printf("slm save: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: qk_vals[%d]=%f. Max=%f\n",
213226
// seq_idx, head_num_idx, sgid, sglid, current_token, tmp_print, qk_max);
214-
qk_vals[current_token] = current_token >= context_len ? 0 : qk[sglid];
227+
qk_vals[current_token] = current_token_global_idx >= context_len ? 0 : qk[sglid];
215228
}
216229
ulong timer5 = intel_get_cycle_counter();
217230

@@ -266,12 +279,13 @@ KERNEL(pa_sdpa_ref)(
266279
// }
267280

268281
OUTPUT_TYPE exp_sum = OUTPUT_VAL_ZERO;
269-
for (uint qk_idx = 0; qk_idx < CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE); qk_idx++) {
270-
const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
271-
if (data_idx < context_len) {
272-
OUTPUT_TYPE val = native_exp(qk_vals[data_idx] - qk_max);
282+
for (uint qk_idx = 0; qk_idx < CEIL_DIV(SEQ_LEN_PORTION_SIZE, SUBGROUPS_PER_WG * SUB_GROUP_SIZE); qk_idx++) {
283+
const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
284+
const uint global_data_idx = block_start_idx * BLOCK_SIZE + qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
285+
if (global_data_idx < context_len) {
286+
OUTPUT_TYPE val = native_exp(qk_vals[local_data_idx] - qk_max);
273287
exp_sum += val;
274-
qk_vals[data_idx] = val;
288+
qk_vals[local_data_idx] = val;
275289
// if (head_num_idx < 4 || head_num_idx == 31)
276290
// printf("head_num %d, sgid = %d, sglid = %d, exp_sum = %f\n", head_num_idx, sgid, sglid, exp_sum);
277291
}
@@ -290,6 +304,8 @@ KERNEL(pa_sdpa_ref)(
290304

291305
exp_sum = OUTPUT_VAL_ZERO;
292306

307+
308+
// JIT: Compile time restiction SUBGROUPS_PER_WG <= SG_SIZE
293309
if (sglid < SUBGROUPS_PER_WG)
294310
exp_sum = qk_sum_vals[sglid];
295311

@@ -300,20 +316,34 @@ KERNEL(pa_sdpa_ref)(
300316

301317

302318
// TODO: replace CEIL_DIV with ALIGN and use += SUBGROUPS_PER_WG * SUB_GROUP_SIZE increment
303-
for (uint qk_idx = 0; qk_idx < CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE); qk_idx++) {
304-
const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
305-
if (data_idx < context_len) {
306-
OUTPUT_TYPE val = qk_vals[data_idx] * inv_sum;
307-
qk_vals[data_idx] = val;
319+
for (uint qk_idx = 0; qk_idx < CEIL_DIV(SEQ_LEN_PORTION_SIZE, SUBGROUPS_PER_WG * SUB_GROUP_SIZE); qk_idx++) {
320+
const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
321+
const uint global_data_idx = block_start_idx * BLOCK_SIZE + qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
322+
if (global_data_idx < context_len) {
323+
OUTPUT_TYPE val = qk_vals[local_data_idx] * inv_sum;
324+
qk_vals[local_data_idx] = val;
308325
}
309326
}
310327

311328
barrier(CLK_LOCAL_MEM_FENCE);
312329

313-
314330
ulong timer_end = intel_get_cycle_counter();
315331
ulong total_time = timer_end - timer_start;
316332

333+
{
334+
// Save temporary exm_sums and max_logits values for each portion
335+
if (sgid == 0) {
336+
const uint num_of_portions = get_num_groups(2);
337+
const uint exp_sums_offset = seq_idx * HEADS_NUM * num_of_portions +
338+
head_num_idx * num_of_portions +
339+
portion_id;
340+
exp_sums[exp_sums_offset] = exp_sum;
341+
342+
const uint max_logits_offset = exp_sums_offset;
343+
max_logits[max_logits_offset] = qk_max;
344+
}
345+
}
346+
317347
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
318348
// printf("SDPA kernel Softmax: %d\n", (uint)total_time);
319349
}
@@ -328,12 +358,14 @@ KERNEL(pa_sdpa_ref)(
328358
ulong timer_start = intel_get_cycle_counter();
329359
OUTPUT_TYPE acc = OUTPUT_VAL_ZERO;
330360

331-
for (uint qk_idx = 0; qk_idx < ALIGN(context_len, SUB_GROUP_SIZE); qk_idx += SUB_GROUP_SIZE) {
332-
const uint qk_offset = qk_idx + sglid;
333361

334-
OUTPUT_TYPE qk = qk_offset < context_len ? qk_vals[qk_offset] : OUTPUT_VAL_ZERO;
362+
for (uint qk_idx = 0; qk_idx < SEQ_LEN_PORTION_SIZE / BLOCK_SIZE * SUB_GROUP_SIZE; qk_idx += SUB_GROUP_SIZE) {
363+
const uint qk_offset_local = qk_idx + sglid;
364+
const uint qk_offset_global = block_start_idx * BLOCK_SIZE + qk_offset_local;
335365

336-
const uint block_idx = block_tables[batch_idx * blocks_num + (qk_idx / BLOCK_SIZE)];
366+
OUTPUT_TYPE qk = qk_offset_global < context_len ? qk_vals[qk_offset_local] : OUTPUT_VAL_ZERO;
367+
368+
const uint block_idx = block_tables[batch_idx * blocks_num + block_start_idx + (qk_idx / BLOCK_SIZE)];
337369
// if (block_idx == 0)
338370
// continue;
339371

@@ -356,33 +388,49 @@ KERNEL(pa_sdpa_ref)(
356388
// seq_idx, head_num_idx, sgid, sglid, block_idx, qk_idx, qk_offset, value_cache_offset - (block_idx * KV_CACHE_BLOCK_STRIDE), block_idx * KV_CACHE_BLOCK_STRIDE, *tmp_print);
357389
// }
358390

359-
if (qk_idx + SUB_GROUP_SIZE <= context_len) {
391+
// FINAL: rename token -> value_idx
392+
if (block_start_idx * BLOCK_SIZE + qk_idx + SUB_GROUP_SIZE <= context_len) {
360393
unroll_for (uint token = 0; token < SUB_GROUP_SIZE; token++) {
361394
OUTPUT_TYPE qk_tmp = sub_group_broadcast(qk, token);
362395
acc = mad(qk_tmp, v[token], acc);
363396
}
364397
} else {
365398
for (uint token = 0; token < SUB_GROUP_SIZE; token++) {
366399
OUTPUT_TYPE qk_tmp = sub_group_broadcast(qk, token);
367-
if (qk_idx + token < context_len) {
400+
if (block_start_idx * BLOCK_SIZE + qk_idx + token < context_len) {
368401
acc = mad(qk_tmp, v[token], acc);
369402
}
370403
}
371404
}
372405
}
373406

374407

375-
const uint output_offset = seq_idx * (HEADS_NUM * HEAD_SIZE) +
376-
head_num_idx * HEAD_SIZE +
377-
sgid * SUB_GROUP_SIZE +
378-
sglid;
408+
// const uint output_offset = seq_idx * (HEADS_NUM * HEAD_SIZE) +
409+
// head_num_idx * HEAD_SIZE +
410+
// sgid * SUB_GROUP_SIZE +
411+
// sglid;
379412

380413
// if (seq_idx == 0 && head_num_idx < 2 || head_num_idx == 31) {
381414
// printf("output res: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: output[%d] = %f\n",
382415
// seq_idx, head_num_idx, sgid, sglid, output_offset, acc);
383416
// }
384417

385-
output[output_offset] = acc;
418+
// output[output_offset] = acc;
419+
420+
{
421+
// [num_seqs, num_heads, max_num_partitions, head_size]
422+
const uint num_of_portions = get_num_groups(2);
423+
const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * num_of_portions) +
424+
head_num_idx * (HEAD_SIZE * num_of_portions) +
425+
portion_id * HEAD_SIZE +
426+
sgid * SUB_GROUP_SIZE +
427+
sglid;
428+
429+
// if (output_offset != tmp_out_offset)
430+
// printf("Different tmp_out_offset index!! %d vs %d, for portion_id %d\n", output_offset, tmp_out_offset, portion_id);
431+
432+
tmp_out[tmp_out_offset] = acc;
433+
}
386434

387435
ulong timer_end = intel_get_cycle_counter();
388436
ulong total_time = timer_end - timer_start;
@@ -391,3 +439,78 @@ KERNEL(pa_sdpa_ref)(
391439
// printf("SDPA kernel GEMM2: %d\n", (uint)total_time);
392440
}
393441
}
442+
443+
#endif
444+
445+
#ifdef SDPA_STAGE_1
446+
447+
// exp_sums, // [num_seqs, num_heads, max_num_partitions]
448+
// max_logits, // [num_seqs, num_heads, max_num_partitions]
449+
// tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
450+
451+
REQD_SUB_GROUP_SIZE(SUB_GROUP_SIZE)
452+
KERNEL(pa_sdpa_ref)(
453+
OPTIONAL_SHAPE_INFO_ARG
454+
__global const INPUT0_TYPE* query,
455+
__global const INPUT1_TYPE* key_cache,
456+
__global const INPUT2_TYPE* value_cache,
457+
__global const INPUT3_TYPE* max_context_len,
458+
__global const INPUT4_TYPE* context_lens,
459+
__global const INPUT5_TYPE* block_tables,
460+
__global const INPUT6_TYPE* scale,
461+
__global OUTPUT_TYPE* output,
462+
__global OUTPUT_TYPE* exp_sums,
463+
__global OUTPUT_TYPE* max_logits,
464+
__global OUTPUT_TYPE* tmp_out,
465+
uint num_of_portions) {
466+
if (num_of_portions <= SUB_GROUP_SIZE) {
467+
const uint seq_idx = get_global_id(0);
468+
const uint head_num_idx = get_global_id(1);
469+
const uint head_idx = get_global_id(2);
470+
const uint sglid = get_sub_group_local_id();
471+
472+
const uint exp_sums_offset = seq_idx * HEADS_NUM * num_of_portions +
473+
head_num_idx * num_of_portions;
474+
const uint max_logit_offset = exp_sums_offset;
475+
476+
OUTPUT_TYPE exp_sum = BLOCK_READN(OUTPUT_TYPE, 1, exp_sums, exp_sums_offset);
477+
OUTPUT_TYPE max_logit = BLOCK_READN(OUTPUT_TYPE, 1, max_logits, max_logit_offset);
478+
if (sglid >= num_of_portions) {
479+
exp_sum = 0;
480+
max_logit = OUTPUT_VAL_MIN;
481+
}
482+
483+
OUTPUT_TYPE global_max = sub_group_reduce_max(max_logit);
484+
485+
// Update exp_sum with respect to the global maximum
486+
OUTPUT_TYPE test_exp_sum = exp_sum;
487+
if (sglid < num_of_portions)
488+
exp_sum = exp_sum * native_exp(max_logit - global_max);
489+
490+
OUTPUT_TYPE global_sum = sub_group_reduce_add(exp_sum);
491+
492+
if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
493+
printf("Run second kernel for reduction: num_of_portions=%d: max_logit=%f, exp_sum = %f, global_sum = %f, global_max=%f, test = %f, %f, %f\n", num_of_portions,
494+
max_logit, exp_sum, global_sum, global_max, test_exp_sum, native_exp(max_logit - global_max), test_exp_sum * native_exp(max_logit - global_max));
495+
496+
for (uint i = 0; i < HEAD_SIZE / SUB_GROUP_SIZE; i++) {
497+
OUTPUT_TYPE acc = OUTPUT_VAL_ZERO;
498+
for (uint portion = 0; portion < num_of_portions; portion++) {
499+
const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * num_of_portions) +
500+
head_num_idx * (HEAD_SIZE * num_of_portions) +
501+
portion * HEAD_SIZE;
502+
OUTPUT_TYPE out_val = BLOCK_READN(OUTPUT_TYPE, 1, tmp_out, tmp_out_offset);
503+
acc += out_val * sub_group_broadcast(exp_sum, portion) / global_sum;
504+
}
505+
const uint out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE) +
506+
head_num_idx * HEAD_SIZE +
507+
i * SUB_GROUP_SIZE;
508+
output[out_offset] = acc;
509+
}
510+
} else {
511+
if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
512+
printf("run second kernel for portion >= 16\n");
513+
}
514+
}
515+
516+
#endif

0 commit comments

Comments
 (0)