Skip to content

Commit 95b2a70

Browse files
committed
[GPU] Enable KV-cache compression in PagedAttention OCL kernel
1 parent b60f5c5 commit 95b2a70

File tree

15 files changed

+591
-94
lines changed

15 files changed

+591
-94
lines changed

src/common/transformations/include/transformations/common_optimizations/convert_pagedattn_inputs.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class TRANSFORMATIONS_API ConvertPagedAttnInputs;
1919

2020
class ConvertPagedAttnInputs : public ov::pass::MatcherPass {
2121
public:
22+
using UpdateShapeFunc = std::function<void(ov::element::Type, bool, size_t, int64_t&, int64_t&)>;
2223
struct KVCacheConfig {
2324
ov::element::Type keyCachePrecision;
2425
ov::element::Type valueCachePrecision;
@@ -34,14 +35,15 @@ class ConvertPagedAttnInputs : public ov::pass::MatcherPass {
3435
};
3536

3637
OPENVINO_MATCHER_PASS_RTTI("ConvertPagedAttnInputs");
37-
ConvertPagedAttnInputs(const KVCacheConfig& config);
38+
ConvertPagedAttnInputs(const KVCacheConfig& config, UpdateShapeFunc update_shape_func);
3839

3940
void setKVCacheConfig(const KVCacheConfig& config);
4041

4142
const KVCacheConfig& getKVCacheConfig() const;
4243

4344
private:
4445
KVCacheConfig m_config;
46+
UpdateShapeFunc m_update_shape_func;
4547
};
4648

4749
} // namespace pass

src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp

+6-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include "transformations/utils/utils.hpp"
1818
using namespace ov::gen_pattern;
1919

20-
ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config) : m_config(config) {
20+
ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config, UpdateShapeFunc f)
21+
: m_config(config),
22+
m_update_shape_func(std::move(f)) {
2123
MATCHER_SCOPE(ConvertPagedAttnInputs);
2224

2325
auto Q = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
@@ -83,7 +85,7 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co
8385
const size_t group_size,
8486
const bool bychannel,
8587
const std::vector<size_t>& orders) {
86-
size_t _block_size = block_size;
88+
ov::Dimension::value_type _block_size = block_size;
8789
ov::Dimension::value_type _head_nums = head_nums;
8890
ov::Dimension::value_type _head_size = head_size;
8991
ov::Dimension::value_type _group_size = group_size;
@@ -94,17 +96,9 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co
9496
}
9597
}
9698
size_t group_num = _head_size / _group_size;
97-
if (precision == ov::element::u8) {
98-
if (bychannel) {
99-
_block_size += 2 * sizeof(float);
100-
} else {
101-
_head_size += sizeof(float) * 2 * group_num;
102-
}
103-
} else if (precision == ov::element::u4) {
104-
_head_size += sizeof(float) * 2 * group_num * 2;
105-
}
106-
auto block_shape = ov::PartialShape::dynamic(4);
99+
m_update_shape_func(precision, bychannel, group_num, _head_size, _block_size);
107100

101+
auto block_shape = ov::PartialShape::dynamic(4);
108102
block_shape[orders[0]] = -1;
109103
block_shape[orders[1]] = _head_nums;
110104
block_shape[orders[2]] = _block_size;

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,25 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
470470
cacheConfig.valueCacheQuantBychannel = false;
471471
cacheConfig.keyCacheDimOrder = {0, 1, 2, 3};
472472
cacheConfig.valueCacheDimOrder = {0, 1, 2, 3};
473-
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertPagedAttnInputs, cacheConfig);
473+
CPU_REGISTER_PASS_COMMON(
474+
manager,
475+
ov::pass::ConvertPagedAttnInputs,
476+
cacheConfig,
477+
[](const ov::element::Type& precision,
478+
const bool bychannel,
479+
const size_t group_num,
480+
int64_t& head_size,
481+
int64_t& block_size) {
482+
if (precision == ov::element::u8) {
483+
if (bychannel) {
484+
block_size += 2 * sizeof(float);
485+
} else {
486+
head_size += sizeof(float) * 2 * group_num;
487+
}
488+
} else if (precision == ov::element::u4) {
489+
head_size += sizeof(float) * 2 * group_num * 2;
490+
}
491+
});
474492
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
475493
CPU_REGISTER_PASS_X64(manager, ov::pass::KeepConstPrecision, decompression_precisions, false, true);
476494
CPU_SET_CALLBACK_X64(

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
687687
config.paged_attention_max_len = max_context_len_mem_lock[0];
688688
}
689689

