Skip to content

Commit bf83d4e

Browse files
committed
[GPU] GQA optimization
1 parent 9fa105e commit bf83d4e

File tree

8 files changed

+338
-109
lines changed

8 files changed

+338
-109
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
661661

662662
if (desc->heads_num != desc->kv_heads_num) {
663663
config.broadcast_axis = 1;
664-
config.group_size = desc->heads_num / desc->kv_heads_num;
664+
config.kv_group_size = desc->heads_num / desc->kv_heads_num;
665665
}
666666

667667
if (desc->has_scores_output() && !is_dynamic) {
@@ -1009,6 +1009,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
10091009
impl->use_micro_sdpa = true;
10101010
}
10111011

1012+
std::cout << "use_micro=" << impl->use_micro_sdpa << " Q_HEADS=" << desc->heads_num << " KV_HEADS= " << desc->kv_heads_num << " KV-cache layouts=["
1013+
<< impl_param.get_input_layout(3).to_short_string() << ", "
1014+
<< impl_param.get_input_layout(4).to_short_string() << "]\n";
1015+
10121016
return impl;
10131017
}
10141018

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
254254
if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) {
255255
if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) {
256256
config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim];
257-
config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
257+
config.kv_group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
258258
}
259259
}
260260

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+181-87
Large diffs are not rendered by default.

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp

+130-16
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,67 @@ namespace kernel_selector {
1313
namespace {
1414
enum KernelsTypes {
1515
SINGLE_TOKEN = 0,
16+
SINGLE_TOKEN_GQA,
1617
MULTI_TOKENS,
1718
FINALIZATION,
1819
FINALIZATION_MULTI_TOKENS,
1920
SCORES_CALCULATION,
2021
TOTAL_KERNELS_NUM
2122
};
2223

24+
static size_t get_heads_per_iteration(const pa_sdpa_params& params) {
25+
int HEADS_PER_ITER = 0;
26+
if (const auto env_var = std::getenv("HEADS_PER_ITER")) {
27+
std::istringstream ss(env_var);
28+
ss >> HEADS_PER_ITER;
29+
static bool printed = false;
30+
if (!printed) {
31+
std::cout << "Set HEADS_PER_ITER=" << HEADS_PER_ITER << "\n";
32+
printed = true;
33+
}
34+
}
35+
if (HEADS_PER_ITER) {
36+
return HEADS_PER_ITER;
37+
}
38+
39+
if (params.conf.kv_group_size > 1) {
40+
std::vector<size_t> preferable_heads_combined = {4, 3, 2};
41+
for (const auto& heads_num : preferable_heads_combined) {
42+
const auto leftovers = params.conf.kv_group_size % heads_num;
43+
if (leftovers == 0 || heads_num - leftovers <= 1) {
44+
return heads_num;
45+
}
46+
}
47+
}
48+
49+
return 1;
50+
}
51+
52+
constexpr size_t heads_per_iteration = 4;
2353
constexpr size_t subgroup_size = 16;
2454
constexpr size_t seq_len_partition_size = 256;
2555
constexpr size_t paged_attention_block_size = 16;
2656
constexpr Datatype softmax_acc_dt = Datatype::F32;
2757

2858
size_t get_sg_number_scale_factor(const pa_sdpa_params& params, size_t head_size, size_t kernel_type) {
59+
int SG_SCALE = 0;
60+
if (const auto env_var = std::getenv("SG_SCALE")) {
61+
std::istringstream ss(env_var);
62+
ss >> SG_SCALE;
63+
static bool printed = false;
64+
if (!printed) {
65+
std::cout << "Set SG_SCALE=" << SG_SCALE << "\n";
66+
printed = true;
67+
}
68+
}
69+
70+
if (SG_SCALE != 0)
71+
return SG_SCALE;
72+
2973
if (params.conf.is_kv_compressed) {
3074
const size_t optimal_scale_factor = 2;
3175
if (kernel_type == KernelsTypes::SINGLE_TOKEN ||
76+
kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ||
3277
kernel_type == KernelsTypes::MULTI_TOKENS) {
3378
if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize) {
3479
return optimal_scale_factor;
@@ -45,6 +90,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) {
4590

4691
if (type == KernelsTypes::SINGLE_TOKEN) {
4792
kernel_name += "_single_token";
93+
} else if (type == KernelsTypes::SINGLE_TOKEN_GQA) {
94+
kernel_name += "_single_token_gqa";
4895
} else if (type == KernelsTypes::MULTI_TOKENS) {
4996
kernel_name += "_multi_tokens_seq";
5097
} else if (type == KernelsTypes::FINALIZATION) {
@@ -65,6 +112,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
65112

66113
const auto& params = static_cast<const pa_sdpa_params&>(p);
67114
std::vector<KernelsTypes> kernels_type = { KernelsTypes::SINGLE_TOKEN,
115+
KernelsTypes::SINGLE_TOKEN_GQA,
68116
KernelsTypes::MULTI_TOKENS,
69117
KernelsTypes::FINALIZATION,
70118
KernelsTypes::FINALIZATION_MULTI_TOKENS };
@@ -90,7 +138,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
90138

91139
int inputs_num = static_cast<int>(params.inputs.size());
92140
int outputs_num = 1;
93-
if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
141+
if (kernel_type == KernelsTypes::SINGLE_TOKEN || kernel_type == KernelsTypes::SINGLE_TOKEN_GQA) {
94142
// SINGLE_TOKEN kernel doesn't use the subsequence_begins input
95143
inputs_num -= 1;
96144
} else if (kernel_type == KernelsTypes::FINALIZATION) {
@@ -221,6 +269,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
221269
jit.AddConstant(MakeJitConstant("HEAD_SIZE", config.head_size));
222270
jit.AddConstant(MakeJitConstant("HEADS_NUM", config.heads_num));
223271
jit.AddConstant(MakeJitConstant("KV_HEADS_NUM", config.kv_heads_num));
272+
jit.AddConstant(MakeJitConstant("KV_HEADS_GROUP_SIZE", config.kv_group_size));
224273
jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", seq_len_partition_size));
225274
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size));
226275
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
@@ -236,8 +285,21 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
236285
jit.AddConstant(MakeJitConstant("ADJUSTED_HEAD_SIZE", params.conf.head_size));
237286
}
238287

239-
if (config.broadcast_axis != -1) {
240-
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size));
288+
if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
289+
jit.AddConstant(MakeJitConstant("HEADS_PER_WI", get_heads_per_iteration(params)));
290+
jit.AddConstant(MakeJitConstant("ITERATIONS_PER_KV_HEADS_GROUP", CeilDiv(config.kv_group_size, get_heads_per_iteration(params))));
291+
jit.AddConstant(MakeJitConstant("HEADS_LEFTOVERS_NUM", config.kv_group_size % get_heads_per_iteration(params)));
292+
293+
static bool print_once = true;
294+
if (print_once) {
295+
std::cout << "KV_HEADS_GROUP_SIZE=" << config.kv_group_size << "\n";
296+
std::cout << "HEADS_PER_WI=" << get_heads_per_iteration(params) << "\n";
297+
std::cout << "ITERATIONS_PER_KV_HEADS_GROUP=" << CeilDiv(config.kv_group_size, get_heads_per_iteration(params)) << "\n";
298+
std::cout << "HEADS_LEFTOVERS_NUM=" << config.kv_group_size % get_heads_per_iteration(params) << "\n";
299+
print_once = false;
300+
}
301+
} else {
302+
jit.AddConstant(MakeJitConstant("HEADS_PER_WI", 1));
241303
}
242304

243305
auto sdpa_stage = 0;
@@ -293,6 +355,16 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
293355
heads_num,
294356
head_size * num_of_partitions * sg_scale };
295357
dispatch_data.lws = { 1, 1, head_size * sg_scale };
358+
} else if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
359+
auto sg_scale = get_sg_number_scale_factor(params, head_size, kernel_idx);
360+
361+
auto kv_groups = heads_num / params.conf.kv_group_size;
362+
auto gqa_heads_num = kv_groups * CeilDiv(params.conf.kv_group_size, get_heads_per_iteration(params));
363+
364+
dispatch_data.gws = { total_tokens,
365+
gqa_heads_num,
366+
head_size * num_of_partitions * sg_scale };
367+
dispatch_data.lws = { 1, 1, head_size * sg_scale };
296368
} else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
297369
const auto& past_lens = params.inputs[3];
298370
const auto subsequences_number = past_lens.Batch().v;
@@ -322,6 +394,24 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
322394
return dispatch_data;
323395
}
324396

