Skip to content

Commit 9fa105e

Browse files
committed
[GPU] Enable KV-cache compression in PagedAttention OCL kernel
1 parent b88dcc6 commit 9fa105e

File tree

15 files changed

+572
-94
lines changed

15 files changed

+572
-94
lines changed

src/common/transformations/include/transformations/common_optimizations/convert_pagedattn_inputs.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class TRANSFORMATIONS_API ConvertPagedAttnInputs;
1919

2020
class ConvertPagedAttnInputs : public ov::pass::MatcherPass {
2121
public:
22+
using UpdateShapeFunc = std::function<void(ov::element::Type, bool, size_t, int64_t&, int64_t&)>;
2223
struct KVCacheConfig {
2324
ov::element::Type keyCachePrecision;
2425
ov::element::Type valueCachePrecision;
@@ -34,14 +35,15 @@ class ConvertPagedAttnInputs : public ov::pass::MatcherPass {
3435
};
3536

3637
OPENVINO_MATCHER_PASS_RTTI("ConvertPagedAttnInputs");
37-
ConvertPagedAttnInputs(const KVCacheConfig& config);
38+
ConvertPagedAttnInputs(const KVCacheConfig& config, UpdateShapeFunc update_shape_func);
3839

3940
void setKVCacheConfig(const KVCacheConfig& config);
4041

4142
const KVCacheConfig& getKVCacheConfig() const;
4243

4344
private:
4445
KVCacheConfig m_config;
46+
UpdateShapeFunc m_update_shape_func;
4547
};
4648

4749
} // namespace pass

src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp

+6-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include "transformations/utils/utils.hpp"
1818
using namespace ov::gen_pattern;
1919

20-
ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config) : m_config(config) {
20+
ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config, UpdateShapeFunc f)
21+
: m_config(config),
22+
m_update_shape_func(std::move(f)) {
2123
MATCHER_SCOPE(ConvertPagedAttnInputs);
2224

2325
auto Q = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
@@ -83,7 +85,7 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co
8385
const size_t group_size,
8486
const bool bychannel,
8587
const std::vector<size_t>& orders) {
86-
size_t _block_size = block_size;
88+
ov::Dimension::value_type _block_size = block_size;
8789
ov::Dimension::value_type _head_nums = head_nums;
8890
ov::Dimension::value_type _head_size = head_size;
8991
ov::Dimension::value_type _group_size = group_size;
@@ -94,17 +96,9 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co
9496
}
9597
}
9698
size_t group_num = _head_size / _group_size;
97-
if (precision == ov::element::u8) {
98-
if (bychannel) {
99-
_block_size += 2 * sizeof(float);
100-
} else {
101-
_head_size += sizeof(float) * 2 * group_num;
102-
}
103-
} else if (precision == ov::element::u4) {
104-
_head_size += sizeof(float) * 2 * group_num * 2;
105-
}
106-
auto block_shape = ov::PartialShape::dynamic(4);
99+
m_update_shape_func(precision, bychannel, group_num, _head_size, _block_size);
107100

101+
auto block_shape = ov::PartialShape::dynamic(4);
108102
block_shape[orders[0]] = -1;
109103
block_shape[orders[1]] = _head_nums;
110104
block_shape[orders[2]] = _block_size;

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,25 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
470470
cacheConfig.valueCacheQuantBychannel = false;
471471
cacheConfig.keyCacheDimOrder = {0, 1, 2, 3};
472472
cacheConfig.valueCacheDimOrder = {0, 1, 2, 3};
473-
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertPagedAttnInputs, cacheConfig);
473+
CPU_REGISTER_PASS_COMMON(
474+
manager,
475+
ov::pass::ConvertPagedAttnInputs,
476+
cacheConfig,
477+
[](const ov::element::Type& precision,
478+
const bool bychannel,
479+
const size_t group_num,
480+
int64_t& head_size,
481+
int64_t& block_size) {
482+
if (precision == ov::element::u8) {
483+
if (bychannel) {
484+
block_size += 2 * sizeof(float);
485+
} else {
486+
head_size += sizeof(float) * 2 * group_num;
487+
}
488+
} else if (precision == ov::element::u4) {
489+
head_size += sizeof(float) * 2 * group_num * 2;
490+
}
491+
});
474492
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
475493
CPU_REGISTER_PASS_X64(manager, ov::pass::KeepConstPrecision, decompression_precisions, false, true);
476494
CPU_SET_CALLBACK_X64(

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
671671
config.paged_attention_max_len = max_context_len_mem_lock[0];
672672
}
673673

674+
if (data_type_traits::is_i8_u8(impl_param.get_input_layout(3).data_type)) {
675+
config.is_kv_compressed = true;
676+
config.use_asymmetric_quantization = true;
677+
}
678+
674679
return config;
675680
}
676681

