Skip to content

Commit c02f776

Browse files
committed
[GPU] PA KV-cache compression tests
1 parent b60f5c5 commit c02f776

11 files changed

+521
-76
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
687687
config.paged_attention_max_len = max_context_len_mem_lock[0];
688688
}
689689

690+
if (data_type_traits::is_i8_u8(impl_param.get_input_layout(3).data_type)) {
691+
config.is_kv_compressed = true;
692+
config.use_asymmetric_quantization = true;
693+
}
694+
690695
return config;
691696
}
692697

@@ -709,6 +714,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
709714
params.inputs[2] = rotation_trig_lut_tensor;
710715
params.outputs[0] = key_cache_tensor;
711716

717+
params.original_cache_dt = to_data_type(impl_param.get_input_layout(1).data_type);
712718
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
713719

714720
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
@@ -826,6 +832,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
826832

827833
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
828834

835+
// Currently, for the processing of the 1st token, plain SDPA kernels are used, which expect
836+
// uncompressed plain QKV inputs. Therefore, set is_kv_compressed=false
837+
params.conf.is_kv_compressed = false;
838+
params.conf.use_asymmetric_quantization = false;
839+
829840
const std::vector<int64_t> default_order = {0, 1, 2, 3};
830841
params.input0_order = default_order;
831842
params.input1_order = default_order;
@@ -991,6 +1002,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
9911002
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
9921003
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
9931004
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
1005+
9941006
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, 0, impl_param.is_dynamic());
9951007
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
9961008
kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params));
@@ -1010,6 +1022,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
10101022
impl->has_rotated_blocks = desc->has_rotated_blocks;
10111023

10121024
if (!kernels_data[Stage::SDPA].kernels[0].micro_kernels.empty()) {
1025+
std::cout << "Micro kernel is used\n";
10131026
impl->use_micro_sdpa = true;
10141027
}
10151028

src/plugins/intel_gpu/src/graph/paged_attention.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
7373
paged_attention_info.add("scale", desc->scale_val.value_or(1.0f));
7474
paged_attention_info.add("has_alibi", desc->has_alibi);
7575
paged_attention_info.add("has_rotated_blocks", desc->has_rotated_blocks);
76+
paged_attention_info.add("key_cache_dt", node.get_input_layout(3).data_type);
77+
paged_attention_info.add("value_cache_dt", node.get_input_layout(4).data_type);
7678
node_info->add("paged_attention primitive info", paged_attention_info);
7779
node_info->dump(primitive_description);
7880

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_rotate_ref.cl

+63-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
#include "include/batch_headers/common.cl"
66

7+
#if IS_KV_COMPRESSED
8+
#define SUBGROUPS_PER_WG 1
9+
#else
710
#define SUBGROUPS_PER_WG KV_HEADS_NUM
11+
#endif
812

913
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
1014
__attribute__((reqd_work_group_size(SUBGROUP_SIZE, KV_HEADS_NUM, 1)))
@@ -62,22 +66,75 @@ KERNEL(pa_kv_cache_rotate)(
6266
barrier(CLK_LOCAL_MEM_FENCE);
6367

6468
const uint token_coefficient_idx = per_token_rotation ? sglid : 0;
65-
const uint block_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
66-
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + sglid;
69+
const uint block_base_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
70+
head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
71+
const uint block_offset = block_base_offset + sglid;
72+
73+
#if IS_KV_COMPRESSED
74+
const uint comp_offset = block_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
75+
UNCOMPRESSED_TYPE* comp_ptr = key_cache + comp_offset;
76+
UNCOMPRESSED_TYPE comp_scale = comp_ptr[0 + sglid];
77+
UNCOMPRESSED_TYPE comp_zp = comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid];
78+
79+
UNCOMPRESSED_TYPE max_value = UNCOMPRESSED_VAL_MIN;
80+
UNCOMPRESSED_TYPE min_value = UNCOMPRESSED_VAL_MAX;
81+
82+
__local UNCOMPRESSED_TYPE* rotated_data = (__local UNCOMPRESSED_TYPE*)(&rotation_coefficients[0][0]);
83+
#endif
84+
6785
for (uint i = 0; i < HEAD_SIZE / 2; i++) {
6886
const uint cache_offset = block_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
69-
OUTPUT_TYPE cache_value_first = key_cache[cache_offset];
70-
OUTPUT_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE];
87+
88+
#if IS_KV_COMPRESSED
89+
UNCOMPRESSED_TYPE cache_value_first = TO_UNCOMPRESSED_TYPE(key_cache[cache_offset] - comp_zp) * comp_scale;
90+
UNCOMPRESSED_TYPE cache_value_second = TO_UNCOMPRESSED_TYPE(key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] - comp_zp) * comp_scale;
91+
#else
92+
UNCOMPRESSED_TYPE cache_value_first = key_cache[cache_offset];
93+
UNCOMPRESSED_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE];
94+
#endif
7195