690+
if (data_type_traits::is_i8_u8(impl_param.get_input_layout(3).data_type)) {
691+
config.is_kv_compressed = true;
692+
config.use_asymmetric_quantization = true;
693+
}
694+
690695
return config;
691696
}
692697

@@ -709,6 +714,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
709714
params.inputs[2] = rotation_trig_lut_tensor;
710715
params.outputs[0] = key_cache_tensor;
711716

717+
params.original_cache_dt = to_data_type(impl_param.get_input_layout(1).data_type);
712718
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
713719

714720
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
@@ -826,6 +832,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
826832

827833
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
828834

835+
// Currently, for the processing of the 1st token, plain SDPA kernels are used, which expect
836+
// uncompressed plain QKV inputs. Therefore, set is_kv_compressed=false
837+
params.conf.is_kv_compressed = false;
838+
params.conf.use_asymmetric_quantization = false;
839+
829840
const std::vector<int64_t> default_order = {0, 1, 2, 3};
830841
params.input0_order = default_order;
831842
params.input1_order = default_order;
@@ -991,6 +1002,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
9911002
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
9921003
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
9931004
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
1005+
9941006
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, 0, impl_param.is_dynamic());
9951007
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
9961008
kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params));
@@ -1013,6 +1025,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
10131025
impl->use_micro_sdpa = true;
10141026
}
10151027

1028+
std::cout << "use_micro=" << impl->use_micro_sdpa << " KV-cache layouts=["
1029+
<< impl_param.get_input_layout(3).to_short_string() << ", "
1030+
<< impl_param.get_input_layout(4).to_short_string() << "]\n";
1031+
10161032
return impl;
10171033
}
10181034

src/plugins/intel_gpu/src/graph/paged_attention.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
7373
paged_attention_info.add("scale", desc->scale_val.value_or(1.0f));
7474
paged_attention_info.add("has_alibi", desc->has_alibi);
7575
paged_attention_info.add("has_rotated_blocks", desc->has_rotated_blocks);
76+
paged_attention_info.add("key_cache_dt", node.get_input_layout(3).data_type);
77+
paged_attention_info.add("value_cache_dt", node.get_input_layout(4).data_type);
7678
node_info->add("paged_attention primitive info", paged_attention_info);
7779
node_info->dump(primitive_description);
7880

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_rotate_ref.cl

+67-8
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44

55
#include "include/batch_headers/common.cl"
66

7+
#if IS_KV_COMPRESSED
8+
#define SUBGROUPS_PER_WG 1
9+
#else
710
#define SUBGROUPS_PER_WG KV_HEADS_NUM
11+
#endif
12+
#define ACCUMULATOR_TYPE float
813