@@ -693,6 +698,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
693698
params.inputs[2] = rotation_trig_lut_tensor;
694699
params.outputs[0] = key_cache_tensor;
695700

701+
params.original_cache_dt = to_data_type(impl_param.get_input_layout(1).data_type);
696702
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
697703

698704
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
@@ -810,6 +816,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
810816

811817
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
812818

819+
// Currently, for the processing of the 1st token, plain SDPA kernels are used, which expect
820+
// uncompressed plain QKV inputs. Therefore, set is_kv_compressed=false
821+
params.conf.is_kv_compressed = false;
822+
params.conf.use_asymmetric_quantization = false;
823+
813824
const std::vector<int64_t> default_order = {0, 1, 2, 3};
814825
params.input0_order = default_order;
815826
params.input1_order = default_order;
@@ -975,6 +986,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
975986
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
976987
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
977988
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
989+
978990
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, 0, impl_param.is_dynamic());
979991
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
980992
kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params));

src/plugins/intel_gpu/src/graph/paged_attention.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
7373
paged_attention_info.add("scale", desc->scale_val.value_or(1.0f));
7474
paged_attention_info.add("has_alibi", desc->has_alibi);
7575
paged_attention_info.add("has_rotated_blocks", desc->has_rotated_blocks);
76+
paged_attention_info.add("key_cache_dt", node.get_input_layout(3).data_type);
77+
paged_attention_info.add("value_cache_dt", node.get_input_layout(4).data_type);
7678
node_info->add("paged_attention primitive info", paged_attention_info);
7779
node_info->dump(primitive_description);
7880

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_rotate_ref.cl

+67-8
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44

55
#include "include/batch_headers/common.cl"
66

7+
#if IS_KV_COMPRESSED
8+
#define SUBGROUPS_PER_WG 1
9+
#else
710
#define SUBGROUPS_PER_WG KV_HEADS_NUM
11+
#endif
12+
#define ACCUMULATOR_TYPE float
813

914
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
10-
__attribute__((reqd_work_group_size(SUBGROUP_SIZE, KV_HEADS_NUM, 1)))
15+
__attribute__((reqd_work_group_size(SUBGROUP_SIZE, SUBGROUPS_PER_WG, 1)))
1116
KERNEL(pa_kv_cache_rotate)(
1217
OPTIONAL_SHAPE_INFO_ARG
1318
__global const INPUT0_TYPE* rotated_block_indices,
@@ -62,22 +67,76 @@ KERNEL(pa_kv_cache_rotate)(
6267
barrier(CLK_LOCAL_MEM_FENCE);
6368

6469
const uint token_coefficient_idx = per_token_rotation ? sglid : 0;
65-
const uint block_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
66-
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + sglid;
70+
const uint block_base_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
71+
head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
72+
const uint token_offset = block_base_offset + sglid;
73+
74+
#if IS_KV_COMPRESSED
75+
const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
76+
UNCOMPRESSED_TYPE* comp_ptr = key_cache + comp_offset;
77+
UNCOMPRESSED_TYPE comp_scale = comp_ptr[0 + sglid];
78+
UNCOMPRESSED_TYPE comp_zp = comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid];
79+
80+
UNCOMPRESSED_TYPE max_value = UNCOMPRESSED_VAL_MIN;
81+
UNCOMPRESSED_TYPE min_value = UNCOMPRESSED_VAL_MAX;
82+
83+
// Reuse SLM to store dequantized rotated values
84+
__local UNCOMPRESSED_TYPE* rotated_data = (__local UNCOMPRESSED_TYPE*)(&rotation_coefficients[0][0]);
85+
#endif
86+
87+
// Apply cache rotation
6788
for (uint i = 0; i < HEAD_SIZE / 2; i++) {
68-
const uint cache_offset = block_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
69-
OUTPUT_TYPE cache_value_first = key_cache[cache_offset];
70-
OUTPUT_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE];
89+
const uint cache_offset = token_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
90+
91+
#if IS_KV_COMPRESSED
92+
UNCOMPRESSED_TYPE cache_value_first = TO_UNCOMPRESSED_TYPE(key_cache[cache_offset] - comp_zp) * comp_scale;
93+
UNCOMPRESSED_TYPE cache_value_second = TO_UNCOMPRESSED_TYPE(key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] - comp_zp) * comp_scale;
94+
#else
95+
UNCOMPRESSED_TYPE cache_value_first = key_cache[cache_offset];
96+
UNCOMPRESSED_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE];
97+
#endif
7198

7299
INPUT2_TYPE rotation_value_cos = rotation_coefficients[i][token_coefficient_idx];
73100
INPUT2_TYPE rotation_value_sin = rotation_coefficients[i + (HEAD_SIZE / 2)][token_coefficient_idx];
74101