7296
INPUT2_TYPE rotation_value_cos = rotation_coefficients[i][token_coefficient_idx];
7397
INPUT2_TYPE rotation_value_sin = rotation_coefficients[i + (HEAD_SIZE / 2)][token_coefficient_idx];
7498

75-
OUTPUT_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin;
76-
OUTPUT_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos;
99+
UNCOMPRESSED_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin;
100+
UNCOMPRESSED_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos;
77101

102+
#if IS_KV_COMPRESSED
103+
max_value = fmax(max_value, new_cache_value_first);
104+
max_value = fmax(max_value, new_cache_value_second);
105+
min_value = fmin(min_value, new_cache_value_first);
106+
min_value = fmin(min_value, new_cache_value_second);
107+
108+
rotated_data[(i + 0) * PAGED_ATTENTION_BLOCK_SIZE + sglid] = new_cache_value_first;
109+
rotated_data[(i + (HEAD_SIZE / 2)) * PAGED_ATTENTION_BLOCK_SIZE + sglid] = new_cache_value_second;
110+
#else
78111
key_cache[cache_offset] = new_cache_value_first;
79112
key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] = new_cache_value_second;
113+
#endif
114+
}
115+
116+
#if IS_KV_COMPRESSED
117+
{
118+
#define ACCUMULATOR_TYPE float
119+
ACCUMULATOR_TYPE grp_max = 0.001;
120+
ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max) : (max_value - min_value);
121+
ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value);
122+
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) + CHAR_MIN;
123+
UNCOMPRESSED_TYPE scale = (INPUT1_TYPE)(scale_tmp);
124+
UNCOMPRESSED_TYPE zp = (INPUT1_TYPE)(zp_tmp);
125+
#undef ACCUMULATOR_TYPE
126+
127+
unroll_for (uint i = 0; i < HEAD_SIZE; i++) {
128+
OUTPUT_TYPE quantized_res = convert_char_rte(rotated_data[i * PAGED_ATTENTION_BLOCK_SIZE + sglid] * scale + zp);
129+
130+
const uint cache_offset = block_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
131+
key_cache[cache_offset] = quantized_res;
132+
}
133+
134+
comp_ptr[0 + sglid] = 1.0 / scale;
135+
comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid] = zp;
80136
}
137+
#endif
81138
}
82139

83140
#undef SUBGROUPS_PER_WG

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl

+85-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,52 @@
44

55
#include "include/batch_headers/common.cl"
66

