38
38
39
39
ulong __attribute__((overloadable )) intel_get_cycle_counter ( void );
40
40
41
+ #ifdef SDPA_STAGE_0
42
+
41
43
REQD_SUB_GROUP_SIZE (SUB_GROUP_SIZE )
42
44
__attribute__((reqd_work_group_size (1 , 1 , HEAD_SIZE )))
43
45
KERNEL (pa_sdpa_ref )(
@@ -49,7 +51,11 @@ KERNEL(pa_sdpa_ref)(
49
51
__global const INPUT4_TYPE * context_lens ,
50
52
__global const INPUT5_TYPE * block_tables ,
51
53
__global const INPUT6_TYPE * scale ,
52
- __global OUTPUT_TYPE * output )
54
+ __global OUTPUT_TYPE * output ,
55
+ __global OUTPUT_TYPE * exp_sums ,
56
+ __global OUTPUT_TYPE * max_logits ,
57
+ __global OUTPUT_TYPE * tmp_out ,
58
+ uint num_of_portions )
53
59
{
54
60
const uint seq_idx = get_global_id (0 );
55
61
const uint head_num_idx = get_global_id (1 );
@@ -64,6 +70,11 @@ KERNEL(pa_sdpa_ref)(
64
70
65
71
const uint blocks_num = INPUT5_FEATURE_NUM ;
66
72
73
+ const uint portion_id = get_group_id (2 );
74
+ const uint block_start_idx = portion_id * SEQ_LEN_PORTION_SIZE / BLOCK_SIZE ;
75
+ const uint block_end_idx = min (block_start_idx + (SEQ_LEN_PORTION_SIZE / BLOCK_SIZE ), blocks_num );
76
+
77
+
67
78
// if (seq_idx < 2 && head_num_idx < 2 && sgid < 2 && sglid < 2) {
68
79
// if (INPUT5_BATCH_NUM == 2) {
69
80
// if (INPUT5_FEATURE_NUM == 0) {
@@ -135,16 +146,17 @@ KERNEL(pa_sdpa_ref)(
135
146
q [i ] = QUERY_BLOCK_READ (query , query_idx );
136
147
}
137
148
138
- for (uint block = 0 ; block < blocks_num ; block ++ ) {
139
- const uint block_idx = batch_idx * blocks_num + block ;
149
+ // JIT: Compile time restriction: devisible SEQ_LEN_PORTION_SIZE / BLOCK_SIZE
150
+ for (uint block = 0 ; block < SEQ_LEN_PORTION_SIZE / BLOCK_SIZE ; block ++ ) {
151
+ const uint block_idx = batch_idx * blocks_num + block + block_start_idx ;
140
152
const uint block_offset = block_tables [block_idx ] * KV_CACHE_BLOCK_STRIDE ;
141
153
142
154
OUTPUT_TYPE qk [QK_VALS_PER_SG_PER_ITER ] = {0 };
143
155
144
156
ulong timer2 = intel_get_cycle_counter ();
145
157
for (uint hs = 0 ; hs < Q_LOAD_ITERS ; hs ++ ) {
146
158
for (uint qk_idx = 0 ; qk_idx < QK_VALS_PER_SG_PER_ITER ; qk_idx ++ ) {
147
- uint current_token = block * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx ;
159
+ uint current_token = ( block + block_start_idx ) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx ;
148
160
if (current_token >= context_len )
149
161
continue ;
150
162
@@ -185,7 +197,7 @@ KERNEL(pa_sdpa_ref)(
185
197
186
198
// Summurize qk calculation across all WIs and apply scale
187
199
for (uint qk_idx = 0 ; qk_idx < QK_VALS_PER_SG_PER_ITER ; qk_idx ++ ) {
188
- const uint current_token = block * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx ;
200
+ const uint current_token = ( block + block_start_idx ) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx ;
189
201
if (current_token < context_len ) {
190
202
OUTPUT_TYPE tmp_print = qk [qk_idx ];
191
203
qk [qk_idx ] = sub_group_reduce_add (qk [qk_idx ]);
@@ -194,7 +206,7 @@ KERNEL(pa_sdpa_ref)(
194
206
// seq_idx, head_num_idx, sgid, sglid, qk_idx, tmp_print, qk[qk_idx]);
195
207
qk [qk_idx ] = scale [0 ] * qk [qk_idx ];
196
208
197
- // Apply attention mask during prefill stage
209
+ // Apply attention mask at prefill stage
198
210
if (INPUT0_FEATURE_NUM > 1 && current_token > token_idx ) {
199
211
qk [qk_idx ] = qk [qk_idx ] + OUTPUT_VAL_MIN ;
200
212
}
@@ -206,12 +218,13 @@ KERNEL(pa_sdpa_ref)(
206
218
// Save QK results to local memory
207
219
if (sglid < QK_VALS_PER_SG_PER_ITER ) {
208
220
const uint current_token = block * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + sglid ;
221
+ const uint current_token_global_idx = (block + block_start_idx ) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + sglid ;
209
222
// Fixed -> // const uint qk_local_idx = block * BLOCK_SIZE * sgid * QK_VALS_PER_SG_PER_ITER + sglid;
210
223
// OUTPUT_TYPE tmp_print = (current_token >= context_len ? 0 : qk[sglid]);
211
224
// if (head_num_idx < 4 || head_num_idx == 31)
212
225
// printf("slm save: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: qk_vals[%d]=%f. Max=%f\n",
213
226
// seq_idx, head_num_idx, sgid, sglid, current_token, tmp_print, qk_max);
214
- qk_vals [current_token ] = current_token >= context_len ? 0 : qk [sglid ];
227
+ qk_vals [current_token ] = current_token_global_idx >= context_len ? 0 : qk [sglid ];
215
228
}
216
229
ulong timer5 = intel_get_cycle_counter ();
217
230
@@ -266,12 +279,13 @@ KERNEL(pa_sdpa_ref)(
266
279
// }
267
280
268
281
OUTPUT_TYPE exp_sum = OUTPUT_VAL_ZERO ;
269
- for (uint qk_idx = 0 ; qk_idx < CEIL_DIV (context_len , SUBGROUPS_PER_WG * SUB_GROUP_SIZE ); qk_idx ++ ) {
270
- const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE ) + sgid * SUB_GROUP_SIZE + sglid ;
271
- if (data_idx < context_len ) {
272
- OUTPUT_TYPE val = native_exp (qk_vals [data_idx ] - qk_max );
282
+ for (uint qk_idx = 0 ; qk_idx < CEIL_DIV (SEQ_LEN_PORTION_SIZE , SUBGROUPS_PER_WG * SUB_GROUP_SIZE ); qk_idx ++ ) {
283
+ const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE ) + sgid * SUB_GROUP_SIZE + sglid ;
284
+ const uint global_data_idx = block_start_idx * BLOCK_SIZE + qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE ) + sgid * SUB_GROUP_SIZE + sglid ;
285
+ if (global_data_idx < context_len ) {
286
+ OUTPUT_TYPE val = native_exp (qk_vals [local_data_idx ] - qk_max );
273
287
exp_sum += val ;
274
- qk_vals [data_idx ] = val ;
288
+ qk_vals [local_data_idx ] = val ;
275
289
// if (head_num_idx < 4 || head_num_idx == 31)
276
290
// printf("head_num %d, sgid = %d, sglid = %d, exp_sum = %f\n", head_num_idx, sgid, sglid, exp_sum);
277
291
}
@@ -290,6 +304,8 @@ KERNEL(pa_sdpa_ref)(
290
304
291
305
exp_sum = OUTPUT_VAL_ZERO ;
292
306
307
+
308
+ // JIT: Compile time restiction SUBGROUPS_PER_WG <= SG_SIZE
293
309
if (sglid < SUBGROUPS_PER_WG )
294
310
exp_sum = qk_sum_vals [sglid ];
295
311
@@ -300,20 +316,34 @@ KERNEL(pa_sdpa_ref)(
300
316
301
317
302
318
// TODO: replace CEIL_DIV with ALIGN and use += SUBGROUPS_PER_WG * SUB_GROUP_SIZE increment
303
- for (uint qk_idx = 0 ; qk_idx < CEIL_DIV (context_len , SUBGROUPS_PER_WG * SUB_GROUP_SIZE ); qk_idx ++ ) {
304
- const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE ) + sgid * SUB_GROUP_SIZE + sglid ;
305
- if (data_idx < context_len ) {
306
- OUTPUT_TYPE val = qk_vals [data_idx ] * inv_sum ;
307
- qk_vals [data_idx ] = val ;
319
+ for (uint qk_idx = 0 ; qk_idx < CEIL_DIV (SEQ_LEN_PORTION_SIZE , SUBGROUPS_PER_WG * SUB_GROUP_SIZE ); qk_idx ++ ) {
320
+ const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE ) + sgid * SUB_GROUP_SIZE + sglid ;
321
+ const uint global_data_idx = block_start_idx * BLOCK_SIZE + qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE ) + sgid * SUB_GROUP_SIZE + sglid ;
322
+ if (global_data_idx < context_len ) {
323
+ OUTPUT_TYPE val = qk_vals [local_data_idx ] * inv_sum ;
324
+ qk_vals [local_data_idx ] = val ;
308
325
}
309
326
}
310
327
311
328
barrier (CLK_LOCAL_MEM_FENCE );
312
329
313
-
314
330
ulong timer_end = intel_get_cycle_counter ();
315
331
ulong total_time = timer_end - timer_start ;
316
332
333
+ {
334
+ // Save temporary exm_sums and max_logits values for each portion
335
+ if (sgid == 0 ) {
336
+ const uint num_of_portions = get_num_groups (2 );
337
+ const uint exp_sums_offset = seq_idx * HEADS_NUM * num_of_portions +
338
+ head_num_idx * num_of_portions +
339
+ portion_id ;
340
+ exp_sums [exp_sums_offset ] = exp_sum ;
341
+
342
+ const uint max_logits_offset = exp_sums_offset ;
343
+ max_logits [max_logits_offset ] = qk_max ;
344
+ }
345
+ }
346
+
317
347
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
318
348
// printf("SDPA kernel Softmax: %d\n", (uint)total_time);
319
349
}
@@ -328,12 +358,14 @@ KERNEL(pa_sdpa_ref)(
328
358
ulong timer_start = intel_get_cycle_counter ();
329
359
OUTPUT_TYPE acc = OUTPUT_VAL_ZERO ;
330
360
331
- for (uint qk_idx = 0 ; qk_idx < ALIGN (context_len , SUB_GROUP_SIZE ); qk_idx += SUB_GROUP_SIZE ) {
332
- const uint qk_offset = qk_idx + sglid ;
333
361
334
- OUTPUT_TYPE qk = qk_offset < context_len ? qk_vals [qk_offset ] : OUTPUT_VAL_ZERO ;
362
+ for (uint qk_idx = 0 ; qk_idx < SEQ_LEN_PORTION_SIZE / BLOCK_SIZE * SUB_GROUP_SIZE ; qk_idx += SUB_GROUP_SIZE ) {
363
+ const uint qk_offset_local = qk_idx + sglid ;
364
+ const uint qk_offset_global = block_start_idx * BLOCK_SIZE + qk_offset_local ;
335
365
336
- const uint block_idx = block_tables [batch_idx * blocks_num + (qk_idx / BLOCK_SIZE )];
366
+ OUTPUT_TYPE qk = qk_offset_global < context_len ? qk_vals [qk_offset_local ] : OUTPUT_VAL_ZERO ;
367
+
368
+ const uint block_idx = block_tables [batch_idx * blocks_num + block_start_idx + (qk_idx / BLOCK_SIZE )];
337
369
// if (block_idx == 0)
338
370
// continue;
339
371
@@ -356,33 +388,49 @@ KERNEL(pa_sdpa_ref)(
356
388
// seq_idx, head_num_idx, sgid, sglid, block_idx, qk_idx, qk_offset, value_cache_offset - (block_idx * KV_CACHE_BLOCK_STRIDE), block_idx * KV_CACHE_BLOCK_STRIDE, *tmp_print);
357
389
// }
358
390
359
- if (qk_idx + SUB_GROUP_SIZE <= context_len ) {
391
+ // FINAL: rename token -> value_idx
392
+ if (block_start_idx * BLOCK_SIZE + qk_idx + SUB_GROUP_SIZE <= context_len ) {
360
393
unroll_for (uint token = 0 ; token < SUB_GROUP_SIZE ; token ++ ) {
361
394
OUTPUT_TYPE qk_tmp = sub_group_broadcast (qk , token );
362
395
acc = mad (qk_tmp , v [token ], acc );
363
396
}
364
397
} else {
365
398
for (uint token = 0 ; token < SUB_GROUP_SIZE ; token ++ ) {
366
399
OUTPUT_TYPE qk_tmp = sub_group_broadcast (qk , token );
367
- if (qk_idx + token < context_len ) {
400
+ if (block_start_idx * BLOCK_SIZE + qk_idx + token < context_len ) {
368
401
acc = mad (qk_tmp , v [token ], acc );
369
402
}
370
403
}
371
404
}
372
405
}
373
406
374
407
375
- const uint output_offset = seq_idx * (HEADS_NUM * HEAD_SIZE ) +
376
- head_num_idx * HEAD_SIZE +
377
- sgid * SUB_GROUP_SIZE +
378
- sglid ;
408
+ // const uint output_offset = seq_idx * (HEADS_NUM * HEAD_SIZE) +
409
+ // head_num_idx * HEAD_SIZE +
410
+ // sgid * SUB_GROUP_SIZE +
411
+ // sglid;
379
412
380
413
// if (seq_idx == 0 && head_num_idx < 2 || head_num_idx == 31) {
381
414
// printf("output res: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: output[%d] = %f\n",
382
415
// seq_idx, head_num_idx, sgid, sglid, output_offset, acc);
383
416
// }
384
417
385
- output [output_offset ] = acc ;
418
+ // output[output_offset] = acc;
419
+
420
+ {
421
+ // [num_seqs, num_heads, max_num_partitions, head_size]
422
+ const uint num_of_portions = get_num_groups (2 );
423
+ const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * num_of_portions ) +
424
+ head_num_idx * (HEAD_SIZE * num_of_portions ) +
425
+ portion_id * HEAD_SIZE +
426
+ sgid * SUB_GROUP_SIZE +
427
+ sglid ;
428
+
429
+ // if (output_offset != tmp_out_offset)
430
+ // printf("Different tmp_out_offset index!! %d vs %d, for portion_id %d\n", output_offset, tmp_out_offset, portion_id);
431
+
432
+ tmp_out [tmp_out_offset ] = acc ;
433
+ }
386
434
387
435
ulong timer_end = intel_get_cycle_counter ();
388
436
ulong total_time = timer_end - timer_start ;
@@ -391,3 +439,78 @@ KERNEL(pa_sdpa_ref)(
391
439
// printf("SDPA kernel GEMM2: %d\n", (uint)total_time);
392
440
}
393
441
}
442
+
443
+ #endif
444
+
445
+ #ifdef SDPA_STAGE_1
446
+
447
+ // exp_sums, // [num_seqs, num_heads, max_num_partitions]
448
+ // max_logits, // [num_seqs, num_heads, max_num_partitions]
449
+ // tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
450
+
451
+ REQD_SUB_GROUP_SIZE (SUB_GROUP_SIZE )
452
+ KERNEL (pa_sdpa_ref )(
453
+ OPTIONAL_SHAPE_INFO_ARG
454
+ __global const INPUT0_TYPE * query ,
455
+ __global const INPUT1_TYPE * key_cache ,
456
+ __global const INPUT2_TYPE * value_cache ,
457
+ __global const INPUT3_TYPE * max_context_len ,
458
+ __global const INPUT4_TYPE * context_lens ,
459
+ __global const INPUT5_TYPE * block_tables ,
460
+ __global const INPUT6_TYPE * scale ,
461
+ __global OUTPUT_TYPE * output ,
462
+ __global OUTPUT_TYPE * exp_sums ,
463
+ __global OUTPUT_TYPE * max_logits ,
464
+ __global OUTPUT_TYPE * tmp_out ,
465
+ uint num_of_portions ) {
466
+ if (num_of_portions <= SUB_GROUP_SIZE ) {
467
+ const uint seq_idx = get_global_id (0 );
468
+ const uint head_num_idx = get_global_id (1 );
469
+ const uint head_idx = get_global_id (2 );
470
+ const uint sglid = get_sub_group_local_id ();
471
+
472
+ const uint exp_sums_offset = seq_idx * HEADS_NUM * num_of_portions +
473
+ head_num_idx * num_of_portions ;
474
+ const uint max_logit_offset = exp_sums_offset ;
475
+
476
+ OUTPUT_TYPE exp_sum = BLOCK_READN (OUTPUT_TYPE , 1 , exp_sums , exp_sums_offset );
477
+ OUTPUT_TYPE max_logit = BLOCK_READN (OUTPUT_TYPE , 1 , max_logits , max_logit_offset );
478
+ if (sglid >= num_of_portions ) {
479
+ exp_sum = 0 ;
480
+ max_logit = OUTPUT_VAL_MIN ;
481
+ }
482
+
483
+ OUTPUT_TYPE global_max = sub_group_reduce_max (max_logit );
484
+
485
+ // Update exp_sum with respect to the global maximum
486
+ OUTPUT_TYPE test_exp_sum = exp_sum ;
487
+ if (sglid < num_of_portions )
488
+ exp_sum = exp_sum * native_exp (max_logit - global_max );
489
+
490
+ OUTPUT_TYPE global_sum = sub_group_reduce_add (exp_sum );
491
+
492
+ if (get_global_id (0 ) == 0 && get_global_id (1 ) == 0 && get_global_id (2 ) == 0 )
493
+ printf ("Run second kernel for reduction: num_of_portions=%d: max_logit=%f, exp_sum = %f, global_sum = %f, global_max=%f, test = %f, %f, %f\n" , num_of_portions ,
494
+ max_logit , exp_sum , global_sum , global_max , test_exp_sum , native_exp (max_logit - global_max ), test_exp_sum * native_exp (max_logit - global_max ));
495
+
496
+ for (uint i = 0 ; i < HEAD_SIZE / SUB_GROUP_SIZE ; i ++ ) {
497
+ OUTPUT_TYPE acc = OUTPUT_VAL_ZERO ;
498
+ for (uint portion = 0 ; portion < num_of_portions ; portion ++ ) {
499
+ const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * num_of_portions ) +
500
+ head_num_idx * (HEAD_SIZE * num_of_portions ) +
501
+ portion * HEAD_SIZE ;
502
+ OUTPUT_TYPE out_val = BLOCK_READN (OUTPUT_TYPE , 1 , tmp_out , tmp_out_offset );
503
+ acc += out_val * sub_group_broadcast (exp_sum , portion ) / global_sum ;
504
+ }
505
+ const uint out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE ) +
506
+ head_num_idx * HEAD_SIZE +
507
+ i * SUB_GROUP_SIZE ;
508
+ output [out_offset ] = acc ;
509
+ }
510
+ } else {
511
+ if (get_global_id (0 ) == 0 && get_global_id (1 ) == 0 && get_global_id (2 ) == 0 )
512
+ printf ("run second kernel for portion >= 16\n" );
513
+ }
514
+ }
515
+
516
+ #endif
0 commit comments