@@ -13,13 +13,29 @@ namespace kernel_selector {
13
13
namespace {
14
14
enum KernelsTypes {
15
15
SINGLE_TOKEN = 0 ,
16
+ SINGLE_TOKEN_GQA,
16
17
MULTI_TOKENS,
17
18
FINALIZATION,
18
19
FINALIZATION_MULTI_TOKENS,
19
20
SCORES_CALCULATION,
20
21
TOTAL_KERNELS_NUM
21
22
};
22
23
24
+ static size_t get_heads_per_wi (const pa_sdpa_params& params) {
25
+ if (params.conf .kv_group_size > 1 ) {
26
+ std::vector<size_t > preferable_head_nums = {4 , 3 , 2 };
27
+ for (const auto & heads_num : preferable_head_nums) {
28
+ const auto leftovers = params.conf .kv_group_size % heads_num;
29
+ if (leftovers == 0 || heads_num - leftovers <= 1 ) {
30
+ return heads_num;
31
+ }
32
+ }
33
+ }
34
+
35
+ return 1 ;
36
+ }
37
+
38
+ constexpr size_t heads_per_iteration = 4 ;
23
39
constexpr size_t subgroup_size = 16 ;
24
40
constexpr size_t seq_len_partition_size = 256 ;
25
41
constexpr size_t paged_attention_block_size = 16 ;
@@ -29,6 +45,7 @@ size_t get_sg_number_scale_factor(const pa_sdpa_params& params, size_t head_size
29
45
if (params.conf .is_kv_compressed ) {
30
46
const size_t optimal_scale_factor = 2 ;
31
47
if (kernel_type == KernelsTypes::SINGLE_TOKEN ||
48
+ kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ||
32
49
kernel_type == KernelsTypes::MULTI_TOKENS) {
33
50
if (head_size * optimal_scale_factor <= params.engineInfo .maxWorkGroupSize ) {
34
51
return optimal_scale_factor;
@@ -45,6 +62,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) {
45
62
46
63
if (type == KernelsTypes::SINGLE_TOKEN) {
47
64
kernel_name += " _single_token" ;
65
+ } else if (type == KernelsTypes::SINGLE_TOKEN_GQA) {
66
+ kernel_name += " _single_token_gqa" ;
48
67
} else if (type == KernelsTypes::MULTI_TOKENS) {
49
68
kernel_name += " _multi_tokens_seq" ;
50
69
} else if (type == KernelsTypes::FINALIZATION) {
@@ -65,6 +84,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
65
84
66
85
const auto & params = static_cast <const pa_sdpa_params&>(p);
67
86
std::vector<KernelsTypes> kernels_type = { KernelsTypes::SINGLE_TOKEN,
87
+ KernelsTypes::SINGLE_TOKEN_GQA,
68
88
KernelsTypes::MULTI_TOKENS,
69
89
KernelsTypes::FINALIZATION,
70
90
KernelsTypes::FINALIZATION_MULTI_TOKENS };
@@ -90,7 +110,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
90
110
91
111
int inputs_num = static_cast <int >(params.inputs .size ());
92
112
int outputs_num = 1 ;
93
- if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
113
+ if (kernel_type == KernelsTypes::SINGLE_TOKEN || kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ) {
94
114
// SINGLE_TOKEN kernel doesn't use the subsequence_begins input
95
115
inputs_num -= 1 ;
96
116
} else if (kernel_type == KernelsTypes::FINALIZATION) {
@@ -221,6 +241,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
221
241
jit.AddConstant (MakeJitConstant (" HEAD_SIZE" , config.head_size ));
222
242
jit.AddConstant (MakeJitConstant (" HEADS_NUM" , config.heads_num ));
223
243
jit.AddConstant (MakeJitConstant (" KV_HEADS_NUM" , config.kv_heads_num ));
244
+ jit.AddConstant (MakeJitConstant (" KV_HEADS_GROUP_SIZE" , config.kv_group_size ));
224
245
jit.AddConstant (MakeJitConstant (" SEQ_LEN_PARTITION_SIZE" , seq_len_partition_size));
225
246
jit.AddConstant (MakeJitConstant (" PAGED_ATTENTION_BLOCK_SIZE" , paged_attention_block_size));
226
247
jit.AddConstant (MakeJitConstant (" SUBGROUP_SIZE" , subgroup_size));
@@ -236,8 +257,13 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
236
257
jit.AddConstant (MakeJitConstant (" ADJUSTED_HEAD_SIZE" , params.conf .head_size ));
237
258
}
238
259
239
- if (config.broadcast_axis != -1 ) {
240
- jit.AddConstant (MakeJitConstant (" BROADCAST_GROUP_SIZE" , config.group_size ));
260
+ if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
261
+ auto heads_per_wi = get_heads_per_wi (params);
262
+ jit.AddConstant (MakeJitConstant (" HEADS_PER_WI" , heads_per_wi));
263
+ jit.AddConstant (MakeJitConstant (" ITERATIONS_PER_KV_HEADS_GROUP" , CeilDiv (config.kv_group_size , heads_per_wi)));
264
+ jit.AddConstant (MakeJitConstant (" HEADS_LEFTOVERS_NUM" , config.kv_group_size % heads_per_wi));
265
+ } else {
266
+ jit.AddConstant (MakeJitConstant (" HEADS_PER_WI" , 1 ));
241
267
}
242
268
243
269
auto sdpa_stage = 0 ;
@@ -293,6 +319,16 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
293
319
heads_num,
294
320
head_size * num_of_partitions * sg_scale };
295
321
dispatch_data.lws = { 1 , 1 , head_size * sg_scale };
322
+ } else if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
323
+ auto sg_scale = get_sg_number_scale_factor (params, head_size, kernel_idx);
324
+
325
+ auto kv_groups = heads_num / params.conf .kv_group_size ;
326
+ auto gqa_heads_num = kv_groups * CeilDiv (params.conf .kv_group_size , get_heads_per_wi (params));
327
+
328
+ dispatch_data.gws = { total_tokens,
329
+ gqa_heads_num,
330
+ head_size * num_of_partitions * sg_scale };
331
+ dispatch_data.lws = { 1 , 1 , head_size * sg_scale };
296
332
} else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
297
333
const auto & past_lens = params.inputs [3 ];
298
334
const auto subsequences_number = past_lens.Batch ().v ;
@@ -334,13 +370,30 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
334
370
const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output;
335
371
const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED;
336
372
337
- auto dispatch_data1 = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN);
338
- kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .global = dispatch_data1.gws ;
339
- kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .local = dispatch_data1.lws ;
340
- kd.kernels [KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only;
341
-
342
- kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .global = dispatch_data1.gws ;
343
- kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .local = dispatch_data1.lws ;
373
+ // Apply GQA optimization starting from a certain sequence length value
374
+ const auto min_gqa_sequence_len = 8 * seq_len_partition_size;
375
+ // Apply GQA only if there is a single subsequence in the request,
376
+ // as multiple subsequences might have significantly different lengths
377
+ const auto max_subsequences_num = 1 ;
378
+ const auto subsequences_num = prim_params.inputs [0 ].Batch ().v ;
379
+ const auto can_use_gqa_kernel = prim_params.conf .paged_attention_max_len >= static_cast <int64_t >(min_gqa_sequence_len) &&
380
+ subsequences_num <= max_subsequences_num &&
381
+ prim_params.conf .kv_group_size > 1 &&
382
+ !multi_tokens_mode &&
383
+ !scores_calc_only;
384
+
385
+ auto dispatch_data = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN_GQA);
386
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].params .workGroups .global = dispatch_data.gws ;
387
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].params .workGroups .local = dispatch_data.lws ;
388
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].skip_execution = multi_tokens_mode || scores_calc_only || !can_use_gqa_kernel;
389
+
390
+ dispatch_data = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN);
391
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .global = dispatch_data.gws ;
392
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .local = dispatch_data.lws ;
393
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only || can_use_gqa_kernel;
394
+
395
+ kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .global = dispatch_data.gws ;
396
+ kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .local = dispatch_data.lws ;
344
397
kd.kernels [KernelsTypes::MULTI_TOKENS].skip_execution = !multi_tokens_mode || scores_calc_only;
345
398
346
399
size_t partition_size = 0 ;
@@ -351,13 +404,13 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
351
404
}
352
405
const size_t num_of_partitions = CeilDiv (prim_params.conf .paged_attention_max_len , partition_size);
353
406
354
- auto dispatch_data2 = SetDefault (prim_params, KernelsTypes::FINALIZATION);
355
- kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .global = dispatch_data2 .gws ;
356
- kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .local = dispatch_data2 .lws ;
407
+ dispatch_data = SetDefault (prim_params, KernelsTypes::FINALIZATION);
408
+ kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .global = dispatch_data .gws ;
409
+ kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .local = dispatch_data .lws ;
357
410
kd.kernels [KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || multi_tokens_mode || scores_calc_only;
358
411
359
- kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .global = dispatch_data2 .gws ;
360
- kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .local = dispatch_data2 .lws ;
412
+ kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .global = dispatch_data .gws ;
413
+ kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .local = dispatch_data .lws ;
361
414
kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !multi_tokens_mode || scores_calc_only;
362
415
363
416
ScalarDescriptor num_of_partitions_scalar;
@@ -369,7 +422,7 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
369
422
kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .scalars [0 ] = num_of_partitions_scalar;
370
423
371
424
if (has_scores_output) {
372
- auto dispatch_data = SetDefault (prim_params, KernelsTypes::SCORES_CALCULATION);
425
+ dispatch_data = SetDefault (prim_params, KernelsTypes::SCORES_CALCULATION);
373
426
kd.kernels [KernelsTypes::SCORES_CALCULATION].params .workGroups .global = dispatch_data.gws ;
374
427
kd.kernels [KernelsTypes::SCORES_CALCULATION].params .workGroups .local = dispatch_data.lws ;
375
428
kd.kernels [KernelsTypes::SCORES_CALCULATION].skip_execution = false ;
0 commit comments