7+
inline void quantize_and_save(__global const INPUT0_TYPE* in_data,
8+
const uint in_data_offset,
9+
__global OUTPUT_TYPE* out_data,
10+
const uint out_data_offset,
11+
const uint out_data_pitch,
12+
const uint comp_offset,
13+
const uint token_pos_in_block,
14+
const uint sglid) {
15+
INPUT0_TYPE input_data[HEAD_SIZE / SUBGROUP_SIZE];
16+
INPUT0_TYPE grp_max = 0.001;
17+
INPUT0_TYPE max_value = INPUT0_VAL_MIN;
18+
INPUT0_TYPE min_value = INPUT0_VAL_MAX;
19+
20+
unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {
21+
input_data[i] = BLOCK_READN(INPUT0_TYPE, 1, in_data, in_data_offset + i * SUBGROUP_SIZE);
22+
max_value = fmax(max_value, input_data[i]);
23+
min_value = fmin(min_value, input_data[i]);
24+
}
25+
26+
min_value = sub_group_reduce_min(min_value);
27+
max_value = sub_group_reduce_max(max_value);
28+
29+
// If the range of input data is zero, it is adjusted to the minimum value(0.001).
30+
#define ACCUMULATOR_TYPE float
31+
ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max) : (max_value - min_value);
32+
ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value);
33+
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) + CHAR_MIN;
34+
INPUT0_TYPE scale = (INPUT1_TYPE)(scale_tmp);
35+
INPUT0_TYPE zp = (INPUT1_TYPE)(zp_tmp);
36+
#undef ACCUMULATOR_TYPE
37+
38+
unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {
39+
OUTPUT_TYPE res = convert_char_rte(input_data[i] * scale + zp);
40+
41+
uint offset = out_data_offset + (i * SUBGROUP_SIZE + sglid) * out_data_pitch;
42+
out_data[offset] = res;
43+
}
44+
45+
INPUT0_TYPE* comp_ptr = out_data + comp_offset;
46+
47+
if (sglid == 0) {
48+
comp_ptr[token_pos_in_block] = 1.0 / scale;
49+
comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + token_pos_in_block] = zp;
50+
}
51+
}
52+
753
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
854
__attribute__((reqd_work_group_size(1, 1, SUBGROUP_SIZE)))
955
KERNEL(pa_kv_cache_update)(
@@ -41,8 +87,12 @@ KERNEL(pa_kv_cache_update)(
4187
seq_idx * (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_BEFORE_FEATURE_NUM + INPUT1_PAD_AFTER_FEATURE_NUM) +
4288
head_idx * HEAD_SIZE;
4389

44-
uint key_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block;
45-
uint value_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block * HEAD_SIZE;
90+
uint block_base_offset = block_idx * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
91+
uint key_out_offset = block_base_offset + current_token_pos_in_block;
92+
uint value_out_offset = block_base_offset + current_token_pos_in_block * HEAD_SIZE;
93+
const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
94+
95+
#if !IS_KV_COMPRESSED
4696

4797
#define READ_BLOCK_SIZE GENERATE_STAGE_BLOCK_SIZE
4898
for (uint head_idx_index = 0; head_idx_index < HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE) {
@@ -71,6 +121,14 @@ KERNEL(pa_kv_cache_update)(
71121
#endif
72122
}
73123
}
124+
125+
#else // !IS_KV_COMPRESSED
126+
// key processing
127+
quantize_and_save(key_data, key_in_offset, key_cache_data, key_out_offset, PAGED_ATTENTION_BLOCK_SIZE, comp_offset, current_token_pos_in_block, sglid);
128+
129+
// value processing
130+
quantize_and_save(value_data, value_in_offset, value_cache_data, value_out_offset, 1, comp_offset, current_token_pos_in_block, sglid);
131+
#endif // !IS_KV_COMPRESSED
74132
} else {
75133
// 1st token
76134
const uint block_idx = get_global_id(0);
@@ -99,17 +157,20 @@ KERNEL(pa_kv_cache_update)(
99157

100158
const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx;
101159

102-
uint key_out_offset = block_indices[block_offset] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
103-
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
104-
105-
uint value_out_offset = key_out_offset;
160+
uint block_base_offset = block_indices[block_offset] * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
161+
head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
162+
uint key_out_offset = block_base_offset;
163+
uint value_out_offset = block_base_offset;
164+
const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
106165

107166
key_out_offset += token_start_pos;
108167
value_out_offset += token_start_pos * HEAD_SIZE;
109168

110169
if (tokens_num == PAGED_ATTENTION_BLOCK_SIZE) {
111170
unroll_for (uint token_num = 0; token_num < PAGED_ATTENTION_BLOCK_SIZE; token_num++) {
112171
uint head_idx_index = 0;
172+
173+
#if !IS_KV_COMPRESSED
113174
#define READ_BLOCK_SIZE 8
114175
for (; head_idx_index + (READ_BLOCK_SIZE * SUBGROUP_SIZE) <= HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE) {
115176
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
@@ -190,15 +251,24 @@ KERNEL(pa_kv_cache_update)(
190251
}
191252
}
192253

254+
#else // !IS_KV_COMPRESSED
255+
// key processing
256+
quantize_and_save(key_data, key_in_offset, key_cache_data, key_out_offset, PAGED_ATTENTION_BLOCK_SIZE, comp_offset, token_num, sglid);
257+
258+
// value processing
259+
quantize_and_save(value_data, value_in_offset, value_cache_data, value_out_offset, 1, comp_offset, token_num, sglid);
260+
#endif // !IS_KV_COMPRESSED
261+
193262
key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM);
194263
value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM);
195264
key_out_offset += 1;
196265
value_out_offset += HEAD_SIZE;
197266
}
198267
} else {
199-
for (uint i = 0; i < tokens_num; i++) {
268+
for (uint token_num = 0; token_num < tokens_num; token_num++) {
200269
uint head_idx_index = 0;
201270

271+
#if !IS_KV_COMPRESSED
202272
#define READ_BLOCK_SIZE 1
203273
for (; head_idx_index + (READ_BLOCK_SIZE * SUBGROUP_SIZE) <= HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE) {
204274
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
@@ -219,6 +289,14 @@ KERNEL(pa_kv_cache_update)(
219289
}
220290
}
221291

292+
#else // IS_KV_COMPRESSED
293+
// key processing
294+
quantize_and_save(key_data, key_in_offset, key_cache_data, key_out_offset, PAGED_ATTENTION_BLOCK_SIZE, comp_offset, token_start_pos + token_num, sglid);
295+
296+
// value processing
297+
quantize_and_save(value_data, value_in_offset, value_cache_data, value_out_offset, 1, comp_offset, token_start_pos + token_num, sglid);
298+
299+
#endif // IS_KV_COMPRESSED
222300
key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM);
223301
value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM);
224302
key_out_offset += 1;

0 commit comments

Comments
 (0)