397+
static size_t get_gqa_seq_len() {
398+
int REQ_SEQ_LEN = 0;
399+
if (const auto env_var = std::getenv("REQ_SEQ_LEN")) {
400+
std::istringstream ss(env_var);
401+
ss >> REQ_SEQ_LEN;
402+
static bool printed = false;
403+
if (!printed) {
404+
std::cout << "Set REQ_SEQ_LEN=" << REQ_SEQ_LEN << "\n";
405+
printed = true;
406+
}
407+
}
408+
409+
if (REQ_SEQ_LEN)
410+
return REQ_SEQ_LEN;
411+
412+
return 8 * seq_len_partition_size;
413+
}
414+
325415
void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const {
326416
kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) {
327417
const auto& prim_params = static_cast<const pa_sdpa_params&>(params);
@@ -334,13 +424,37 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
334424
const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output;
335425
const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED;
336426

337-
auto dispatch_data1 = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN);
338-
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data1.gws;
339-
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data1.lws;
340-
kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only;
341-
342-
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data1.gws;
343-
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data1.lws;
427+
// Apply GQA optimization starting from a certain sequence length value
428+
const auto min_gqa_sequence_len = get_gqa_seq_len();
429+
// Apply GQA only if there is a single subsequence in the request,
430+
// as multiple subsequences might have significantly different lengths
431+
const auto max_subsequences_num = 1;
432+
const auto subsequences_num = prim_params.inputs[0].Batch().v;
433+
const auto can_use_gqa_kernel = prim_params.conf.paged_attention_max_len >= static_cast<int64_t>(min_gqa_sequence_len) &&
434+
subsequences_num <= max_subsequences_num &&
435+
prim_params.conf.kv_group_size > 1 &&
436+
!multi_tokens_mode &&
437+
!scores_calc_only;
438+
// std::cout << "stage=" << prim_params.stage
439+
// << " paged_attention_max_len=" << prim_params.conf.paged_attention_max_len
440+
// << " (" << (prim_params.conf.paged_attention_max_len >= static_cast<int64_t>(required_gqa_sequence_len))
441+
// << ", " << (subsequences_num <= max_subsequences_num)
442+
// << ", " << (prim_params.conf.kv_group_size > 1)
443+
// << ", " << (!multi_tokens_mode)
444+
// << ", " << (!scores_calc_only) << ")\n";
445+
446+
auto dispatch_data = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN_GQA);
447+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].params.workGroups.global = dispatch_data.gws;
448+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].params.workGroups.local = dispatch_data.lws;
449+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].skip_execution = multi_tokens_mode || scores_calc_only || !can_use_gqa_kernel;
450+
451+
dispatch_data = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN);
452+
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data.gws;
453+
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data.lws;
454+
kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only || can_use_gqa_kernel;
455+
456+
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data.gws;
457+
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data.lws;
344458
kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !multi_tokens_mode || scores_calc_only;
345459

