7
7
#include "include/batch_headers/sub_group_block_write.cl"
8
8
#include "include/batch_headers/sub_group_shuffle.cl"
9
9
10
- #define SUBGROUPS_PER_WG (HEAD_SIZE / SUBGROUP_SIZE)
10
+ #define SUBGROUPS_PER_WG (( HEAD_SIZE / SUBGROUP_SIZE) * SG_SCALE_FACTOR )
11
11
#define PAGED_ATTENTION_BLOCKS_PER_PARTITION (SEQ_LEN_PARTITION_SIZE / PAGED_ATTENTION_BLOCK_SIZE)
12
12
13
13
#if HEAD_SIZE > 128
@@ -75,7 +75,7 @@ KERNEL(pa_sdpa_opt)(
75
75
76
76
const uint seq_idx = get_global_id (0 );
77
77
const uint head_num_idx = get_global_id (1 );
78
- const uint head_size_idx = get_global_id (2 );
78
+ const uint head_size_idx = get_local_id (2 );
79
79
const uint sglid = get_sub_group_local_id ();
80
80
const uint sgid = get_sub_group_id ();
81
81
const uint total_partitions_num = get_num_groups (2 );
@@ -93,7 +93,6 @@ KERNEL(pa_sdpa_opt)(
93
93
#endif
94
94
95
95
const uint partition_idx = get_group_id (2 );
96
- const uint block_start_idx = partition_idx * SEQ_LEN_PARTITION_SIZE / PAGED_ATTENTION_BLOCK_SIZE ;
97
96
98
97
if (partition_idx * SEQ_LEN_PARTITION_SIZE >= seq_len ) {
99
98
return ;
@@ -336,6 +335,15 @@ KERNEL(pa_sdpa_opt)(
336
335
OUTPUT_TYPE acc = OUTPUT_VAL_ZERO ;
337
336
338
337
const uint partition_seq_len = min (seq_len - partition_idx * SEQ_LEN_PARTITION_SIZE , (uint )SEQ_LEN_PARTITION_SIZE );
338
+
339
+ #if SG_SCALE_FACTOR > 1
340
+ const uint block_start_idx = (sgid / (SUBGROUPS_PER_WG / SG_SCALE_FACTOR )) * (SEQ_LEN_PARTITION_SIZE / SG_SCALE_FACTOR / SUBGROUP_SIZE );
341
+ const uint block_end_idx = min (block_start_idx + (SEQ_LEN_PARTITION_SIZE / SG_SCALE_FACTOR / SUBGROUP_SIZE ), partition_seq_len / SUBGROUP_SIZE );
342
+ #else
343
+ const uint block_start_idx = 0 ;
344
+ const uint block_end_idx = partition_seq_len / SUBGROUP_SIZE ;
345
+ #endif
346
+
339
347
uint blocks_num_per_partition = min (total_blocks_num - partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION , (uint )PAGED_ATTENTION_BLOCKS_PER_PARTITION );
340
348
341
349
uint leftovers = blocks_num_per_partition * PAGED_ATTENTION_BLOCK_SIZE - partition_seq_len ;
@@ -346,7 +354,7 @@ KERNEL(pa_sdpa_opt)(
346
354
347
355
const uint start_block_idx = block_indices_begins [subsequence_idx ] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION ;
348
356
349
- for (uint block_num = 0 ; block_num < blocks_num_per_partition ; block_num ++ ) {
357
+ for (uint block_num = block_start_idx ; block_num < block_end_idx ; block_num ++ ) {
350
358
#ifdef BROADCAST_GROUP_SIZE
351
359
const uint head_idx = head_num_idx / BROADCAST_GROUP_SIZE ;
352
360
#else
@@ -389,6 +397,10 @@ KERNEL(pa_sdpa_opt)(
389
397
}
390
398
}
391
399
400
+
401
+ #if SG_SCALE_FACTOR > 1
402
+ if (sgid >= SUBGROUPS_PER_WG / SG_SCALE_FACTOR ) {
403
+ #endif
392
404
if (leftovers != 0 ) {
393
405
#ifdef BROADCAST_GROUP_SIZE
394
406
const uint head_idx = head_num_idx / BROADCAST_GROUP_SIZE ;
@@ -429,6 +441,32 @@ KERNEL(pa_sdpa_opt)(
429
441
}
430
442
}
431
443
444
+
445
+ #if SG_SCALE_FACTOR > 1
446
+ }
447
+ #endif
448
+
449
+ #if SG_SCALE_FACTOR > 1
450
+ if ((partition_seq_len > (SEQ_LEN_PARTITION_SIZE / SG_SCALE_FACTOR )) || (leftovers != 0 )) {
451
+ barrier (CLK_LOCAL_MEM_FENCE );
452
+
453
+ if (sgid >= SUBGROUPS_PER_WG / SG_SCALE_FACTOR ) {
454
+ // Reuse slm_qk_vals SLM to sum-up results between two groups of subgroups
455
+ slm_qk_vals [head_size_idx ] = acc ;
456
+ }
457
+
458
+ barrier (CLK_LOCAL_MEM_FENCE );
459
+
460
+ if (sgid < SUBGROUPS_PER_WG / SG_SCALE_FACTOR ) {
461
+ acc += slm_qk_vals [head_size_idx ];
462
+ }
463
+ }
464
+ #endif
465
+
466
+ #if SG_SCALE_FACTOR > 1
467
+ if (sgid < SUBGROUPS_PER_WG / SG_SCALE_FACTOR ) {
468
+ #endif
469
+
432
470
if (seq_len > SEQ_LEN_PARTITION_SIZE ) {
433
471
const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * total_partitions_num ) +
434
472
head_num_idx * (HEAD_SIZE * total_partitions_num ) +
@@ -446,6 +484,11 @@ KERNEL(pa_sdpa_opt)(
446
484
output [output_offset ] = acc ;
447
485
}
448
486
487
+ #if SG_SCALE_FACTOR > 1
488
+ }
489
+ #endif
490
+
491
+
449
492
}
450
493
}
451
494
0 commit comments