@@ -13,22 +13,67 @@ namespace kernel_selector {
13
13
namespace {
14
14
enum KernelsTypes {
15
15
SINGLE_TOKEN = 0 ,
16
+ SINGLE_TOKEN_GQA,
16
17
MULTI_TOKENS,
17
18
FINALIZATION,
18
19
FINALIZATION_MULTI_TOKENS,
19
20
SCORES_CALCULATION,
20
21
TOTAL_KERNELS_NUM
21
22
};
22
23
24
+ static size_t get_heads_per_iteration (const pa_sdpa_params& params) {
25
+ int HEADS_PER_ITER = 0 ;
26
+ if (const auto env_var = std::getenv (" HEADS_PER_ITER" )) {
27
+ std::istringstream ss (env_var);
28
+ ss >> HEADS_PER_ITER;
29
+ static bool printed = false ;
30
+ if (!printed) {
31
+ std::cout << " Set HEADS_PER_ITER=" << HEADS_PER_ITER << " \n " ;
32
+ printed = true ;
33
+ }
34
+ }
35
+ if (HEADS_PER_ITER) {
36
+ return HEADS_PER_ITER;
37
+ }
38
+
39
+ if (params.conf .kv_group_size > 1 ) {
40
+ std::vector<size_t > preferable_heads_combined = {4 , 3 , 2 };
41
+ for (const auto & heads_num : preferable_heads_combined) {
42
+ const auto leftovers = params.conf .kv_group_size % heads_num;
43
+ if (leftovers == 0 || heads_num - leftovers <= 1 ) {
44
+ return heads_num;
45
+ }
46
+ }
47
+ }
48
+
49
+ return 1 ;
50
+ }
51
+
52
+ constexpr size_t heads_per_iteration = 4 ;
23
53
constexpr size_t subgroup_size = 16 ;
24
54
constexpr size_t seq_len_partition_size = 256 ;
25
55
constexpr size_t paged_attention_block_size = 16 ;
26
56
constexpr Datatype softmax_acc_dt = Datatype::F32;
27
57
28
58
size_t get_sg_number_scale_factor (const pa_sdpa_params& params, size_t head_size, size_t kernel_type) {
59
+ int SG_SCALE = 0 ;
60
+ if (const auto env_var = std::getenv (" SG_SCALE" )) {
61
+ std::istringstream ss (env_var);
62
+ ss >> SG_SCALE;
63
+ static bool printed = false ;
64
+ if (!printed) {
65
+ std::cout << " Set SG_SCALE=" << SG_SCALE << " \n " ;
66
+ printed = true ;
67
+ }
68
+ }
69
+
70
+ if (SG_SCALE != 0 )
71
+ return SG_SCALE;
72
+
29
73
if (params.conf .is_kv_compressed ) {
30
74
const size_t optimal_scale_factor = 2 ;
31
75
if (kernel_type == KernelsTypes::SINGLE_TOKEN ||
76
+ kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ||
32
77
kernel_type == KernelsTypes::MULTI_TOKENS) {
33
78
if (head_size * optimal_scale_factor <= params.engineInfo .maxWorkGroupSize ) {
34
79
return optimal_scale_factor;
@@ -45,6 +90,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) {
45
90
46
91
if (type == KernelsTypes::SINGLE_TOKEN) {
47
92
kernel_name += " _single_token" ;
93
+ } else if (type == KernelsTypes::SINGLE_TOKEN_GQA) {
94
+ kernel_name += " _single_token_gqa" ;
48
95
} else if (type == KernelsTypes::MULTI_TOKENS) {
49
96
kernel_name += " _multi_tokens_seq" ;
50
97
} else if (type == KernelsTypes::FINALIZATION) {
@@ -65,6 +112,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
65
112
66
113
const auto & params = static_cast <const pa_sdpa_params&>(p);
67
114
std::vector<KernelsTypes> kernels_type = { KernelsTypes::SINGLE_TOKEN,
115
+ KernelsTypes::SINGLE_TOKEN_GQA,
68
116
KernelsTypes::MULTI_TOKENS,
69
117
KernelsTypes::FINALIZATION,
70
118
KernelsTypes::FINALIZATION_MULTI_TOKENS };
@@ -90,7 +138,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
90
138
91
139
int inputs_num = static_cast <int >(params.inputs .size ());
92
140
int outputs_num = 1 ;
93
- if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
141
+ if (kernel_type == KernelsTypes::SINGLE_TOKEN || kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ) {
94
142
// SINGLE_TOKEN kernel doesn't use the subsequence_begins input
95
143
inputs_num -= 1 ;
96
144
} else if (kernel_type == KernelsTypes::FINALIZATION) {
@@ -221,6 +269,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
221
269
jit.AddConstant (MakeJitConstant (" HEAD_SIZE" , config.head_size ));
222
270
jit.AddConstant (MakeJitConstant (" HEADS_NUM" , config.heads_num ));
223
271
jit.AddConstant (MakeJitConstant (" KV_HEADS_NUM" , config.kv_heads_num ));
272
+ jit.AddConstant (MakeJitConstant (" KV_HEADS_GROUP_SIZE" , config.kv_group_size ));
224
273
jit.AddConstant (MakeJitConstant (" SEQ_LEN_PARTITION_SIZE" , seq_len_partition_size));
225
274
jit.AddConstant (MakeJitConstant (" PAGED_ATTENTION_BLOCK_SIZE" , paged_attention_block_size));
226
275
jit.AddConstant (MakeJitConstant (" SUBGROUP_SIZE" , subgroup_size));
@@ -236,8 +285,21 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
236
285
jit.AddConstant (MakeJitConstant (" ADJUSTED_HEAD_SIZE" , params.conf .head_size ));
237
286
}
238
287
239
- if (config.broadcast_axis != -1 ) {
240
- jit.AddConstant (MakeJitConstant (" BROADCAST_GROUP_SIZE" , config.group_size ));
288
+ if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
289
+ jit.AddConstant (MakeJitConstant (" HEADS_PER_WI" , get_heads_per_iteration (params)));
290
+ jit.AddConstant (MakeJitConstant (" ITERATIONS_PER_KV_HEADS_GROUP" , CeilDiv (config.kv_group_size , get_heads_per_iteration (params))));
291
+ jit.AddConstant (MakeJitConstant (" HEADS_LEFTOVERS_NUM" , config.kv_group_size % get_heads_per_iteration (params)));
292
+
293
+ static bool print_once = true ;
294
+ if (print_once) {
295
+ std::cout << " KV_HEADS_GROUP_SIZE=" << config.kv_group_size << " \n " ;
296
+ std::cout << " HEADS_PER_WI=" << get_heads_per_iteration (params) << " \n " ;
297
+ std::cout << " ITERATIONS_PER_KV_HEADS_GROUP=" << CeilDiv (config.kv_group_size , get_heads_per_iteration (params)) << " \n " ;
298
+ std::cout << " HEADS_LEFTOVERS_NUM=" << config.kv_group_size % get_heads_per_iteration (params) << " \n " ;
299
+ print_once = false ;
300
+ }
301
+ } else {
302
+ jit.AddConstant (MakeJitConstant (" HEADS_PER_WI" , 1 ));
241
303
}
242
304
243
305
auto sdpa_stage = 0 ;
@@ -293,6 +355,16 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
293
355
heads_num,
294
356
head_size * num_of_partitions * sg_scale };
295
357
dispatch_data.lws = { 1 , 1 , head_size * sg_scale };
358
+ } else if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
359
+ auto sg_scale = get_sg_number_scale_factor (params, head_size, kernel_idx);
360
+
361
+ auto kv_groups = heads_num / params.conf .kv_group_size ;
362
+ auto gqa_heads_num = kv_groups * CeilDiv (params.conf .kv_group_size , get_heads_per_iteration (params));
363
+
364
+ dispatch_data.gws = { total_tokens,
365
+ gqa_heads_num,
366
+ head_size * num_of_partitions * sg_scale };
367
+ dispatch_data.lws = { 1 , 1 , head_size * sg_scale };
296
368
} else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
297
369
const auto & past_lens = params.inputs [3 ];
298
370
const auto subsequences_number = past_lens.Batch ().v ;
@@ -322,6 +394,24 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
322
394
return dispatch_data;
323
395
}
324
396
397
+ static size_t get_gqa_seq_len () {
398
+ int REQ_SEQ_LEN = 0 ;
399
+ if (const auto env_var = std::getenv (" REQ_SEQ_LEN" )) {
400
+ std::istringstream ss (env_var);
401
+ ss >> REQ_SEQ_LEN;
402
+ static bool printed = false ;
403
+ if (!printed) {
404
+ std::cout << " Set REQ_SEQ_LEN=" << REQ_SEQ_LEN << " \n " ;
405
+ printed = true ;
406
+ }
407
+ }
408
+
409
+ if (REQ_SEQ_LEN)
410
+ return REQ_SEQ_LEN;
411
+
412
+ return 8 * seq_len_partition_size;
413
+ }
414
+
325
415
void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc (KernelData& kd) const {
326
416
kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) {
327
417
const auto & prim_params = static_cast <const pa_sdpa_params&>(params);
@@ -334,13 +424,37 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
334
424
const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output;
335
425
const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED;
336
426
337
- auto dispatch_data1 = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN);
338
- kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .global = dispatch_data1.gws ;
339
- kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .local = dispatch_data1.lws ;
340
- kd.kernels [KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only;
341
-
342
- kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .global = dispatch_data1.gws ;
343
- kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .local = dispatch_data1.lws ;
427
+ // Apply GQA optimization starting from a certain sequence length value
428
+ const auto min_gqa_sequence_len = get_gqa_seq_len ();
429
+ // Apply GQA only if there is a single subsequence in the request,
430
+ // as multiple subsequences might have significantly different lengths
431
+ const auto max_subsequences_num = 1 ;
432
+ const auto subsequences_num = prim_params.inputs [0 ].Batch ().v ;
433
+ const auto can_use_gqa_kernel = prim_params.conf .paged_attention_max_len >= static_cast <int64_t >(min_gqa_sequence_len) &&
434
+ subsequences_num <= max_subsequences_num &&
435
+ prim_params.conf .kv_group_size > 1 &&
436
+ !multi_tokens_mode &&
437
+ !scores_calc_only;
438
+ // std::cout << "stage=" << prim_params.stage
439
+ // << " paged_attention_max_len=" << prim_params.conf.paged_attention_max_len
440
+ // << " (" << (prim_params.conf.paged_attention_max_len >= static_cast<int64_t>(required_gqa_sequence_len))
441
+ // << ", " << (subsequences_num <= max_subsequences_num)
442
+ // << ", " << (prim_params.conf.kv_group_size > 1)
443
+ // << ", " << (!multi_tokens_mode)
444
+ // << ", " << (!scores_calc_only) << ")\n";
445
+
446
+ auto dispatch_data = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN_GQA);
447
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].params .workGroups .global = dispatch_data.gws ;
448
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].params .workGroups .local = dispatch_data.lws ;
449
+ kd.kernels [KernelsTypes::SINGLE_TOKEN_GQA].skip_execution = multi_tokens_mode || scores_calc_only || !can_use_gqa_kernel;
450
+
451
+ dispatch_data = SetDefault (prim_params, KernelsTypes::SINGLE_TOKEN);
452
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .global = dispatch_data.gws ;
453
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].params .workGroups .local = dispatch_data.lws ;
454
+ kd.kernels [KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only || can_use_gqa_kernel;
455
+
456
+ kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .global = dispatch_data.gws ;
457
+ kd.kernels [KernelsTypes::MULTI_TOKENS].params .workGroups .local = dispatch_data.lws ;
344
458
kd.kernels [KernelsTypes::MULTI_TOKENS].skip_execution = !multi_tokens_mode || scores_calc_only;
345
459
346
460
size_t partition_size = 0 ;
@@ -351,13 +465,13 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
351
465
}
352
466
const size_t num_of_partitions = CeilDiv (prim_params.conf .paged_attention_max_len , partition_size);
353
467
354
- auto dispatch_data2 = SetDefault (prim_params, KernelsTypes::FINALIZATION);
355
- kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .global = dispatch_data2 .gws ;
356
- kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .local = dispatch_data2 .lws ;
468
+ dispatch_data = SetDefault (prim_params, KernelsTypes::FINALIZATION);
469
+ kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .global = dispatch_data .gws ;
470
+ kd.kernels [KernelsTypes::FINALIZATION].params .workGroups .local = dispatch_data .lws ;
357
471
kd.kernels [KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || multi_tokens_mode || scores_calc_only;
358
472
359
- kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .global = dispatch_data2 .gws ;
360
- kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .local = dispatch_data2 .lws ;
473
+ kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .global = dispatch_data .gws ;
474
+ kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .workGroups .local = dispatch_data .lws ;
361
475
kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !multi_tokens_mode || scores_calc_only;
362
476
363
477
ScalarDescriptor num_of_partitions_scalar;
@@ -369,7 +483,7 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
369
483
kd.kernels [KernelsTypes::FINALIZATION_MULTI_TOKENS].params .scalars [0 ] = num_of_partitions_scalar;
370
484
371
485
if (has_scores_output) {
372
- auto dispatch_data = SetDefault (prim_params, KernelsTypes::SCORES_CALCULATION);
486
+ dispatch_data = SetDefault (prim_params, KernelsTypes::SCORES_CALCULATION);
373
487
kd.kernels [KernelsTypes::SCORES_CALCULATION].params .workGroups .global = dispatch_data.gws ;
374
488
kd.kernels [KernelsTypes::SCORES_CALCULATION].params .workGroups .local = dispatch_data.lws ;
375
489
kd.kernels [KernelsTypes::SCORES_CALCULATION].skip_execution = false ;
0 commit comments