@@ -13,13 +13,28 @@ namespace kernel_selector {
13
13
namespace {
14
14
enum KernelsTypes {
15
15
SINGLE_TOKEN = 0 ,
16
+ SINGLE_TOKEN_GQA,
16
17
MULTI_TOKENS,
17
18
FINALIZATION,
18
19
FINALIZATION_MULTI_TOKENS,
19
20
SCORES_CALCULATION,
20
21
TOTAL_KERNELS_NUM
21
22
};
22
23
24
+ static size_t get_heads_per_wi (const pa_sdpa_params& params) {
25
+ if (params.conf .kv_group_size > 1 ) {
26
+ std::vector<size_t > preferable_head_nums = {4 , 3 , 2 };
27
+ for (const auto & heads_num : preferable_head_nums) {
28
+ const auto leftovers = params.conf .kv_group_size % heads_num;
29
+ if (leftovers == 0 || heads_num - leftovers <= 1 ) {
30
+ return heads_num;
31
+ }
32
+ }
33
+ }
34
+
35
+ return 1 ;
36
+ }
37
+
23
38
constexpr size_t subgroup_size = 16 ;
24
39
constexpr size_t seq_len_partition_size = 256 ;
25
40
constexpr size_t paged_attention_block_size = 16 ;
@@ -29,6 +44,7 @@ size_t get_sg_number_scale_factor(const pa_sdpa_params& params, size_t head_size
29
44
if (params.conf .is_kv_compressed ) {
30
45
const size_t optimal_scale_factor = 2 ;
31
46
if (kernel_type == KernelsTypes::SINGLE_TOKEN ||
47
+ kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ||
32
48
kernel_type == KernelsTypes::MULTI_TOKENS) {
33
49
if (head_size * optimal_scale_factor <= params.engineInfo .maxWorkGroupSize ) {
34
50
return optimal_scale_factor;
@@ -45,6 +61,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) {
45
61
46
62
if (type == KernelsTypes::SINGLE_TOKEN) {
47
63
kernel_name += " _single_token" ;
64
+ } else if (type == KernelsTypes::SINGLE_TOKEN_GQA) {
65
+ kernel_name += " _single_token_gqa" ;
48
66
} else if (type == KernelsTypes::MULTI_TOKENS) {
49
67
kernel_name += " _multi_tokens_seq" ;
50
68
} else if (type == KernelsTypes::FINALIZATION) {
@@ -65,6 +83,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
65
83
66
84
const auto & params = static_cast <const pa_sdpa_params&>(p);
67
85
std::vector<KernelsTypes> kernels_type = { KernelsTypes::SINGLE_TOKEN,
86
+ KernelsTypes::SINGLE_TOKEN_GQA,
68
87
KernelsTypes::MULTI_TOKENS,
69
88
KernelsTypes::FINALIZATION,
70
89
KernelsTypes::FINALIZATION_MULTI_TOKENS };
@@ -90,7 +109,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
90
109
91
110
int inputs_num = static_cast <int >(params.inputs .size ());
92
111
int outputs_num = 1 ;
93
- if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
112
+ if (kernel_type == KernelsTypes::SINGLE_TOKEN || kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ) {
94
113
// SINGLE_TOKEN kernel doesn't use the subsequence_begins input
95
114
inputs_num -= 1 ;
96
115
} else if (kernel_type == KernelsTypes::FINALIZATION) {
@@ -221,6 +240,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
221
240
jit.AddConstant (MakeJitConstant (" HEAD_SIZE" , config.head_size ));
222
241
jit.AddConstant (MakeJitConstant (" HEADS_NUM" , config.heads_num ));
223
242
jit.AddConstant (MakeJitConstant (" KV_HEADS_NUM" , config.kv_heads_num ));
243
+ jit.AddConstant (MakeJitConstant (" KV_HEADS_GROUP_SIZE" , config.kv_group_size ));
224
244
jit.AddConstant (MakeJitConstant (" SEQ_LEN_PARTITION_SIZE" , seq_len_partition_size));
225
245
jit.AddConstant (MakeJitConstant (" PAGED_ATTENTION_BLOCK_SIZE" , paged_attention_block_size));
226
246
jit.AddConstant (MakeJitConstant (" SUBGROUP_SIZE" , subgroup_size));
@@ -236,8 +256,13 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
236
256
jit.AddConstant (MakeJitConstant (" ADJUSTED_HEAD_SIZE" , params.conf .head_size ));
237
257
}
238
258
239
- if (config.broadcast_axis != -1 ) {
240
- jit.AddConstant (MakeJitConstant (" BROADCAST_GROUP_SIZE" , config.group_size ));
259
+ if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
260
+ auto heads_per_wi = get_heads_per_wi (params);
261
+ jit.AddConstant (MakeJitConstant (" HEADS_PER_WI" , heads_per_wi));
262
+ jit.AddConstant (MakeJitConstant (" ITERATIONS_PER_KV_HEADS_GROUP" , CeilDiv (config.kv_group_size , heads_per_wi)));
263
+ jit.AddConstant (MakeJitConstant (" HEADS_LEFTOVERS_NUM" , config.kv_group_size % heads_per_wi));
264
+ } else {
265
+ jit.AddConstant (MakeJitConstant (" HEADS_PER_WI" , 1 ));
241
266
}
242
267
243
268
auto sdpa_stage = 0 ;
@@ -293,6 +318,16 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
293
318
heads_num,
294
319
head_size * num_of_partitions * sg_scale };
295
320
dispatch_data.lws = { 1 , 1 , head_size * sg_scale };
321
+ } else if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
322
+ auto sg_scale = get_sg_number_scale_factor (params, head_size, kernel_idx);
323
+
324
+ auto kv_groups = heads_num / params.conf .kv_group_size ;
325
+ auto gqa_heads_num = kv_groups * CeilDiv (params.conf .kv_group_size , get_heads_per_wi (params));
326
+
327
+ dispatch_data.gws = { total_tokens,
328
+ gqa_heads_num,
329
+ head_size * num_of_partitions * sg_scale };
330
+ dispatch_data.lws = { 1 , 1 , head_size * sg_scale };
296
331
} else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
297
332
const auto & past_lens = params.inputs [3 ];
298
333
const auto subsequences_number = past_lens.Batch ().v ;
@@ -334,13 +369,30 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
334
369
const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output;
335
370
const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED;
336
371
337
- auto dispatch_data1 = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN);
338
- kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .global = dispatch_data1.gws ;
339
- kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .local = dispatch_data1.lws ;
340
- kd.kernels [KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only;
341
-
342
- kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .global = dispatch_data1.gws ;
343
- kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .local = dispatch_data1.lws ;
372
+ // Apply GQA optimization starting from a certain sequence length value
373
+ const auto min_gqa_sequence_len = 8 * seq_len_partition_size;
374
+ // Apply GQA only if there is a single subsequence in the request,
375
+ // as multiple subsequences might have significantly different lengths
376
+ const auto max_subsequences_num = 1 ;
377
+ const auto subsequences_num = prim_params.inputs [0 ].Batch ().v ;
378
+ const auto can_use_gqa_kernel = prim_params.conf .paged_attention_max_len >= static_cast <int64_t >(min_gqa_sequence_len) &&
379
+ subsequences_num <= max_subsequences_num &&
380
+ prim_params.conf .kv_group_size > 1 &&
381
+ !multi_tokens_mode &&
382
+ !scores_calc_only;
383
+
384
+ auto dispatch_data = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN_GQA);
385
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].params .workGroups .global = dispatch_data.gws ;
386
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].params .workGroups .local = dispatch_data.lws ;
387
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].skip_execution = multi_tokens_mode || scores_calc_only || !can_use_gqa_kernel;
388
+
389
+ dispatch_data = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN);
390
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .global = dispatch_data.gws ;
391
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .local = dispatch_data.lws ;
392
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only || can_use_gqa_kernel;
393
+
394
+ kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .global = dispatch_data.gws ;
395
+ kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .local = dispatch_data.lws ;
344
396
kd.kernels [KernelsTypes::MULTI_TOKENS].skip_execution = !multi_tokens_mode || scores_calc_only;
345
397
346
398
size_t partition_size = 0 ;
@@ -351,13 +403,13 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
351
403
}
352
404
const size_t num_of_partitions = CeilDiv (prim_params.conf .paged_attention_max_len , partition_size);
353
405
354
- auto dispatch_data2 = SetDefault (prim_params, KernelsTypes::FINALIZATION);
355
- kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .global = dispatch_data2 .gws ;
356
- kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .local = dispatch_data2 .lws ;
406
+ dispatch_data = SetDefault (prim_params, KernelsTypes::FINALIZATION);
407
+ kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .global = dispatch_data .gws ;
408
+ kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .local = dispatch_data .lws ;
357
409
kd.kernels [KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || multi_tokens_mode || scores_calc_only;
358
410
359
- kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .global = dispatch_data2 .gws ;
360
- kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .local = dispatch_data2 .lws ;
411
+ kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .global = dispatch_data .gws ;
412
+ kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .local = dispatch_data .lws ;
361
413
kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !multi_tokens_mode || scores_calc_only;
362
414
363
415
ScalarDescriptor num_of_partitions_scalar;
@@ -369,7 +421,7 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
369
421
kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .scalars [0 ] = num_of_partitions_scalar;
370
422
371
423
if (has_scores_output) {
372
- auto dispatch_data = SetDefault (prim_params, KernelsTypes::SCORES_CALCULATION);
424
+ dispatch_data = SetDefault (prim_params, KernelsTypes::SCORES_CALCULATION);
373
425
kd.kernels [KernelsTypes::SCORES_CALCULATION].params .workGroups .global = dispatch_data.gws ;
374
426
kd.kernels [KernelsTypes::SCORES_CALCULATION].params .workGroups .local = dispatch_data.lws ;
375
427
kd.kernels [KernelsTypes::SCORES_CALCULATION].skip_execution = false ;
0 commit comments