914
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
10-
__attribute__((reqd_work_group_size(SUBGROUP_SIZE, KV_HEADS_NUM, 1)))
15+
__attribute__((reqd_work_group_size(SUBGROUP_SIZE, SUBGROUPS_PER_WG, 1)))
1116
KERNEL(pa_kv_cache_rotate)(
1217
OPTIONAL_SHAPE_INFO_ARG
1318
__global const INPUT0_TYPE* rotated_block_indices,
@@ -62,22 +67,76 @@ KERNEL(pa_kv_cache_rotate)(
6267
barrier(CLK_LOCAL_MEM_FENCE);
6368

6469
const uint token_coefficient_idx = per_token_rotation ? sglid : 0;
65-
const uint block_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
66-
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + sglid;
70+
const uint block_base_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
71+
head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
72+
const uint token_offset = block_base_offset + sglid;
73+
74+
#if IS_KV_COMPRESSED
75+
const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
76+
UNCOMPRESSED_TYPE* comp_ptr = key_cache + comp_offset;
77+
UNCOMPRESSED_TYPE comp_scale = comp_ptr[0 + sglid];
78+
UNCOMPRESSED_TYPE comp_zp = comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid];
79+
80+
UNCOMPRESSED_TYPE max_value = UNCOMPRESSED_VAL_MIN;
81+
UNCOMPRESSED_TYPE min_value = UNCOMPRESSED_VAL_MAX;
82+
83+
// Reuse SLM to store dequantized rotated values
84+
__local UNCOMPRESSED_TYPE* rotated_data = (__local UNCOMPRESSED_TYPE*)(&rotation_coefficients[0][0]);
85+
#endif
86+
87+
// Apply cache rotation
6788
for (uint i = 0; i < HEAD_SIZE / 2; i++) {
68-
const uint cache_offset = block_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
69-
OUTPUT_TYPE cache_value_first = key_cache[cache_offset];
70-
OUTPUT_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE];
89+
const uint cache_offset = token_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
90+
91+
#if IS_KV_COMPRESSED
92+
UNCOMPRESSED_TYPE cache_value_first = TO_UNCOMPRESSED_TYPE(key_cache[cache_offset] - comp_zp) * comp_scale;
93+
UNCOMPRESSED_TYPE cache_value_second = TO_UNCOMPRESSED_TYPE(key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] - comp_zp) * comp_scale;
94+
#else
95+
UNCOMPRESSED_TYPE cache_value_first = key_cache[cache_offset];
96+
UNCOMPRESSED_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE];
97+
#endif
7198

7299
INPUT2_TYPE rotation_value_cos = rotation_coefficients[i][token_coefficient_idx];
73100
INPUT2_TYPE rotation_value_sin = rotation_coefficients[i + (HEAD_SIZE / 2)][token_coefficient_idx];
74101

75-
OUTPUT_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin;
76-
OUTPUT_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos;
102+
UNCOMPRESSED_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin;
103+
UNCOMPRESSED_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos;
77104

105+
#if IS_KV_COMPRESSED
106+
max_value = fmax(fmax(max_value, new_cache_value_first), new_cache_value_second);
107+
min_value = fmin(fmin(min_value, new_cache_value_first), new_cache_value_second);
108+
109+
rotated_data[(i + 0) * PAGED_ATTENTION_BLOCK_SIZE + sglid] = new_cache_value_first;
110+
rotated_data[(i + (HEAD_SIZE / 2)) * PAGED_ATTENTION_BLOCK_SIZE + sglid] = new_cache_value_second;
111+
#else
78112
key_cache[cache_offset] = new_cache_value_first;
79113
key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] = new_cache_value_second;
114+
#endif
115+
}
116+
117+
#if IS_KV_COMPRESSED
118+
// Re-quantize cache data
119+
ACCUMULATOR_TYPE grp_max = 0.001;
120+
ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max) : (max_value - min_value);
121+
ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value);
122+
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) + CHAR_MIN;
123+
UNCOMPRESSED_TYPE scale = (UNCOMPRESSED_TYPE)(scale_tmp);
124+
UNCOMPRESSED_TYPE zp = (UNCOMPRESSED_TYPE)(zp_tmp);
125+
126+
// Note: absence of this explicit unrolling directive leads to automatic
127+
// unrolling and causes registers spill. Set unrolling to a reasonable value manually
128+
__attribute__((opencl_unroll_hint(8)))
129+
for (uint i = 0; i < HEAD_SIZE; i++) {
130+
OUTPUT_TYPE quantized_res = convert_char_rte(rotated_data[i * PAGED_ATTENTION_BLOCK_SIZE + sglid] * scale + zp);
131+
132+
const uint cache_offset = token_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
133+
key_cache[cache_offset] = quantized_res;
80134
}
135+
136+
comp_ptr[0 + sglid] = 1.0 / scale;
137+
comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid] = zp;
138+
#endif
81139
}
82140

141+
#undef ACCUMULATOR_TYPE
83142
#undef SUBGROUPS_PER_WG

0 commit comments

Comments
 (0)