Skip to content

Commit e25e2c7

Browse files
committed
[GPU] Fix KV-cache shape infer for QKV order {1,2,0,3}
1 parent 1d038e7 commit e25e2c7

File tree

6 files changed

+18
-15
lines changed

6 files changed

+18
-15
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ void prepare_buffer_fusing::run(program& p) {
943943
auto update_scale_zp = [&](size_t kv_cache_output_idx, size_t read_value_output_idx) {
944944
auto scales_out_layout = node.get_output_layout(false, kv_cache_output_idx);
945945

946-
const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
946+
const auto scales_zp_concat_axis = kv_cache_inst::get_scale_zp_sequence_axis();
947947
padding::DynamicDimsMask info_dynamic_pad_scales;
948948
info_dynamic_pad_scales[scales_zp_concat_axis] = 1;
949949
scales_out_layout.data_padding._dynamic_dims_mask = info_dynamic_pad_scales;

src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ class typed_primitive_inst<kv_cache> : public typed_primitive_inst_base<kv_cache
6262
return sequence_axis >= 0 ? sequence_axis : past_layout_rank + sequence_axis;
6363
}
6464

65-
static int64_t get_scale_zp_sequence_axis(int64_t sequence_axis, const kv_cache::QuantizationAttributes& quantization_attrs) {
66-
const auto scale_zp_concat_axis = quantization_attrs.scales_zp_output_order[sequence_axis];
65+
static int64_t get_scale_zp_sequence_axis() {
66+
const auto scale_zp_concat_axis = 2;
6767
return scale_zp_concat_axis;
6868
}
6969

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
851851
auto prealloc_shape = updated_layouts[i].get_shape();
852852
const auto shape_rank = prealloc_shape.size();
853853
const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank)
854-
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
854+
: kv_cache_inst::get_scale_zp_sequence_axis();
855855

856856
prealloc_shape[seq_axis] += tmp_prealloc_count;
857857
required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies<size_t>());
@@ -883,7 +883,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
883883
const auto& desc = _node->as<kv_cache>().get_primitive();
884884
const auto shape_rank = updated_layouts[i].get_shape().size();
885885
const auto seq_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, shape_rank)
886-
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
886+
: kv_cache_inst::get_scale_zp_sequence_axis();
887887

888888
prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], false, i, tmp_prealloc_count, seq_axis);
889889
} else {
@@ -907,7 +907,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
907907
auto& present_layout = _impl_params->output_layouts[i];
908908
const auto present_layout_rank = present_layout.get_partial_shape().size();
909909
const auto sequence_axis = i == 0 ? kv_cache_inst::get_sequence_axis(desc->concat_axis, present_layout_rank)
910-
: kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
910+
: kv_cache_inst::get_scale_zp_sequence_axis();
911911

912912
auto max_pad = kv_cache_inst::get_max_pad(present_layout,
913913
_max_output_layout_count[i],
@@ -978,7 +978,7 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
978978
if (max_pad > 0) {
979979
if (auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable)) {
980980
auto present_scales_layout = _impl_params->output_layouts[2];
981-
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
981+
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
982982

983983
// In case of compressed KV-cache, calling update_impl for each iteration
984984
// because of scales layout [batch, num_heads, seq_len, head_size], which requires proper
@@ -1374,7 +1374,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
13741374
if (desc->compressed) {
13751375
auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable);
13761376
auto& present_scales_layout = _impl_params->output_layouts[2];
1377-
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
1377+
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
13781378
kv_cache_inst::update_pad(present_scales_layout, max_pad - new_seq_len, sequence_axis);
13791379
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id()
13801380
<< " Updated present_scale_layout's pad : " << present_scales_layout.to_string() << std::endl;
@@ -1398,7 +1398,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
13981398

13991399
if (desc->compressed) {
14001400
auto& past_scale_layout = _impl_params->input_layouts[3];
1401-
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis(desc->concat_axis, desc->quantization_attributes);
1401+
const auto sequence_axis = kv_cache_inst::get_scale_zp_sequence_axis();
14021402
kv_cache_inst::update_pad(past_scale_layout, max_pad, sequence_axis);
14031403

14041404
if (desc->get_compression_zp_inputs_num() > 0) {

src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,7 @@ std::vector<ov::PartialShape> shape_infer(const KVCacheCompressed* op,
191191
auto quantized_data_shapes =
192192
ov::op::internal::DynamicQuantize::shape_infer(&dq_op, { input_shapes[1] });
193193

194-
const auto concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size());
195-
const auto scales_concat_axis = op->get_quantization_attrs().scales_zp_output_order[concat_axis];
194+
const auto scales_concat_axis = 2;
196195
ov::PartialShape compression_scale_shape = input_shapes[3];
197196
compression_scale_shape[scales_concat_axis] += quantized_data_shapes[1][scales_concat_axis];
198197
out_shapes[2] = compression_scale_shape;

src/plugins/intel_gpu/tests/common/subgraphs_builders.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ inline std::shared_ptr<ov::Node> make_qkv_transpose(ov::Output<ov::Node> qkv, st
120120
return std::make_shared<ov::op::v1::Transpose>(qkv, transpose_const);
121121
}
122122

123-
inline std::shared_ptr<ov::Node> make_kv_rearrange(ov::Output<ov::Node> kv_past, ov::Output<ov::Node> beam_idx) {
124-
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
123+
inline std::shared_ptr<ov::Node> make_kv_rearrange(ov::Output<ov::Node> kv_past, ov::Output<ov::Node> beam_idx, int axis_val = 0) {
124+
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, axis_val);
125125
return std::make_shared<ov::op::v8::Gather>(kv_past, beam_idx, axis, 0);
126126
}
127127

@@ -242,8 +242,8 @@ inline std::shared_ptr<ov::Model> make_llm_kv_cache_sdpa_pattern(ov::Dimension b
242242
in_beam_idx->set_friendly_name("beam_idx");
243243
params.push_back(in_beam_idx);
244244

245-
concat_k_input = make_kv_rearrange(past_k, in_beam_idx);
246-
concat_v_input = make_kv_rearrange(past_v, in_beam_idx);
245+
concat_k_input = make_kv_rearrange(past_k, in_beam_idx, qkv_order[0]);
246+
concat_v_input = make_kv_rearrange(past_v, in_beam_idx, qkv_order[0]);
247247
}
248248

249249
auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{concat_k_input, in_k_token}, concat_axis);

src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ std::vector<Params> get_test_params() {
342342
p.push_back({with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
343343
p.push_back({with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
344344
p.push_back({!with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
345+
p.push_back({!with_rearrange, with_mask, !with_scale, !causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}});
345346

346347
// Beam search
347348
p.push_back({with_rearrange, !with_mask, !with_scale, !causal, !compressed, 2, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
@@ -351,6 +352,7 @@ std::vector<Params> get_test_params() {
351352
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
352353
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
353354
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
355+
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}});
354356

355357
/* -- causal mask -- */
356358

@@ -367,6 +369,8 @@ std::vector<Params> get_test_params() {
367369
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
368370
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
369371
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
372+
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {1, 2, 0, 3}});
373+
370374
return p;
371375
}
372376

0 commit comments

Comments
 (0)