Skip to content

Commit 66cd717

Browse files
committed
[GPU] GQA optimization of PagedAttention OCL kernel for long sequences
1 parent c1e81b0 commit 66cd717

File tree

7 files changed

+255
-105
lines changed

7 files changed

+255
-105
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
660660

661661
if (desc->heads_num != desc->kv_heads_num) {
662662
config.broadcast_axis = 1;
663-
config.group_size = desc->heads_num / desc->kv_heads_num;
663+
config.kv_group_size = desc->heads_num / desc->kv_heads_num;
664664
}
665665

666666
if (desc->has_scores_output() && !is_dynamic) {

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
252252
if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) {
253253
if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) {
254254
config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim];
255-
config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
255+
config.kv_group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
256256
}
257257
}
258258

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+170-83
Large diffs are not rendered by default.

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp

+68-16
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,28 @@ namespace kernel_selector {
1313
namespace {
1414
enum KernelsTypes {
1515
SINGLE_TOKEN = 0,
16+
SINGLE_TOKEN_GQA,
1617
MULTI_TOKENS,
1718
FINALIZATION,
1819
FINALIZATION_MULTI_TOKENS,
1920
SCORES_CALCULATION,
2021
TOTAL_KERNELS_NUM
2122
};
2223

24+
static size_t get_heads_per_wi(const pa_sdpa_params& params) {
25+
if (params.conf.kv_group_size > 1) {
26+
std::vector<size_t> preferable_head_nums = {4, 3, 2};
27+
for (const auto& heads_num : preferable_head_nums) {
28+
const auto leftovers = params.conf.kv_group_size % heads_num;
29+
if (leftovers == 0 || heads_num - leftovers <= 1) {
30+
return heads_num;
31+
}
32+
}
33+
}
34+
35+
return 1;
36+
}
37+
2338
constexpr size_t subgroup_size = 16;
2439
constexpr size_t seq_len_partition_size = 256;
2540
constexpr size_t paged_attention_block_size = 16;
@@ -29,6 +44,7 @@ size_t get_sg_number_scale_factor(const pa_sdpa_params& params, size_t head_size
2944
if (params.conf.is_kv_compressed) {
3045
const size_t optimal_scale_factor = 2;
3146
if (kernel_type == KernelsTypes::SINGLE_TOKEN ||
47+
kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ||
3248
kernel_type == KernelsTypes::MULTI_TOKENS) {
3349
if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize) {
3450
return optimal_scale_factor;
@@ -45,6 +61,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) {
4561

4662
if (type == KernelsTypes::SINGLE_TOKEN) {
4763
kernel_name += "_single_token";
64+
} else if (type == KernelsTypes::SINGLE_TOKEN_GQA) {
65+
kernel_name += "_single_token_gqa";
4866
} else if (type == KernelsTypes::MULTI_TOKENS) {
4967
kernel_name += "_multi_tokens_seq";
5068
} else if (type == KernelsTypes::FINALIZATION) {
@@ -65,6 +83,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
6583

6684
const auto& params = static_cast<const pa_sdpa_params&>(p);
6785
std::vector<KernelsTypes> kernels_type = { KernelsTypes::SINGLE_TOKEN,
86+
KernelsTypes::SINGLE_TOKEN_GQA,
6887
KernelsTypes::MULTI_TOKENS,
6988
KernelsTypes::FINALIZATION,
7089
KernelsTypes::FINALIZATION_MULTI_TOKENS };
@@ -90,7 +109,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
90109

91110
int inputs_num = static_cast<int>(params.inputs.size());
92111
int outputs_num = 1;
93-
if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
112+
if (kernel_type == KernelsTypes::SINGLE_TOKEN || kernel_type == KernelsTypes::SINGLE_TOKEN_GQA) {
94113
// SINGLE_TOKEN kernel doesn't use the subsequence_begins input
95114
inputs_num -= 1;
96115
} else if (kernel_type == KernelsTypes::FINALIZATION) {
@@ -221,6 +240,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
221240
jit.AddConstant(MakeJitConstant("HEAD_SIZE", config.head_size));
222241
jit.AddConstant(MakeJitConstant("HEADS_NUM", config.heads_num));
223242
jit.AddConstant(MakeJitConstant("KV_HEADS_NUM", config.kv_heads_num));
243+
jit.AddConstant(MakeJitConstant("KV_HEADS_GROUP_SIZE", config.kv_group_size));
224244
jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", seq_len_partition_size));
225245
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size));
226246
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
@@ -236,8 +256,13 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
236256
jit.AddConstant(MakeJitConstant("ADJUSTED_HEAD_SIZE", params.conf.head_size));
237257
}
238258

239-
if (config.broadcast_axis != -1) {
240-
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size));
259+
if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
260+
auto heads_per_wi = get_heads_per_wi(params);
261+
jit.AddConstant(MakeJitConstant("HEADS_PER_WI", heads_per_wi));
262+
jit.AddConstant(MakeJitConstant("ITERATIONS_PER_KV_HEADS_GROUP", CeilDiv(config.kv_group_size, heads_per_wi)));
263+
jit.AddConstant(MakeJitConstant("HEADS_LEFTOVERS_NUM", config.kv_group_size % heads_per_wi));
264+
} else {
265+
jit.AddConstant(MakeJitConstant("HEADS_PER_WI", 1));
241266
}
242267

243268
auto sdpa_stage = 0;
@@ -293,6 +318,16 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
293318
heads_num,
294319
head_size * num_of_partitions * sg_scale };
295320
dispatch_data.lws = { 1, 1, head_size * sg_scale };
321+
} else if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
322+
auto sg_scale = get_sg_number_scale_factor(params, head_size, kernel_idx);
323+
324+
auto kv_groups = heads_num / params.conf.kv_group_size;
325+
auto gqa_heads_num = kv_groups * CeilDiv(params.conf.kv_group_size, get_heads_per_wi(params));
326+
327+
dispatch_data.gws = { total_tokens,
328+
gqa_heads_num,
329+
head_size * num_of_partitions * sg_scale };
330+
dispatch_data.lws = { 1, 1, head_size * sg_scale };
296331
} else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
297332
const auto& past_lens = params.inputs[3];
298333
const auto subsequences_number = past_lens.Batch().v;
@@ -334,13 +369,30 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
334369
const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output;
335370
const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED;
336371

