Skip to content

Commit 38107ba

Browse files
committed
[GPU] GQA optimization of PagedAttention OCL kernel for long sequences
1 parent c1e81b0 commit 38107ba

File tree

6 files changed

+243
-104
lines changed

6 files changed

+243
-104
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
660660

661661
if (desc->heads_num != desc->kv_heads_num) {
662662
config.broadcast_axis = 1;
663-
config.group_size = desc->heads_num / desc->kv_heads_num;
663+
config.kv_group_size = desc->heads_num / desc->kv_heads_num;
664664
}
665665

666666
if (desc->has_scores_output() && !is_dynamic) {

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
252252
if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) {
253253
if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) {
254254
config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim];
255-
config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
255+
config.kv_group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
256256
}
257257
}
258258

0 commit comments

Comments
 (0)