Skip to content

Commit 99d0bd4

Browse files
committed
[GPU] GQA optimization of PagedAttention OCL kernel for long sequences
1 parent c1e81b0 commit 99d0bd4

File tree

7 files changed

+256
-105
lines changed

7 files changed

+256
-105
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
660660

661661
if (desc->heads_num != desc->kv_heads_num) {
662662
config.broadcast_axis = 1;
663-
config.group_size = desc->heads_num / desc->kv_heads_num;
663+
config.kv_group_size = desc->heads_num / desc->kv_heads_num;
664664
}
665665

666666
if (desc->has_scores_output() && !is_dynamic) {

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
252252
if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) {
253253
if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) {
254254
config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim];
255-
config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
255+
config.kv_group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
256256
}
257257
}
258258

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+170-83
Large diffs are not rendered by default.

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp

+69-16
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,29 @@ namespace kernel_selector {
1313
namespace {
1414
enum KernelsTypes {
1515
SINGLE_TOKEN = 0,
16+
SINGLE_TOKEN_GQA,
1617
MULTI_TOKENS,
1718
FINALIZATION,
1819
FINALIZATION_MULTI_TOKENS,
1920
SCORES_CALCULATION,
2021
TOTAL_KERNELS_NUM
2122
};
2223

24+
static size_t get_heads_per_wi(const pa_sdpa_params& params) {
25+
if (params.conf.kv_group_size > 1) {
26+
std::vector<size_t> preferable_head_nums = {4, 3, 2};
27+
for (const auto& heads_num : preferable_head_nums) {
28+
const auto leftovers = params.conf.kv_group_size % heads_num;
29+
if (leftovers == 0 || heads_num - leftovers <= 1) {
30+
return heads_num;
31+
}
32+
}
33+
}
34+
35+
return 1;
36+
}
37+
38+
constexpr size_t heads_per_iteration = 4;
2339
constexpr size_t subgroup_size = 16;
2440
constexpr size_t seq_len_partition_size = 256;
2541
constexpr size_t paged_attention_block_size = 16;
@@ -29,6 +45,7 @@ size_t get_sg_number_scale_factor(const pa_sdpa_params& params, size_t head_size
2945
if (params.conf.is_kv_compressed) {
3046
const size_t optimal_scale_factor = 2;
3147
if (kernel_type == KernelsTypes::SINGLE_TOKEN ||
48+
kernel_type == KernelsTypes::SINGLE_TOKEN_GQA ||
3249
kernel_type == KernelsTypes::MULTI_TOKENS) {
3350
if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize) {
3451
return optimal_scale_factor;
@@ -45,6 +62,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) {
4562

4663
if (type == KernelsTypes::SINGLE_TOKEN) {
4764
kernel_name += "_single_token";
65+
} else if (type == KernelsTypes::SINGLE_TOKEN_GQA) {
66+
kernel_name += "_single_token_gqa";
4867
} else if (type == KernelsTypes::MULTI_TOKENS) {
4968
kernel_name += "_multi_tokens_seq";
5069
} else if (type == KernelsTypes::FINALIZATION) {
@@ -65,6 +84,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
6584

6685
const auto& params = static_cast<const pa_sdpa_params&>(p);
6786
std::vector<KernelsTypes> kernels_type = { KernelsTypes::SINGLE_TOKEN,
87+
KernelsTypes::SINGLE_TOKEN_GQA,
6888
KernelsTypes::MULTI_TOKENS,
6989
KernelsTypes::FINALIZATION,
7090
KernelsTypes::FINALIZATION_MULTI_TOKENS };
@@ -90,7 +110,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const {
90110

91111
int inputs_num = static_cast<int>(params.inputs.size());
92112
int outputs_num = 1;
93-
if (kernel_type == KernelsTypes::SINGLE_TOKEN) {
113+
if (kernel_type == KernelsTypes::SINGLE_TOKEN || kernel_type == KernelsTypes::SINGLE_TOKEN_GQA) {
94114
// SINGLE_TOKEN kernel doesn't use the subsequence_begins input
95115
inputs_num -= 1;
96116
} else if (kernel_type == KernelsTypes::FINALIZATION) {
@@ -221,6 +241,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
221241
jit.AddConstant(MakeJitConstant("HEAD_SIZE", config.head_size));
222242
jit.AddConstant(MakeJitConstant("HEADS_NUM", config.heads_num));
223243
jit.AddConstant(MakeJitConstant("KV_HEADS_NUM", config.kv_heads_num));
244+
jit.AddConstant(MakeJitConstant("KV_HEADS_GROUP_SIZE", config.kv_group_size));
224245
jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", seq_len_partition_size));
225246
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size));
226247
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
@@ -236,8 +257,13 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
236257
jit.AddConstant(MakeJitConstant("ADJUSTED_HEAD_SIZE", params.conf.head_size));
237258
}
238259

239-
if (config.broadcast_axis != -1) {
240-
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size));
260+
if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
261+
auto heads_per_wi = get_heads_per_wi(params);
262+
jit.AddConstant(MakeJitConstant("HEADS_PER_WI", heads_per_wi));
263+
jit.AddConstant(MakeJitConstant("ITERATIONS_PER_KV_HEADS_GROUP", CeilDiv(config.kv_group_size, heads_per_wi)));
264+
jit.AddConstant(MakeJitConstant("HEADS_LEFTOVERS_NUM", config.kv_group_size % heads_per_wi));
265+
} else {
266+
jit.AddConstant(MakeJitConstant("HEADS_PER_WI", 1));
241267
}
242268

243269
auto sdpa_stage = 0;
@@ -293,6 +319,16 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params&
293319
heads_num,
294320
head_size * num_of_partitions * sg_scale };
295321
dispatch_data.lws = { 1, 1, head_size * sg_scale };
322+
} else if (kernel_idx == KernelsTypes::SINGLE_TOKEN_GQA) {
323+
auto sg_scale = get_sg_number_scale_factor(params, head_size, kernel_idx);
324+
325+
auto kv_groups = heads_num / params.conf.kv_group_size;
326+
auto gqa_heads_num = kv_groups * CeilDiv(params.conf.kv_group_size, get_heads_per_wi(params));
327+
328+
dispatch_data.gws = { total_tokens,
329+
gqa_heads_num,
330+
head_size * num_of_partitions * sg_scale };
331+
dispatch_data.lws = { 1, 1, head_size * sg_scale };
296332
} else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) {
297333
const auto& past_lens = params.inputs[3];
298334
const auto subsequences_number = past_lens.Batch().v;
@@ -334,13 +370,30 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
334370
const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output;
335371
const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED;
336372

337-
auto dispatch_data1 = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN);
338-
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data1.gws;
339-
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data1.lws;
340-
kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only;
341-
342-
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data1.gws;
343-
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data1.lws;
373+
// Apply GQA optimization starting from a certain sequence length value
374+
const auto min_gqa_sequence_len = 8 * seq_len_partition_size;
375+
// Apply GQA only if there is a single subsequence in the request,
376+
// as multiple subsequences might have significantly different lengths
377+
const auto max_subsequences_num = 1;
378+
const auto subsequences_num = prim_params.inputs[0].Batch().v;
379+
const auto can_use_gqa_kernel = prim_params.conf.paged_attention_max_len >= static_cast<int64_t>(min_gqa_sequence_len) &&
380+
subsequences_num <= max_subsequences_num &&
381+
prim_params.conf.kv_group_size > 1 &&
382+
!multi_tokens_mode &&
383+
!scores_calc_only;
384+
385+
auto dispatch_data = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN_GQA);
386+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].params.workGroups.global = dispatch_data.gws;
387+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].params.workGroups.local = dispatch_data.lws;
388+
kd.kernels[KernelsTypes::SINGLE_TOKEN_GQA].skip_execution = multi_tokens_mode || scores_calc_only || !can_use_gqa_kernel;
389+
390+
dispatch_data = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN);
391+
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data.gws;
392+
kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data.lws;
393+
kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = multi_tokens_mode || scores_calc_only || can_use_gqa_kernel;
394+
395+
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data.gws;
396+
kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data.lws;
344397
kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !multi_tokens_mode || scores_calc_only;
345398

