4
4
5
5
#include "include/batch_headers/common.cl"
6
6
7
+ inline void FUNC (quantize_and_save )(__global const INPUT0_TYPE * in_data ,
8
+ const uint in_data_offset ,
9
+ __global OUTPUT_TYPE * out_data ,
10
+ const uint out_data_offset ,
11
+ const uint out_data_pitch ,
12
+ const uint comp_offset ,
13
+ const uint token_pos_in_block ,
14
+ const uint sglid ) {
15
+ INPUT0_TYPE input_data [HEAD_SIZE / SUBGROUP_SIZE ];
16
+ INPUT0_TYPE grp_max = 0.001 ;
17
+ INPUT0_TYPE max_value = INPUT0_VAL_MIN ;
18
+ INPUT0_TYPE min_value = INPUT0_VAL_MAX ;
19
+
20
+ unroll_for (uint i = 0 ; i < HEAD_SIZE / SUBGROUP_SIZE ; i ++ ) {
21
+ input_data [i ] = BLOCK_READN (INPUT0_TYPE , 1 , in_data , in_data_offset + i * SUBGROUP_SIZE );
22
+ max_value = fmax (max_value , input_data [i ]);
23
+ min_value = fmin (min_value , input_data [i ]);
24
+ }
25
+
26
+ min_value = sub_group_reduce_min (min_value );
27
+ max_value = sub_group_reduce_max (max_value );
28
+
29
+ // If the range of input data is zero, it is adjusted to the minimum value(0.001).
30
+ #define ACCUMULATOR_TYPE float
31
+ ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max ) : (max_value - min_value );
32
+ ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE )((CHAR_MAX - CHAR_MIN ) / diff_value );
33
+ ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE )(- min_value * scale_tmp ) + CHAR_MIN ;
34
+ INPUT0_TYPE scale = (INPUT1_TYPE )(scale_tmp );
35
+ INPUT0_TYPE zp = (INPUT1_TYPE )(zp_tmp );
36
+ #undef ACCUMULATOR_TYPE
37
+
38
+ unroll_for (uint i = 0 ; i < HEAD_SIZE / SUBGROUP_SIZE ; i ++ ) {
39
+ OUTPUT_TYPE res = convert_char_rte (input_data [i ] * scale + zp );
40
+
41
+ uint offset = out_data_offset + (i * SUBGROUP_SIZE + sglid ) * out_data_pitch ;
42
+ out_data [offset ] = res ;
43
+ }
44
+
45
+ INPUT0_TYPE * comp_ptr = out_data + comp_offset ;
46
+
47
+ if (sglid == 0 ) {
48
+ comp_ptr [token_pos_in_block ] = 1.0 / scale ;
49
+ comp_ptr [PAGED_ATTENTION_BLOCK_SIZE + token_pos_in_block ] = zp ;
50
+ }
51
+ }
52
+
7
53
REQD_SUB_GROUP_SIZE (SUBGROUP_SIZE )
8
54
__attribute__((reqd_work_group_size (1 , 1 , SUBGROUP_SIZE )))
9
55
KERNEL (pa_kv_cache_update )(
@@ -41,8 +87,12 @@ KERNEL(pa_kv_cache_update)(
41
87
seq_idx * (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_BEFORE_FEATURE_NUM + INPUT1_PAD_AFTER_FEATURE_NUM ) +
42
88
head_idx * HEAD_SIZE ;
43
89
44
- uint key_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block ;
45
- uint value_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block * HEAD_SIZE ;
90
+ uint block_base_offset = block_idx * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE ;
91
+ uint key_out_offset = block_base_offset + current_token_pos_in_block ;
92
+ uint value_out_offset = block_base_offset + current_token_pos_in_block * HEAD_SIZE ;
93
+ const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE ;
94
+
95
+ #if !IS_KV_COMPRESSED
46
96
47
97
#define READ_BLOCK_SIZE GENERATE_STAGE_BLOCK_SIZE
48
98
for (uint head_idx_index = 0 ; head_idx_index < HEAD_SIZE ; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE ) {
@@ -71,6 +121,14 @@ KERNEL(pa_kv_cache_update)(
71
121
#endif
72
122
}
73
123
}
124
+
125
+ #else // IS_KV_COMPRESSED
126
+ // key processing
127
+ FUNC_CALL (quantize_and_save )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE , comp_offset , current_token_pos_in_block , sglid );
128
+
129
+ // value processing
130
+ FUNC_CALL (quantize_and_save )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 , comp_offset , current_token_pos_in_block , sglid );
131
+ #endif // IS_KV_COMPRESSED
74
132
} else {
75
133
// 1st token
76
134
const uint block_idx = get_global_id (0 );
@@ -99,17 +157,20 @@ KERNEL(pa_kv_cache_update)(
99
157
100
158
const uint block_offset = block_indices_begins [subsequence_idx ] + current_block_idx ;
101
159
102
- uint key_out_offset = block_indices [block_offset ] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
103
- head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE ;
104
-
105
- uint value_out_offset = key_out_offset ;
160
+ uint block_base_offset = block_indices [block_offset ] * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
161
+ head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE ;
162
+ uint key_out_offset = block_base_offset ;
163
+ uint value_out_offset = block_base_offset ;
164
+ const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE ;
106
165
107
166
key_out_offset += token_start_pos ;
108
167
value_out_offset += token_start_pos * HEAD_SIZE ;
109
168
110
169
if (tokens_num == PAGED_ATTENTION_BLOCK_SIZE ) {
111
170
unroll_for (uint token_num = 0 ; token_num < PAGED_ATTENTION_BLOCK_SIZE ; token_num ++ ) {
112
171
uint head_idx_index = 0 ;
172
+
173
+ #if !IS_KV_COMPRESSED
113
174
#define READ_BLOCK_SIZE 8
114
175
for (; head_idx_index + (READ_BLOCK_SIZE * SUBGROUP_SIZE ) <= HEAD_SIZE ; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE ) {
115
176
#define BLOCK_READ (ptr , offset ) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
@@ -190,15 +251,24 @@ KERNEL(pa_kv_cache_update)(
190
251
}
191
252
}
192
253
254
+ #else // IS_KV_COMPRESSED
255
+ // key processing
256
+ FUNC_CALL (quantize_and_save )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE , comp_offset , token_num , sglid );
257
+
258
+ // value processing
259
+ FUNC_CALL (quantize_and_save )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 , comp_offset , token_num , sglid );
260
+ #endif // IS_KV_COMPRESSED
261
+
193
262
key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM );
194
263
value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM );
195
264
key_out_offset += 1 ;
196
265
value_out_offset += HEAD_SIZE ;
197
266
}
198
267
} else {
199
- for (uint i = 0 ; i < tokens_num ; i ++ ) {
268
+ for (uint token_num = 0 ; token_num < tokens_num ; token_num ++ ) {
200
269
uint head_idx_index = 0 ;
201
270
271
+ #if !IS_KV_COMPRESSED
202
272
#define READ_BLOCK_SIZE 1
203
273
for (; head_idx_index + (READ_BLOCK_SIZE * SUBGROUP_SIZE ) <= HEAD_SIZE ; head_idx_index += SUBGROUP_SIZE * READ_BLOCK_SIZE ) {
204
274
#define BLOCK_READ (ptr , offset ) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset);
@@ -219,6 +289,13 @@ KERNEL(pa_kv_cache_update)(
219
289
}
220
290
}
221
291
292
+ #else // IS_KV_COMPRESSED
293
+ // key processing
294
+ FUNC_CALL (quantize_and_save )(key_data , key_in_offset , key_cache_data , key_out_offset , PAGED_ATTENTION_BLOCK_SIZE , comp_offset , token_start_pos + token_num , sglid );
295
+
296
+ // value processing
297
+ FUNC_CALL (quantize_and_save )(value_data , value_in_offset , value_cache_data , value_out_offset , 1 , comp_offset , token_start_pos + token_num , sglid );
298
+ #endif // IS_KV_COMPRESSED
222
299
key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM );
223
300
value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM );
224
301
key_out_offset += 1 ;
0 commit comments