75-
OUTPUT_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin;
76-
OUTPUT_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos;
102+
UNCOMPRESSED_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin;
103+
UNCOMPRESSED_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos;
77104

105+
#if IS_KV_COMPRESSED
106+
max_value = fmax(fmax(max_value, new_cache_value_first), new_cache_value_second);
107+
min_value = fmin(fmin(min_value, new_cache_value_first), new_cache_value_second);
108+
109+
rotated_data[(i + 0) * PAGED_ATTENTION_BLOCK_SIZE + sglid] = new_cache_value_first;
110+
rotated_data[(i + (HEAD_SIZE / 2)) * PAGED_ATTENTION_BLOCK_SIZE + sglid] = new_cache_value_second;
111+
#else
78112
key_cache[cache_offset] = new_cache_value_first;
79113
key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] = new_cache_value_second;
114+
#endif
115+
}
116+
117+
#if IS_KV_COMPRESSED
118+
// Re-quantize cache data
119+
ACCUMULATOR_TYPE grp_max = 0.001;
120+
ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max) : (max_value - min_value);
121+
ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value);
122+
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) + CHAR_MIN;
123+
UNCOMPRESSED_TYPE scale = (UNCOMPRESSED_TYPE)(scale_tmp);
124+
UNCOMPRESSED_TYPE zp = (UNCOMPRESSED_TYPE)(zp_tmp);
125+
126+
// Note: absence of this explicit unrolling directive leads to automatic
127+
// unrolling and causes registers spill. Set unrolling to a reasonable value manually
128+
__attribute__((opencl_unroll_hint(8)))
129+
for (uint i = 0; i < HEAD_SIZE; i++) {
130+
OUTPUT_TYPE quantized_res = convert_char_rte(rotated_data[i * PAGED_ATTENTION_BLOCK_SIZE + sglid] * scale + zp);
131+
132+
const uint cache_offset = token_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
133+
key_cache[cache_offset] = quantized_res;
80134
}
135+
136+
comp_ptr[0 + sglid] = 1.0 / scale;
137+
comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid] = zp;
138+
#endif
81139
}
82140

141+
#undef ACCUMULATOR_TYPE
83142
#undef SUBGROUPS_PER_WG

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl

+84-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,52 @@
44

55
#include "include/batch_headers/common.cl"
66

