@@ -58,10 +58,9 @@ KERNEL(pa_sdpa_ref)(
58
58
__global OUTPUT_TYPE * tmp_out ,
59
59
const uint num_of_portions
60
60
#else
61
- __global OUTPUT_TYPE * output ,
61
+ __global OUTPUT_TYPE * output
62
62
#endif
63
- )
64
- {
63
+ ) {
65
64
const uint seq_idx = get_global_id (0 );
66
65
const uint head_num_idx = get_global_id (1 );
67
66
const uint head_idx = get_global_id (2 );
@@ -73,7 +72,7 @@ KERNEL(pa_sdpa_ref)(
73
72
74
73
const uint context_len = context_lens [batch_idx ];
75
74
76
- const uint total_blocks_num = INPUT5_FEATURE_NUM ;
75
+ const uint blocks_pitch = INPUT5_FEATURE_NUM ;
77
76
78
77
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
79
78
const uint portion_id = get_group_id (2 );
@@ -86,6 +85,8 @@ KERNEL(pa_sdpa_ref)(
86
85
const uint block_start_idx = 0 ;
87
86
#endif
88
87
88
+ const uint total_blocks_num = CEIL_DIV (context_len , BLOCK_SIZE );
89
+
89
90
// if (seq_idx < 2 && head_num_idx < 2 && sgid < 2 && sglid < 2) {
90
91
// if (INPUT5_BATCH_NUM == 2) {
91
92
// if (INPUT5_FEATURE_NUM == 0) {
@@ -159,12 +160,13 @@ KERNEL(pa_sdpa_ref)(
159
160
160
161
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
161
162
// FINAL: Compile time restriction: devisible SEQ_LEN_PORTION_SIZE / BLOCK_SIZE
162
- const uint blocks_num = SEQ_LEN_PORTION_SIZE / BLOCK_SIZE ;
163
+ const uint blocks_num = (portion_id == num_of_portions - 1 ) ? (total_blocks_num - (portion_id * SEQ_LEN_PORTION_SIZE / BLOCK_SIZE ))
164
+ : (SEQ_LEN_PORTION_SIZE / BLOCK_SIZE );
163
165
#else
164
166
const uint blocks_num = total_blocks_num ;
165
167
#endif
166
168
for (uint block_num = 0 ; block_num < blocks_num ; block_num ++ ) {
167
- const uint block_idx = batch_idx * total_blocks_num + block_start_idx + block_num ;
169
+ const uint block_idx = batch_idx * blocks_pitch + block_start_idx + block_num ;
168
170
const uint block_offset = block_tables [block_idx ] * KV_CACHE_BLOCK_STRIDE ;
169
171
170
172
OUTPUT_TYPE qk [QK_VALS_PER_SG_PER_ITER ] = {0 };
@@ -263,8 +265,9 @@ KERNEL(pa_sdpa_ref)(
263
265
ulong timer_end = intel_get_cycle_counter ();
264
266
ulong total_time = timer_end - timer_start ;
265
267
266
- // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
267
- // printf("SDPA kernel GEMM1: %d; qk_max=%f\n", (uint)total_time, qk_max);
268
+ // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_local_id(2) == 0 && context_len >= 496)
269
+ // printf("%d. %d. SDPA kernel GEMM1: %d; qk_max=%f, blocks_num=%d, total_blocks_num=%d, portion_id=%d, num_of_portions=%d\n",
270
+ // context_len, get_global_id(2), (uint)total_time, qk_max, blocks_num, total_blocks_num, portion_id, num_of_portions);
268
271
}
269
272
270
273
// barrier(CLK_LOCAL_MEM_FENCE);
@@ -311,10 +314,12 @@ KERNEL(pa_sdpa_ref)(
311
314
312
315
// // temp test
313
316
// barrier(CLK_LOCAL_MEM_FENCE);
317
+ ulong timer_start2 = intel_get_cycle_counter ();
314
318
315
319
ACCUMULATOR_TYPE exp_sum = ACCUMULATOR_VAL_ZERO ;
316
320
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
317
- const uint qk_num = CEIL_DIV (SEQ_LEN_PORTION_SIZE , SUBGROUPS_PER_WG * SUB_GROUP_SIZE );
321
+ const uint qk_num = (num_of_portions == 1 ) ? CEIL_DIV (context_len , SUBGROUPS_PER_WG * SUB_GROUP_SIZE )
322
+ : CEIL_DIV (SEQ_LEN_PORTION_SIZE , SUBGROUPS_PER_WG * SUB_GROUP_SIZE );
318
323
#else
319
324
const uint qk_num = CEIL_DIV (context_len , SUBGROUPS_PER_WG * SUB_GROUP_SIZE );
320
325
#endif
@@ -338,6 +343,7 @@ KERNEL(pa_sdpa_ref)(
338
343
}
339
344
}
340
345
346
+ ulong timer_start3 = intel_get_cycle_counter ();
341
347
342
348
// // temp test
343
349
// barrier(CLK_LOCAL_MEM_FENCE);
@@ -365,6 +371,7 @@ KERNEL(pa_sdpa_ref)(
365
371
366
372
exp_sum = ACCUMULATOR_VAL_ZERO ;
367
373
374
+ ulong timer_start4 = intel_get_cycle_counter ();
368
375
369
376
// FINAL FIX: Compile time restiction SUBGROUPS_PER_WG <= SG_SIZE
370
377
if (sglid < SUBGROUPS_PER_WG )
@@ -391,9 +398,7 @@ KERNEL(pa_sdpa_ref)(
391
398
}
392
399
393
400
barrier (CLK_LOCAL_MEM_FENCE );
394
-
395
- ulong timer_end = intel_get_cycle_counter ();
396
- ulong total_time = timer_end - timer_start ;
401
+ ulong timer_start5 = intel_get_cycle_counter ();
397
402
398
403
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
399
404
{
@@ -417,8 +422,17 @@ KERNEL(pa_sdpa_ref)(
417
422
}
418
423
#endif
419
424
420
- // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
421
- // printf("SDPA kernel Softmax: %d\n", (uint)total_time);
425
+ ulong timer_end = intel_get_cycle_counter ();
426
+
427
+ ulong total_time1 = timer_start2 - timer_start ;
428
+ ulong total_time2 = timer_start3 - timer_start2 ;
429
+ ulong total_time3 = timer_start4 - timer_start3 ;
430
+ ulong total_time4 = timer_start5 - timer_start4 ;
431
+ ulong total_time5 = timer_end - timer_start5 ;
432
+
433
+ // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_local_id(2) == 0)
434
+ // printf("%d. SDPA kernel Softmax: qk_max calc: %d, exp_sum_loc calc: %d, exp_sum calc: %d, qk_vals recalc: %d, save: %d\n",
435
+ // get_global_id(2), (uint)total_time1, (uint)total_time2, (uint)total_time3, (uint)total_time4, (uint)total_time5);
422
436
}
423
437
424
438
// if (seq_idx == 0 && sgid == 0 && sglid == 0) {
@@ -433,17 +447,18 @@ KERNEL(pa_sdpa_ref)(
433
447
434
448
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
435
449
// FINAL: Compile time restriction: devisible SEQ_LEN_PORTION_SIZE / BLOCK_SIZE
436
- const uint qk_num = SEQ_LEN_PORTION_SIZE / BLOCK_SIZE * SUB_GROUP_SIZE ;
450
+ const uint qk_num = (portion_id == num_of_portions - 1 ) ? (context_len - (portion_id * SEQ_LEN_PORTION_SIZE ))
451
+ : (SEQ_LEN_PORTION_SIZE );
437
452
#else
438
- const uint qk_num = ALIGN ( context_len , SUB_GROUP_SIZE ) ;
453
+ const uint qk_num = context_len ;
439
454
#endif
440
455
for (uint qk_idx = 0 ; qk_idx < qk_num ; qk_idx += SUB_GROUP_SIZE ) {
441
456
const uint qk_offset_local = qk_idx + sglid ;
442
457
const uint qk_offset_global = block_start_idx * BLOCK_SIZE + qk_offset_local ;
443
458
444
459
OUTPUT_TYPE qk = qk_offset_global < context_len ? qk_vals [qk_offset_local ] : OUTPUT_VAL_ZERO ;
445
460
446
- const uint block_idx = block_tables [batch_idx * total_blocks_num + block_start_idx + (qk_idx / BLOCK_SIZE )];
461
+ const uint block_idx = block_tables [batch_idx * blocks_pitch + block_start_idx + (qk_idx / BLOCK_SIZE )];
447
462
// if (block_idx == 0)
448
463
// continue;
449
464
@@ -504,8 +519,8 @@ KERNEL(pa_sdpa_ref)(
504
519
ulong timer_end = intel_get_cycle_counter ();
505
520
ulong total_time = timer_end - timer_start ;
506
521
507
- // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id (2) == 0)
508
- // printf("SDPA kernel GEMM2: %d\n", (uint)total_time);
522
+ // if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_local_id (2) == 0)
523
+ // printf("%d. SDPA kernel GEMM2: %d\n", get_global_id(2) , (uint)total_time);
509
524
}
510
525
}
511
526
0 commit comments