14
14
// constexpr size_t HEAD_SIZE = 64;
15
15
// constexpr size_t HEADS_NUM = 32;
16
16
// constexpr size_t KV_HEADS_NUM = 4;
17
+ // constexpr NUM_QUERIES_PER_KV_HEAD (HEADS_NUM / KV_HEADS_NUM)
17
18
// constexpr size_t BLOCK_SIZE = 16;
18
19
// constexpr size_t X_SIZE = 4;
19
20
29
30
// How much QK outputs each subgroup calculates per cycle
30
31
#define QK_PER_SG 4
31
32
32
- #define KV_CACHE_BLOCK_STRIDE (HEAD_SIZE * HEADS_NUM * BLOCK_SIZE)
33
+ #define KV_CACHE_BLOCK_STRIDE (HEAD_SIZE * KV_HEADS_NUM * BLOCK_SIZE)
33
34
34
35
#define QUERY_BLOCK_READ (ptr , offset ) BLOCK_READN(INPUT0_TYPE, 1, ptr, offset)
35
36
36
37
#define SUBGROUPS_PER_WG HEAD_SIZE / SUB_GROUP_SIZE
37
38
38
39
REQD_SUB_GROUP_SIZE (SUB_GROUP_SIZE )
39
- __attribute__((reqd_work_group_size (1 , 1 , SUB_GROUP_SIZE )))
40
+ __attribute__((reqd_work_group_size (1 , 1 , 64 )))
40
41
KERNEL (pa_sdpa_ref )(
41
42
OPTIONAL_SHAPE_INFO_ARG
42
43
__global const INPUT0_TYPE * query ,
@@ -45,6 +46,7 @@ KERNEL(pa_sdpa_ref)(
45
46
__global const INPUT3_TYPE * max_context_len ,
46
47
__global const INPUT4_TYPE * context_lens ,
47
48
__global const INPUT5_TYPE * block_tables ,
49
+ __global const INPUT6_TYPE * scale ,
48
50
__global OUTPUT_TYPE * output )
49
51
{
50
52
const uint seq_idx = get_global_id (0 );
@@ -60,6 +62,30 @@ KERNEL(pa_sdpa_ref)(
60
62
61
63
const uint blocks_num = INPUT5_FEATURE_NUM ;
62
64
65
+ // if (seq_idx < 2 && head_num_idx < 2 && sgid < 2 && sglid < 2) {
66
+ // if (INPUT5_FEATURE_NUM == 0) {
67
+ // printf("Empty blocks. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n",
68
+ // seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
69
+ // } else if (INPUT5_FEATURE_NUM == 1) {
70
+ // printf("Blocks table[b=0]: %d. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n", block_tables[0],
71
+ // seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
72
+ // } else if (INPUT5_FEATURE_NUM == 2) {
73
+ // printf("Blocks table[b=0]: %d %d. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n", block_tables[0], block_tables[1],
74
+ // seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
75
+ // } else if (INPUT5_FEATURE_NUM == 3) {
76
+ // printf("Blocks table[b=0]: %d %d %d. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n", block_tables[0], block_tables[1], block_tables[2],
77
+ // seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
78
+ // } else if (INPUT5_FEATURE_NUM == 4) {
79
+ // printf("Blocks table[b=0]: %d %d %d %d. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n", block_tables[0], block_tables[1], block_tables[2], block_tables[3],
80
+ // seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
81
+ // }
82
+
83
+ // if (seq_idx == 0 && head_num_idx == 0 && sgid == 0 && sglid == 0) {
84
+ // printf("key_cache[405504]=%f\n", key_cache[405504]);
85
+ // printf("value_cache[405504]=%f\n", value_cache[405504]);
86
+ // }
87
+ // }
88
+
63
89
// sgid0: 0..3
64
90
// sgid1: 4..7
65
91
// sgid2: 8..11
@@ -84,7 +110,9 @@ KERNEL(pa_sdpa_ref)(
84
110
OUTPUT_TYPE qk [QK_PER_SG ] = {0 };
85
111
86
112
for (uint hs = 0 ; hs < HEAD_ITEMS_PER_WI ; hs ++ ) {
87
- const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM + hs * SUB_GROUP_SIZE ;
113
+ const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM +
114
+ head_num_idx * HEAD_SIZE +
115
+ hs * SUB_GROUP_SIZE ;
88
116
89
117
// TODO: can be preloaded outside HEAD_ITEMS_PER_WI loop - need to check perf
90
118
INPUT0_TYPE q = QUERY_BLOCK_READ (query , query_idx );
@@ -94,34 +122,53 @@ KERNEL(pa_sdpa_ref)(
94
122
continue ;
95
123
96
124
const uint key_idx = block_offset +
125
+ (head_num_idx / NUM_QUERIES_PER_KV_HEAD ) * (HEAD_SIZE / X_SIZE * BLOCK_SIZE * X_SIZE ) +
97
126
(X_SIZE * QK_PER_SG ) * sgid +
98
127
(HEAD_ITEMS_PER_WI * BLOCK_SIZE * X_SIZE ) * hs +
99
128
(sglid / X_SIZE ) * X_SIZE * BLOCK_SIZE +
100
129
(sglid % X_SIZE ) + qk_idx * X_SIZE ;
130
+
101
131
// TODO1: try block loading and shuffling
102
132
// TODO2: try to load k*4 times and then calculate
103
133
// TODO3: try bigger X block
104
134
INPUT1_TYPE k = key_cache [key_idx ];
105
135
136
+
137
+ // if (seq_idx == 0 && head_num_idx == 0) {
138
+ // printf("main_calc: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d, block=%d, hs=%d, qk_idx=%d, current_token=%d, query_idx=%d, key_idx=%d (block_offset=%d): %f * %f\n",
139
+ // seq_idx, head_num_idx, sgid, sglid, block, hs, qk_idx, current_token, query_idx, key_idx - block_offset, block_offset, q, k);
140
+ // }
141
+
106
142
qk [qk_idx ] = mad (q , k , qk [qk_idx ]);
107
143
}
108
144
}
109
145
110
- // Summurize qk calculation across all WIs
146
+ // Summurize qk calculation across all WIs and apply scale
111
147
for (uint qk_idx = 0 ; qk_idx < QK_PER_SG ; qk_idx ++ ) {
112
- qk [QK_PER_SG ] = sub_group_reduce_add (qk [QK_PER_SG ]);
113
- qk_max = OUTPUT_MAX_FUNC (qk_max , qk [QK_PER_SG ]);
148
+ const uint current_token = block * BLOCK_SIZE + sgid * QK_PER_SG + qk_idx ;
149
+ if (current_token < context_len ) {
150
+ OUTPUT_TYPE tmp_print = qk [qk_idx ];
151
+ qk [qk_idx ] = sub_group_reduce_add (qk [qk_idx ]);
152
+ // if (head_num_idx < 4)
153
+ // printf("final_calc: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: before qk[%d]=%f, after=%f\n",
154
+ // seq_idx, head_num_idx, sgid, sglid, qk_idx, tmp_print, qk[qk_idx]);
155
+ qk [qk_idx ] = scale [0 ] * qk [qk_idx ];
156
+ qk_max = OUTPUT_MAX_FUNC (qk_max , qk [qk_idx ]);
157
+ }
114
158
}
115
159
116
160
// Save QK results to local memory
117
161
if (sglid < QK_PER_SG ) {
118
- const uint qk_local_idx = block * BLOCK_SIZE * sgid * QK_PER_SG + sglid ;
119
- qk_vals [qk_local_idx ] = qk [sglid ];
162
+ const uint current_token = block * BLOCK_SIZE + sgid * QK_PER_SG + sglid ;
163
+ // Fixed -> // const uint qk_local_idx = block * BLOCK_SIZE * sgid * QK_PER_SG + sglid;
164
+ // OUTPUT_TYPE tmp_print = (current_token >= context_len ? 0 : qk[sglid]);
165
+ // if (head_num_idx < 4 || head_num_idx == 31)
166
+ // printf("slm save: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: qk_vals[%d]=%f. Max=%f\n",
167
+ // seq_idx, head_num_idx, sgid, sglid, current_token, tmp_print, qk_max);
168
+ qk_vals [current_token ] = current_token >= context_len ? 0 : qk [sglid ];
120
169
}
121
170
}
122
171
123
- /* WARNING NEED TO ADD BIAS BEFORE SOFTMAX */
124
-
125
172
// Apply SoftMax operation
126
173
__local OUTPUT_TYPE qk_max_vals [SUBGROUPS_PER_WG ];
127
174
__local OUTPUT_TYPE qk_sum_vals [SUBGROUPS_PER_WG ];
@@ -138,23 +185,35 @@ KERNEL(pa_sdpa_ref)(
138
185
// Final max value after reduction across of all SG and WI
139
186
qk_max = sub_group_reduce_max (qk_max );
140
187
188
+ // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
189
+ // printf("QK max value = %f\n", qk_max);
190
+ // }
191
+
141
192
OUTPUT_TYPE exp_sum = OUTPUT_VAL_ZERO ;
142
193
for (uint qk_idx = 0 ; qk_idx < CEIL_DIV (context_len , SUBGROUPS_PER_WG * SUB_GROUP_SIZE ); qk_idx ++ ) {
143
194
const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE ) + sgid * SUB_GROUP_SIZE + sglid ;
144
195
if (data_idx < context_len ) {
145
196
OUTPUT_TYPE val = native_exp (qk_vals [data_idx ] - qk_max );
146
197
exp_sum += val ;
147
198
qk_vals [data_idx ] = val ;
199
+ // if (head_num_idx < 4 || head_num_idx == 31)
200
+ // printf("head_num %d, sgid = %d, sglid = %d, exp_sum = %f\n", head_num_idx, sgid, sglid, exp_sum);
148
201
}
149
202
}
150
203
151
204
exp_sum = sub_group_reduce_add (exp_sum );
152
205
206
+ // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
207
+ // printf("exp_sum final value = %f\n", exp_sum);
208
+ // }
209
+
153
210
if (sglid == 0 )
154
211
qk_sum_vals [sgid ] = exp_sum ;
155
212
156
213
barrier (CLK_LOCAL_MEM_FENCE );
157
214
215
+ exp_sum = OUTPUT_VAL_ZERO ;
216
+
158
217
if (sglid < SUBGROUPS_PER_WG )
159
218
exp_sum = qk_sum_vals [sglid ];
160
219
@@ -163,6 +222,8 @@ KERNEL(pa_sdpa_ref)(
163
222
164
223
const OUTPUT_TYPE inv_sum = OUTPUT_VAL_ONE / exp_sum ;
165
224
225
+
226
+ // TODO: replace CEIL_DIV with ALIGN and use += SUBGROUPS_PER_WG * SUB_GROUP_SIZE increment
166
227
for (uint qk_idx = 0 ; qk_idx < CEIL_DIV (context_len , SUBGROUPS_PER_WG * SUB_GROUP_SIZE ); qk_idx ++ ) {
167
228
const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE ) + sgid * SUB_GROUP_SIZE + sglid ;
168
229
if (data_idx < context_len ) {
@@ -174,5 +235,61 @@ KERNEL(pa_sdpa_ref)(
174
235
barrier (CLK_LOCAL_MEM_FENCE );
175
236
}
176
237
177
- output [seq_idx + sglid ] = qk_vals [sglid % context_len ];
238
+ // if (seq_idx == 0 && sgid == 0 && sglid == 0) {
239
+ // for (uint i = 0; i < context_len; i++) {
240
+ // printf("Softmax res for %d head: %d. %f\n", head_num_idx, i, qk_vals[i]);
241
+ // }
242
+ // }
243
+
244
+ {
245
+ OUTPUT_TYPE acc = OUTPUT_VAL_ZERO ;
246
+
247
+ for (uint qk_idx = 0 ; qk_idx < ALIGN (context_len , SUB_GROUP_SIZE ); qk_idx += SUB_GROUP_SIZE ) {
248
+ const uint qk_offset = qk_idx + sglid ;
249
+
250
+ OUTPUT_TYPE qk = qk_offset < context_len ? qk_vals [qk_offset ] : OUTPUT_VAL_ZERO ;
251
+
252
+ const uint block_idx = block_tables [batch_idx * blocks_num + (qk_idx / SUB_GROUP_SIZE )];
253
+ if (block_idx == 0 )
254
+ continue ;
255
+
256
+ const uint value_cache_offset = block_idx * KV_CACHE_BLOCK_STRIDE +
257
+ (head_num_idx / NUM_QUERIES_PER_KV_HEAD ) * (HEAD_SIZE * BLOCK_SIZE ) +
258
+ sgid * (SUB_GROUP_SIZE * BLOCK_SIZE ) +
259
+ sglid * BLOCK_SIZE ;
260
+
261
+ #define VALUE_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE)
262
+ #define VALUE_VLOAD (offset , ptr ) CAT(vload, BLOCK_SIZE)(offset, ptr)
263
+
264
+ ushort16 v_tmp = vload16 (0 , (__global ushort * )(value_cache + value_cache_offset ));
265
+ OUTPUT_TYPE * v = (OUTPUT_TYPE * )& v_tmp ;
266
+
267
+ // VALUE_VEC_TYPE* tmp_print = v;
268
+
269
+ // if (seq_idx == 0 && head_num_idx == 0) {
270
+ // printf("gemm2: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d, block_idx=%d, qk_idx=%d, qk_offset=%d, value_offset=%d (block_offset=%d): %v8f\n",
271
+ // seq_idx, head_num_idx, sgid, sglid, block_idx, qk_idx, qk_offset, value_cache_offset - (block_idx * KV_CACHE_BLOCK_STRIDE), block_idx * KV_CACHE_BLOCK_STRIDE, *tmp_print);
272
+ // }
273
+
274
+ for (uint token = 0 ; token < BLOCK_SIZE ; token ++ ) {
275
+ OUTPUT_TYPE qk_tmp = sub_group_broadcast (qk , token );
276
+ if (qk_idx + token < context_len ) {
277
+ acc = mad (qk_tmp , v [token ], acc );
278
+ }
279
+ }
280
+ }
281
+
282
+
283
+ const uint output_offset = seq_idx * (HEADS_NUM * HEAD_SIZE ) +
284
+ head_num_idx * HEAD_SIZE +
285
+ sgid * SUB_GROUP_SIZE +
286
+ sglid ;
287
+
288
+ // if (seq_idx == 0 && head_num_idx < 2 || head_num_idx == 31) {
289
+ // printf("output res: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: output[%d] = %f\n",
290
+ // seq_idx, head_num_idx, sgid, sglid, output_offset, acc);
291
+ // }
292
+
293
+ output [output_offset ] = acc ;
294
+ }
178
295
}
0 commit comments