Skip to content

Commit 9102e88

Browse files
committed
[GPU] Update PagedAttention creation logic: use head_size and heads_num parameters from rt_info if available
1 parent ed11461 commit 9102e88

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
2525
auto inputs = p.GetInputInfo(op);
2626
auto prim = cldnn::paged_attention(layer_type_name_ID(op), inputs);
2727

28-
auto key_cache_ps = op->get_input_partial_shape(3);
28+
const auto& rt_info = op->get_rt_info();
29+
const auto k_head_size_id = "k_head_size";
30+
const auto num_k_heads_id = "num_k_heads";
31+
const auto has_rt_params = rt_info.find(k_head_size_id) != rt_info.end() &&
32+
rt_info.find(num_k_heads_id) != rt_info.end();
33+
2934
auto query_ps = op->get_input_partial_shape(0);
30-
auto head_size = key_cache_ps[2].get_length();
31-
auto kv_heads_num = key_cache_ps[1].get_length();
35+
auto key_cache_ps = op->get_input_partial_shape(3);
36+
auto head_size = has_rt_params ? rt_info.at(k_head_size_id).as<int64_t>() : key_cache_ps[2].get_length();
37+
auto kv_heads_num = has_rt_params ? rt_info.at(num_k_heads_id).as<int64_t>() : key_cache_ps[1].get_length();
3238

3339
// WA: in some cases, the query input may have a bounded dimension
3440
// Use input shape of the input node in such cases

0 commit comments

Comments
 (0)