346460
size_t partition_size = 0;
@@ -351,13 +465,13 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
351465
}
352466
const size_t num_of_partitions = CeilDiv(prim_params.conf.paged_attention_max_len, partition_size);
353467

354-
auto dispatch_data2 = SetDefault(prim_params, KernelsTypes::FINALIZATION);
355-
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data2.gws;
356-
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data2.lws;
468+
dispatch_data = SetDefault(prim_params, KernelsTypes::FINALIZATION);
469+
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data.gws;
470+
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data.lws;
357471
kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || multi_tokens_mode || scores_calc_only;
358472

359-
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data2.gws;
360-
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data2.lws;
473+
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data.gws;
474+
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data.lws;
361475
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !multi_tokens_mode || scores_calc_only;
362476

363477
ScalarDescriptor num_of_partitions_scalar;
@@ -369,7 +483,7 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
369483
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars[0] = num_of_partitions_scalar;
370484

371485
if (has_scores_output) {
372-
auto dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION);
486+
dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION);
373487
kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.global = dispatch_data.gws;
374488
kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.local = dispatch_data.lws;
375489
kd.kernels[KernelsTypes::SCORES_CALCULATION].skip_execution = false;

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const {
7070
auto jit = MakeBaseParamsJitConstants(params);
7171

7272
if (params.conf.broadcast_axis != -1) {
73-
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", params.conf.group_size));
73+
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", params.conf.kv_group_size));
7474
jit.AddConstant(MakeJitConstant("DO_BROADCAST_KEY_VALUE", GetBroadcastInputStr(params.inputs[0].GetDims().size(),
7575
params.conf.broadcast_axis,
76-
params.conf.group_size)));
76+
params.conf.kv_group_size)));
7777
} else {
7878
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", 1));
7979
}

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct sdpa_configuration {
8383
int64_t kv_heads_num = -1;
8484

8585
// GQA configuration
86-
int64_t group_size = -1;
86+
int64_t kv_group_size = 1;
8787
int64_t broadcast_axis = -1;
8888

8989
bool is_causal = false;

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,18 @@ sdpa_kernel_selector::sdpa_kernel_selector() {
1717
Attach<SDPAKernelOpt>();
1818
Attach<SDPAKernelRef>();
1919
#ifdef ENABLE_ONEDNN_FOR_GPU
20-
Attach<SDPAKernelMicro>();
20+
int DISABLE_MICRO = 0;
21+
if (const auto env_var = std::getenv("DISABLE_MICRO")) {
22+
std::istringstream ss(env_var);
23+
ss >> DISABLE_MICRO;
24+
static bool printed = false;
25+
if (!printed) {
26+
std::cout << "Set DISABLE_MICRO=" << DISABLE_MICRO << "\n";
27+
printed = true;
28+
}
29+
}
30+
if (!DISABLE_MICRO)
31+
Attach<SDPAKernelMicro>();
2132
#endif
2233
}
2334

src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
5151
prim.kv_heads_num = kv_heads_num;
5252
prim.heads_num = heads_num;
5353

54+
static bool print_once = true;
55+
if (print_once) {
56+
std::cout << "PA config: heads=" << heads_num << " kv_heads=" << kv_heads_num << "\n";
57+
print_once = false;
58+
}
59+
5460
const size_t scale_idx = 9;
5561
const size_t sliding_window_idx = 10;
5662
const size_t alibi_idx = 11;

0 commit comments

Comments
 (0)