14
14
#define Q_LOAD_ITERS (HEAD_SIZE / SUB_GROUP_SIZE)
15
15
16
16
// How much QK outputs each subgroup calculates per block
17
- #define QK_VALS_PER_SG_PER_ITER (BLOCK_SIZE / SUBGROUPS_PER_WG)
17
+ #define QK_VALS_PER_SG_PER_ITER CEIL_DIV (BLOCK_SIZE, SUBGROUPS_PER_WG)
18
18
19
19
#define KV_CACHE_BLOCK_STRIDE (HEAD_SIZE * KV_HEADS_NUM * BLOCK_SIZE)
20
20
@@ -35,6 +35,7 @@ KERNEL(pa_sdpa_ref)(
35
35
const __global INPUT4_TYPE * context_lens ,
36
36
const __global INPUT5_TYPE * block_tables ,
37
37
const __global INPUT6_TYPE * scale ,
38
+ const __global INPUT7_TYPE * is_prompt ,
38
39
#ifdef USE_SEQ_LEN_SPLIT
39
40
__global OUTPUT_TYPE * output ,
40
41
__global ACCUMULATOR_TYPE * exp_sums ,
@@ -71,6 +72,10 @@ KERNEL(pa_sdpa_ref)(
71
72
72
73
const uint total_blocks_num = CEIL_DIV (context_len , BLOCK_SIZE );
73
74
75
+ // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
76
+ // printf("context_len=%d block_start_idx=%d total_blocks_num=%d context_len=%d, SCALE_VAL=%f is_prompt=%d\n", context_len, block_start_idx, total_blocks_num, context_len, scale[0], is_prompt[0]);
77
+ // }
78
+
74
79
__local OUTPUT_TYPE qk_vals_local [SHARED_MEM_SIZE ];
75
80
ACCUMULATOR_TYPE qk_max = ACCUMULATOR_VAL_MIN ;
76
81
@@ -99,7 +104,12 @@ KERNEL(pa_sdpa_ref)(
99
104
for (uint q_idx = 0 ; q_idx < Q_LOAD_ITERS ; q_idx ++ ) {
100
105
for (uint qk_idx = 0 ; qk_idx < QK_VALS_PER_SG_PER_ITER ; qk_idx ++ ) {
101
106
uint current_token = (block_start_idx + block_num ) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx ;
107
+ #if BLOCK_SIZE % SUBGROUPS_PER_WG != 0
108
+ // TODO: Optimize for BLOCK_SIZE % SUBGROUPS_PER_WG != 0 case
109
+ if (current_token >= context_len || sgid >= BLOCK_SIZE / QK_VALS_PER_SG_PER_ITER )
110
+ #else
102
111
if (current_token >= context_len )
112
+ #endif
103
113
continue ;
104
114
105
115
const uint key_idx = block_offset +
@@ -120,27 +130,44 @@ KERNEL(pa_sdpa_ref)(
120
130
}
121
131
}
122
132
133
+ // if (context_len == 17 && sgid == 4 && QK_VALS_PER_SG_PER_ITER == 4 && (head_num_idx == 0 || head_num_idx == 1 || head_num_idx == 28)) {
134
+ // printf("FROM SGID=4; token_idx=%d, head_num=%d block_num=%d, sglid=%d: %f %f %f %f \n", token_idx, head_num_idx, block_num, sglid,
135
+ // qk[0], qk[1], qk[2], qk[3]);
136
+ // }
137
+
123
138
// Summurize qk calculation across all WIs and apply scale
124
139
for (uint qk_idx = 0 ; qk_idx < QK_VALS_PER_SG_PER_ITER ; qk_idx ++ ) {
125
140
const uint current_token = (block_start_idx + block_num ) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx ;
141
+ #if BLOCK_SIZE % SUBGROUPS_PER_WG != 0
142
+ if (current_token < context_len && sgid < BLOCK_SIZE / QK_VALS_PER_SG_PER_ITER ) {
143
+ #else
126
144
if (current_token < context_len ) {
145
+ #endif
127
146
qk [qk_idx ] = sub_group_reduce_add (qk [qk_idx ]);
128
147
129
148
// Apply scale
130
149
qk [qk_idx ] = scale [0 ] * qk [qk_idx ];
131
150
132
151
// Apply attention mask for context processing stage
133
- const bool is_prefill_stage = INPUT0_FEATURE_NUM > 1 ;
134
- if (is_prefill_stage && current_token > token_idx ) {
135
- qk [qk_idx ] = qk [qk_idx ] + OUTPUT_VAL_MIN ;
152
+ const unsigned char is_prefill_stage = is_prompt [0 ];
153
+ if (is_prefill_stage == 1 ) {
154
+ if (current_token > token_idx )
155
+ qk [qk_idx ] = qk [qk_idx ] + OUTPUT_VAL_MIN ;
156
+ } else if (is_prefill_stage == 2 ) {
157
+ if (current_token > context_len - INPUT0_FEATURE_NUM + token_idx )
158
+ qk [qk_idx ] = qk [qk_idx ] + OUTPUT_VAL_MIN ;
136
159
}
137
160
138
161
qk_max = ACCUMULATOR_MAX_FUNC (qk_max , TO_ACCUMULATOR_TYPE (qk [qk_idx ]));
139
162
}
140
163
}
141
164
142
165
// Save QK results to local memory
166
+ #if BLOCK_SIZE % SUBGROUPS_PER_WG != 0
167
+ if (sglid < QK_VALS_PER_SG_PER_ITER && sgid < BLOCK_SIZE / QK_VALS_PER_SG_PER_ITER ) {
168
+ #else
143
169
if (sglid < QK_VALS_PER_SG_PER_ITER ) {
170
+ #endif
144
171
const uint current_token_global_idx = (block_start_idx + block_num ) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + sglid ;
145
172
#ifdef USE_SEQ_LEN_SPLIT
146
173
const uint current_token_local = block_num * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + sglid ;
@@ -152,6 +179,33 @@ KERNEL(pa_sdpa_ref)(
152
179
}
153
180
}
154
181
182
+ // barrier(CLK_LOCAL_MEM_FENCE);
183
+ // if (get_global_id(1) == 0 && get_global_id(2) == 0) {
184
+ // if (context_len == 15)
185
+ // printf("token_idx=%d, qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d\n",
186
+ // token_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
187
+ // qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
188
+ // qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], is_prompt[0]);
189
+ // else if (context_len == 16)
190
+ // printf("token_idx=%d, qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d\n",
191
+ // token_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
192
+ // qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
193
+ // qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], is_prompt[0]);
194
+ // else if (context_len == 17)
195
+ // printf("token_idx=%d, qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d\n",
196
+ // token_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
197
+ // qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
198
+ // qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], qk_vals_local[16], is_prompt[0]);
199
+ // }
200
+
201
+ // barrier(CLK_LOCAL_MEM_FENCE);
202
+ // if (context_len == 17 && sgid == 4 && sglid == 0) {
203
+ // printf("FROM SGID=4; token_idx=%d, head_num=%d qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d. qk_max=%f\n",
204
+ // token_idx, head_num_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
205
+ // qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
206
+ // qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], qk_vals_local[16], is_prompt[0], qk_max);
207
+ // }
208
+
155
209
// Apply SoftMax operation
156
210
__local ACCUMULATOR_TYPE qk_max_vals [SUBGROUPS_PER_WG ];
157
211
__local ACCUMULATOR_TYPE qk_sum_vals [SUBGROUPS_PER_WG ];
@@ -168,6 +222,16 @@ KERNEL(pa_sdpa_ref)(
168
222
// Final max value after reduction across of all SG and WI
169
223
qk_max = sub_group_reduce_max (qk_max );
170
224
225
+ // barrier(CLK_LOCAL_MEM_FENCE);
226
+ // if (context_len == 17 && get_global_id(2) == 0 && (head_num_idx == 1 || head_num_idx == 28) && SUBGROUPS_PER_WG == 5) {
227
+ // printf("Calculation QK_VALS token_idx=%d, head_num=%d qk_vals_local: %f (-qk_max = %f, native_exp = %f), %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f(-qk_max = %f, native_exp = %f): %d. qk_max=%f (%f %f %f %f %f)\n",
228
+ // token_idx, head_num_idx, qk_vals_local[0], TO_ACCUMULATOR_TYPE(qk_vals_local[0] - qk_max), native_exp(TO_ACCUMULATOR_TYPE(qk_vals_local[0]) - qk_max), qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
229
+ // qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
230
+ // qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15],
231
+ // qk_vals_local[16], TO_ACCUMULATOR_TYPE(qk_vals_local[16] - qk_max), native_exp(TO_ACCUMULATOR_TYPE(qk_vals_local[16]) - qk_max),
232
+ // is_prompt[0], qk_max, qk_max_vals[0], qk_max_vals[1], qk_max_vals[2], qk_max_vals[3], qk_max_vals[4]);
233
+ // }
234
+
171
235
ACCUMULATOR_TYPE exp_sum = ACCUMULATOR_VAL_ZERO ;
172
236
#ifdef USE_SEQ_LEN_SPLIT
173
237
const uint qk_num = (num_of_portions == 1 ) ? CEIL_DIV (context_len , SUBGROUPS_PER_WG * SUB_GROUP_SIZE )
@@ -189,6 +253,15 @@ KERNEL(pa_sdpa_ref)(
189
253
}
190
254
}
191
255
256
+
257
+ // barrier(CLK_LOCAL_MEM_FENCE);
258
+ // if (context_len == 17 && get_global_id(2) == 0) {
259
+ // printf("UPDATED QK_VALS token_idx=%d, head_num=%d qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d. qk_max=%f\n",
260
+ // token_idx, head_num_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
261
+ // qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
262
+ // qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], qk_vals_local[16], is_prompt[0], qk_max);
263
+ // }
264
+
192
265
exp_sum = sub_group_reduce_add (exp_sum );
193
266
194
267
if (sglid == 0 )
@@ -236,6 +309,16 @@ KERNEL(pa_sdpa_ref)(
236
309
}
237
310
}
238
311
#endif
312
+
313
+
314
+ // if (context_len == 17 && get_global_id(2) == 0 && SUBGROUPS_PER_WG == 5) {
315
+ // printf("SF result: token_idx=%d, head_num=%d qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f; Total qk_max=%f total sum=%f (%f %f %f %f %f)\n",
316
+ // token_idx, head_num_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
317
+ // qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
318
+ // qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], qk_vals_local[16], qk_max, exp_sum,
319
+ // qk_sum_vals[0], qk_sum_vals[1], qk_sum_vals[2], qk_sum_vals[3], qk_sum_vals[4]);
320
+
321
+ // }
239
322
}
240
323
241
324
{
0 commit comments