7+
inline void FUNC(quantize_and_save)(__global const INPUT0_TYPE* in_data,
8+
const uint in_data_offset,
9+
__global OUTPUT_TYPE* out_data,
10+
const uint out_data_offset,
11+
const uint out_data_pitch,
12+
const uint comp_offset,
13+
const uint token_pos_in_block,
14+
const uint sglid) {
15+
INPUT0_TYPE input_data[HEAD_SIZE / SUBGROUP_SIZE];
16+
INPUT0_TYPE grp_max = 0.001;
17+
INPUT0_TYPE max_value = INPUT0_VAL_MIN;
18+
INPUT0_TYPE min_value = INPUT0_VAL_MAX;
19+
20+
unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {
21+
input_data[i] = BLOCK_READN(INPUT0_TYPE, 1, in_data, in_data_offset + i * SUBGROUP_SIZE);
22+
max_value = fmax(max_value, input_data[i]);
23+
min_value = fmin(min_value, input_data[i]);
24+
}
25+
26+
min_value = sub_group_reduce_min(min_value);
27+
max_value = sub_group_reduce_max(max_value);
28+
29+
// If the range of input data is zero, it is adjusted to the minimum value(0.001).
30+
#define ACCUMULATOR_TYPE float
31+
ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max) : (max_value - min_value);
32+
ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value);
33+
ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) + CHAR_MIN;
34+
INPUT0_TYPE scale = (INPUT1_TYPE)(scale_tmp);
35+
INPUT0_TYPE zp = (INPUT1_TYPE)(zp_tmp);
36+
#undef ACCUMULATOR_TYPE
37+
38+
unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {
39+
OUTPUT_TYPE res = convert_char_rte(input_data[i] * scale + zp);
40+
41+
uint offset = out_data_offset + (i * SUBGROUP_SIZE + sglid) * out_data_pitch;
42+
out_data[offset] = res;
43+
}
44+
45+
INPUT0_TYPE* comp_ptr = out_data + comp_offset;
46+
47+
if (sglid == 0) {
48+
comp_ptr[token_pos_in_block] = 1.0 / scale;
49+
comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + token_pos_in_block] = zp;
50+
}
51+
}
52+
753
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
854
__attribute__((reqd_work_group_size(1, 1, SUBGROUP_SIZE)))
955
KERNEL(pa_kv_cache_update)(
@@ -41,8 +87,12 @@ KERNEL(pa_kv_cache_update)(
4187
seq_idx * (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_BEFORE_FEATURE_NUM + INPUT1_PAD_AFTER_FEATURE_NUM) +
4288
head_idx * HEAD_SIZE;
4389

44-
uint key_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block;
45-
uint value_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block * HEAD_SIZE;
90+
uint block_base_offset = block_idx * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
91+
uint key_out_offset = block_base_offset + current_token_pos_in_block;
92+
uint value_out_offset = block_base_offset + current_token_pos_in_block * HEAD_SIZE;
93+
const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
94+
95+
#if !IS_KV_COMPRESSED
4696

4797
#define READ_BLOCK_SIZE GENERATE_STAGE_BLOCK_SIZE
4898
for (uint head_idx_index = 0; head_idx_index < HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE) {
@@ -71,6 +121,14 @@ KERNEL(pa_kv_cache_update)(
71121
#endif
72122
}
73123
}
124+
125+
#else // IS_KV_COMPRESSED
126+
// key processing
127+
FUNC_CALL(quantize_and_save)(key_data, key_in_offset, key_cache_data, key_out_offset, PAGED_ATTENTION_BLOCK_SIZE, comp_offset, current_token_pos_in_block, sglid);
128+
129+
// value processing
130+
FUNC_CALL(quantize_and_save)(value_data, value_in_offset, value_cache_data, value_out_offset, 1, comp_offset, current_token_pos_in_block, sglid);
131+
#endif // IS_KV_COMPRESSED
74132
} else {
75133
// 1st token
76134
const uint block_idx = get_global_id(0);
@@ -99,17 +157,20 @@ KERNEL(pa_kv_cache_update)(
99157

100158
const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx;
101159

102-
uint key_out_offset = block_indices[block_offset] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
103-
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
104-
105-
uint value_out_offset = key_out_offset;
160+
uint block_base_offset = block_indices[block_offset] * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
161+
head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
162+
uint key_out_offset = block_base_offset;
163+
uint value_out_offset = block_base_offset;
164+
const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
106165

107166
key_out_offset += token_start_pos;
108167
value_out_offset += token_start_pos * HEAD_SIZE;
109168

110169
if (tokens_num == PAGED_ATTENTION_BLOCK_SIZE) {
111170
unroll_for (uint token_num = 0; token_num < PAGED_ATTENTION_BLOCK_SIZE; token_num++) {
112171
uint head_idx_index = 0;
172+
173+
#if !IS_KV_COMPRESSED
113174
#define READ_BLOCK_SIZE 8
114175
for (; head_idx_index + (READ_BLOCK_SIZE * SUBGROUP_SIZE) <= HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE) {
115176
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
@@ -190,15 +251,24 @@ KERNEL(pa_kv_cache_update)(
190251
}
191252
}
192253

254+
#else // IS_KV_COMPRESSED
255+
// key processing
256+
FUNC_CALL(quantize_and_save)(key_data, key_in_offset, key_cache_data, key_out_offset, PAGED_ATTENTION_BLOCK_SIZE, comp_offset, token_num, sglid);
257+
258+
// value processing
259+
FUNC_CALL(quantize_and_save)(value_data, value_in_offset, value_cache_data, value_out_offset, 1, comp_offset, token_num, sglid);
260+
#endif // IS_KV_COMPRESSED
261+
193262
key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM);
194263
value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM);
195264
key_out_offset += 1;
196265
value_out_offset += HEAD_SIZE;
197266
}
198267
} else {
199-
for (uint i = 0; i < tokens_num; i++) {
268+
for (uint token_num = 0; token_num < tokens_num; token_num++) {
200269
uint head_idx_index = 0;
201270

271+
#if !IS_KV_COMPRESSED
202272
#define READ_BLOCK_SIZE 1
203273
for (; head_idx_index + (READ_BLOCK_SIZE * SUBGROUP_SIZE) <= HEAD_SIZE; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE) {
204274
#define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
@@ -219,6 +289,13 @@ KERNEL(pa_kv_cache_update)(
219289
}
220290
}
221291

292+
#else // IS_KV_COMPRESSED
293+
// key processing
294+
FUNC_CALL(quantize_and_save)(key_data, key_in_offset, key_cache_data, key_out_offset, PAGED_ATTENTION_BLOCK_SIZE, comp_offset, token_start_pos + token_num, sglid);
295+
296+
// value processing
297+
FUNC_CALL(quantize_and_save)(value_data, value_in_offset, value_cache_data, value_out_offset, 1, comp_offset, token_start_pos + token_num, sglid);
298+
#endif // IS_KV_COMPRESSED
222299
key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM);
223300
value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM);
224301
key_out_offset += 1;

0 commit comments

Comments
 (0)