337-
auto dispatch_data1 = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN);
338-
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data1.gws;
339-
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data1.lws;
340-
kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only;
341-
342-
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data1.gws;
343-
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data1.lws;
372+
// Apply GQA optimization starting from a certain sequence length value
373+
const auto min_gqa_sequence_len = 8 * seq_len_partition_size;
374+
// Apply GQA only if there is a single subsequence in the request,
375+
// as multiple subsequences might have significantly different lengths
376+
const auto max_subsequences_num = 1;
377+
const auto subsequences_num = prim_params.inputs[0].Batch().v;
378+
const auto can_use_gqa_kernel = prim_params.conf.paged_attention_max_len >= static_cast<int64_t>(min_gqa_sequence_len) &&
379+
subsequences_num <= max_subsequences_num &&
380+
prim_params.conf.kv_group_size > 1 &&
381+
!multi_tokens_mode &&
382+
!scores_calc_only;
383+
384+
auto dispatch_data = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN_GQA);
385+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].params.workGroups.global = dispatch_data.gws;
386+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].params.workGroups.local = dispatch_data.lws;
387+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].skip_execution = multi_tokens_mode || scores_calc_only || !can_use_gqa_kernel;
388+
389+
dispatch_data = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN);
390+
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data.gws;
391+
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data.lws;
392+
kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only || can_use_gqa_kernel;
393+
394+
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data.gws;
395+
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data.lws;
344396
kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !multi_tokens_mode || scores_calc_only;
345397

346398
size_t partition_size = 0;
@@ -351,13 +403,13 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
351403
}
352404
const size_t num_of_partitions = CeilDiv(prim_params.conf.paged_attention_max_len, partition_size);
353405

354-
auto dispatch_data2 = SetDefault(prim_params, KernelsTypes::FINALIZATION);
355-
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data2.gws;
356-
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data2.lws;
406+
dispatch_data = SetDefault(prim_params, KernelsTypes::FINALIZATION);
407+
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data.gws;
408+
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data.lws;
357409
kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || multi_tokens_mode || scores_calc_only;
358410

359-
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data2.gws;
360-
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data2.lws;
411+
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data.gws;
412+
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data.lws;
361413
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !multi_tokens_mode || scores_calc_only;
362414

363415
ScalarDescriptor num_of_partitions_scalar;
@@ -369,7 +421,7 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
369421
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars[0] = num_of_partitions_scalar;
370422

371423
if (has_scores_output) {
372-
auto dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION);
424+
dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION);
373425
kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.global = dispatch_data.gws;
374426
kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.local = dispatch_data.lws;
375427
kd.kernels[KernelsTypes::SCORES_CALCULATION].skip_execution = false;

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const {
7070
auto jit = MakeBaseParamsJitConstants(params);
7171

7272
if (params.conf.broadcast_axis != -1) {
73-
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", params.conf.group_size));
73+
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", params.conf.kv_group_size));
7474
jit.AddConstant(MakeJitConstant("DO_BROADCAST_KEY_VALUE", GetBroadcastInputStr(params.inputs[0].GetDims().size(),
7575
params.conf.broadcast_axis,
76-
params.conf.group_size)));
76+
params.conf.kv_group_size)));
7777
} else {
7878
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", 1));
7979
}

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct sdpa_configuration {
8383
int64_t kv_heads_num = -1;
8484

8585
// GQA configuration
86-
int64_t group_size = -1;
86+
int64_t kv_group_size = 1;
8787
int64_t broadcast_axis = -1;
8888

8989
bool is_causal = false;

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,18 @@ sdpa_kernel_selector::sdpa_kernel_selector() {
1717
Attach<SDPAKernelOpt>();
1818
Attach<SDPAKernelRef>();
1919
#ifdef ENABLE_ONEDNN_FOR_GPU
20-
Attach<SDPAKernelMicro>();
20+
int DISABLE_MICRO = 0;
21+
if (const auto env_var = std::getenv("DISABLE_MICRO")) {
22+
std::istringstream ss(env_var);
23+
ss >> DISABLE_MICRO;
24+
static bool printed = false;
25+
if (!printed) {
26+
std::cout << "Set DISABLE_MICRO=" << DISABLE_MICRO << "\n";
27+
printed = true;
28+
}
29+
}
30+
if (!DISABLE_MICRO)
31+
Attach<SDPAKernelMicro>();
2132
#endif
2233
}
2334

0 commit comments

Comments
 (0)