Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d3665ce

Browse files
committedFeb 12, 2025·
Use more threads
1 parent a517444 commit d3665ce

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed
 

‎src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+47-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "include/batch_headers/sub_group_block_write.cl"
88
#include "include/batch_headers/sub_group_shuffle.cl"
99

10-
#define SUBGROUPS_PER_WG (HEAD_SIZE / SUBGROUP_SIZE)
10+
#define SUBGROUPS_PER_WG ((HEAD_SIZE / SUBGROUP_SIZE) * SG_SCALE_FACTOR)
1111
#define PAGED_ATTENTION_BLOCKS_PER_PARTITION (SEQ_LEN_PARTITION_SIZE / PAGED_ATTENTION_BLOCK_SIZE)
1212

1313
#if HEAD_SIZE > 128
@@ -75,7 +75,7 @@ KERNEL(pa_sdpa_opt)(
7575

7676
const uint seq_idx = get_global_id(0);
7777
const uint head_num_idx = get_global_id(1);
78-
const uint head_size_idx = get_global_id(2);
78+
const uint head_size_idx = get_local_id(2);
7979
const uint sglid = get_sub_group_local_id();
8080
const uint sgid = get_sub_group_id();
8181
const uint total_partitions_num = get_num_groups(2);
@@ -93,7 +93,6 @@ KERNEL(pa_sdpa_opt)(
9393
#endif
9494

9595
const uint partition_idx = get_group_id(2);
96-
const uint block_start_idx = partition_idx * SEQ_LEN_PARTITION_SIZE / PAGED_ATTENTION_BLOCK_SIZE;
9796

9897
if (partition_idx * SEQ_LEN_PARTITION_SIZE >= seq_len) {
9998
return;
@@ -336,6 +335,15 @@ KERNEL(pa_sdpa_opt)(
336335
OUTPUT_TYPE acc = OUTPUT_VAL_ZERO;
337336

338337
const uint partition_seq_len = min(seq_len - partition_idx * SEQ_LEN_PARTITION_SIZE, (uint)SEQ_LEN_PARTITION_SIZE);
338+
339+
#if SG_SCALE_FACTOR > 1
340+
const uint block_start_idx = (sgid / (SUBGROUPS_PER_WG / SG_SCALE_FACTOR)) * (SEQ_LEN_PARTITION_SIZE / SG_SCALE_FACTOR / SUBGROUP_SIZE);
341+
const uint block_end_idx = min(block_start_idx + (SEQ_LEN_PARTITION_SIZE / SG_SCALE_FACTOR / SUBGROUP_SIZE), partition_seq_len / SUBGROUP_SIZE);
342+
#else
343+
const uint block_start_idx = 0;
344+
const uint block_end_idx = partition_seq_len / SUBGROUP_SIZE;
345+
#endif
346+
339347
uint blocks_num_per_partition = min(total_blocks_num - partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION, (uint)PAGED_ATTENTION_BLOCKS_PER_PARTITION);
340348

341349
uint leftovers = blocks_num_per_partition * PAGED_ATTENTION_BLOCK_SIZE - partition_seq_len;
@@ -346,7 +354,7 @@ KERNEL(pa_sdpa_opt)(
346354

347355
const uint start_block_idx = block_indices_begins[subsequence_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION;
348356

349-
for (uint block_num = 0; block_num < blocks_num_per_partition; block_num++) {
357+
for (uint block_num = block_start_idx; block_num < block_end_idx; block_num++) {
350358
#ifdef BROADCAST_GROUP_SIZE
351359
const uint head_idx = head_num_idx / BROADCAST_GROUP_SIZE;
352360
#else
@@ -389,6 +397,10 @@ KERNEL(pa_sdpa_opt)(
389397
}
390398
}
391399

400+
401+
#if SG_SCALE_FACTOR > 1
402+
if (sgid >= SUBGROUPS_PER_WG / SG_SCALE_FACTOR) {
403+
#endif
392404
if (leftovers != 0) {
393405
#ifdef BROADCAST_GROUP_SIZE
394406
const uint head_idx = head_num_idx / BROADCAST_GROUP_SIZE;
@@ -429,6 +441,32 @@ KERNEL(pa_sdpa_opt)(
429441
}
430442
}
431443

444+
445+
#if SG_SCALE_FACTOR > 1
446+
}
447+
#endif
448+
449+
#if SG_SCALE_FACTOR > 1
450+
if ((partition_seq_len > (SEQ_LEN_PARTITION_SIZE / SG_SCALE_FACTOR)) || (leftovers != 0)) {
451+
barrier(CLK_LOCAL_MEM_FENCE);
452+
453+
if (sgid >= SUBGROUPS_PER_WG / SG_SCALE_FACTOR) {
454+
// Reuse slm_qk_vals SLM to sum-up results between two groups of subgroups
455+
slm_qk_vals[head_size_idx] = acc;
456+
}
457+
458+
barrier(CLK_LOCAL_MEM_FENCE);
459+
460+
if (sgid < SUBGROUPS_PER_WG / SG_SCALE_FACTOR) {
461+
acc += slm_qk_vals[head_size_idx];
462+
}
463+
}
464+
#endif
465+
466+
#if SG_SCALE_FACTOR > 1
467+
if (sgid < SUBGROUPS_PER_WG / SG_SCALE_FACTOR) {
468+
#endif
469+
432470
if (seq_len > SEQ_LEN_PARTITION_SIZE) {
433471
const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * total_partitions_num) +
434472
head_num_idx * (HEAD_SIZE * total_partitions_num) +
@@ -446,6 +484,11 @@ KERNEL(pa_sdpa_opt)(
446484
output[output_offset] = acc;
447485
}
448486

487+
#if SG_SCALE_FACTOR > 1
488+
}
489+
#endif
490+
491+
449492
}
450493
}
451494

‎src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp

+17-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ constexpr size_t subgroup_size = 16;
2424
constexpr size_t seq_len_partition_size = 256;
2525
constexpr size_t paged_attention_block_size = 16;
2626
constexpr Datatype softmax_acc_dt = Datatype::F32;
27+
28+
size_t get_sg_number_scale_factor(const Params& params, size_t head_size, size_t kernel_type) {
29+
return 1;
30+
const size_t optimal_scale_factor = 2;
31+
if (kernel_type == KernelsTypes::SINGLE_TOKEN ||
32+
kernel_type == KernelsTypes::MULTI_TOKENS) {
33+
if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize) {
34+
return optimal_scale_factor;
35+
}
36+
}
37+
38+
return 1;
39+
}
2740
} // namespace
2841

2942
static std::string GetKernelName(std::string base_name, KernelsTypes type) {
@@ -211,6 +224,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
211224
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size));
212225
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
213226
jit.AddConstant(MakeJitConstant("IS_KV_COMPRESSED", params.conf.is_kv_compressed));
227+
jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, config.head_size, kernel_idx)));
214228

215229
if (params.conf.is_kv_compressed) {
216230
auto scales_zp_size = 2 * 2; // FP16 * (scale + zp)
@@ -272,10 +286,11 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
272286
const size_t head_size = static_cast<size_t>(params.conf.head_size);
273287

274288
if (kernel_idx == KernelsTypes::SINGLE_TOKEN || kernel_idx == KernelsTypes::MULTI_TOKENS) {
289+
auto sg_scale = get_sg_number_scale_factor(params, head_size, kernel_idx);
275290
dispatch_data.gws = { total_tokens,
276291
heads_num,
277-
head_size * num_of_partitions };
278-
dispatch_data.lws = { 1, 1, head_size };
292+
head_size * num_of_partitions * sg_scale };
293+
dispatch_data.lws = { 1, 1, head_size * sg_scale };
279294
} else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
280295
const auto& past_lens = params.inputs[3];
281296
const auto subsequences_number = past_lens.Batch().v;

0 commit comments

Comments
 (0)
Please sign in to comment.