346399
size_t partition_size = 0;
@@ -351,13 +404,13 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
351404
}
352405
const size_t num_of_partitions = CeilDiv(prim_params.conf.paged_attention_max_len, partition_size);
353406

354-
auto dispatch_data2 = SetDefault(prim_params, KernelsTypes::FINALIZATION);
355-
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data2.gws;
356-
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data2.lws;
407+
dispatch_data = SetDefault(prim_params, KernelsTypes::FINALIZATION);
408+
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data.gws;
409+
kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data.lws;
357410
kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || multi_tokens_mode || scores_calc_only;
358411

359-
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data2.gws;
360-
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data2.lws;
412+
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data.gws;
413+
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data.lws;
361414
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !multi_tokens_mode || scores_calc_only;
362415

363416
ScalarDescriptor num_of_partitions_scalar;
@@ -369,7 +422,7 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
369422
kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars[0] = num_of_partitions_scalar;
370423

371424
if (has_scores_output) {
372-
auto dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION);
425+
dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION);
373426
kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.global = dispatch_data.gws;
374427
kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.local = dispatch_data.lws;
375428
kd.kernels[KernelsTypes::SCORES_CALCULATION].skip_execution = false;

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const {
7070
auto jit = MakeBaseParamsJitConstants(params);
7171

7272
if (params.conf.broadcast_axis != -1) {
73-
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", params.conf.group_size));
73+
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", params.conf.kv_group_size));
7474
jit.AddConstant(MakeJitConstant("DO_BROADCAST_KEY_VALUE", GetBroadcastInputStr(params.inputs[0].GetDims().size(),
7575
params.conf.broadcast_axis,
76-
params.conf.group_size)));
76+
params.conf.kv_group_size)));
7777
} else {
7878
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", 1));
7979
}

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct sdpa_configuration {
8383
int64_t kv_heads_num = -1;
8484

8585
// GQA configuration
86-
int64_t group_size = -1;
86+
int64_t kv_group_size = 1;
8787
int64_t broadcast_axis = -1;
8888

8989
bool is_causal = false;

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,18 @@ sdpa_kernel_selector::sdpa_kernel_selector() {
1717
Attach<SDPAKernelOpt>();
1818
Attach<SDPAKernelRef>();
1919
#ifdef ENABLE_ONEDNN_FOR_GPU
20-
Attach<SDPAKernelMicro>();
20+
int DISABLE_MICRO = 0;
21+
if (const auto env_var = std::getenv("DISABLE_MICRO")) {
22+
std::istringstream ss(env_var);
23+
ss >> DISABLE_MICRO;
24+
static bool printed = false;
25+
if (!printed) {
26+
std::cout << "Set DISABLE_MICRO=" << DISABLE_MICRO << "\n";
27+
printed = true;
28+
}
29+
}
30+
if (!DISABLE_MICRO)
31+
Attach<SDPAKernelMicro>();
2132
#endif
2233
}
2334

0 commit comments

Comments
 (0)