Skip to content

Commit 87907da

Browse files
committed
Micro sdpa causal mask
1 parent 1fdc831 commit 87907da

File tree

1 file changed

+13
-15
lines changed
  • src/plugins/intel_gpu/src/kernel_selector/cl_kernels

1 file changed

+13
-15
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl

+13-15
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,21 @@
3434
typedef ugemm_kq_c_type s_tile_type;
3535
typedef ugemm_vs_c_type a_tile_type;
3636

37-
38-
39-
#define DECLARE_2D_MASK_FILL(tile_type, element_type, sg, br, bc, nbr, nbc) \
40-
__attribute__((overloadable)) void fill_mask_t(tile_type *t, \
37+
#define DECLARE_2D_CAUSAL_MASK(tile_type, element_type, sg, br, bc, nbr, nbc) \
38+
__attribute__((overloadable)) void apply_causal_mask_t(tile_type *t, \
4139
int m, int n, int ld, \
42-
int offset_r, int offset_c) { \
40+
int offset_r, int offset_c, \
41+
float iscale) { \
4342
if (offset_c + bc * nbc - 1 <= offset_r) { \
44-
tile_fill((*t), 0); \
4543
return; \
4644
} else if (offset_r + br * nbr <= offset_c) { \
47-
tile_fill((*t), -HALF_MAX); \
45+
tile_fill((*t), (-HALF_MAX * iscale)); \
4846
return; \
4947
} \
5048
_Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \
5149
int i = i0 + get_sub_group_local_id(); \
5250
_Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \
53-
tile_access(*t, i0, j, sg, br, bc, nbr) = (offset_c + j > offset_r + i) ? -HALF_MAX : 0.0f; \
51+
tile_access(*t, i0, j, sg, br, bc, nbr) += (offset_c + j > offset_r + i) ? (-HALF_MAX * iscale) : 0.0f; \
5452
} \
5553
} \
5654
}
@@ -86,7 +84,7 @@ DECLARE_2D_TILE(
8684

8785
DECLARE_2D_TILE(mask_tile_type, half, SUBGROUP_SIZE, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, ugemm_kq_c_type_nblock1)
8886
DECLARE_2D_TILE(mask_tile_type_float, float, SUBGROUP_SIZE, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, ugemm_kq_c_type_nblock1)
89-
DECLARE_2D_MASK_FILL(mask_tile_type_float, float, SUBGROUP_SIZE, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, ugemm_kq_c_type_nblock1)
87+
DECLARE_2D_CAUSAL_MASK(mask_tile_type_float, float, SUBGROUP_SIZE, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, ugemm_kq_c_type_nblock1)
9088

9189
#ifdef BLOCK_A
9290
DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_half, half, SUBGROUP_SIZE,
@@ -472,12 +470,12 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
472470

473471
#if IS_PAGED_ATTENTION
474472
#define unscale(x) ((x)*iscale)
475-
mask_tile_type_float mask_tile_float;
473+
// mask_tile_type_float mask_tile_float;
476474
// tile_fill(mask_tile_float, -66.0f);
477475
#if DEBUG_PRINT
478476
if (b0 == 0 && b1 == 0 && k0 == 0 && sg_i_kq == 0 && sg_j_kq == 0 && get_sub_group_local_id() == 0) {
479477
printf("mask after init:\n");
480-
printf("fill_mask_t(q=%d, k=%d, col_r=%d, col_c=%d; br=%d bc=%d nbr=%d nbc=%d)\n", q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, ugemm_kq_c_type_nblock1);
478+
printf("apply_causal_mask_t(q=%d, k=%d, col_r=%d, col_c=%d; br=%d bc=%d nbr=%d nbc=%d)\n", q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, ugemm_kq_c_type_nblock1);
481479
for (int i = 0; i < 8; i++) {
482480
for (int j = 0; j < 8; j++)
483481
printf("\t%f", xlane_tile_access(mask_tile_float, i, j, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
@@ -487,7 +485,7 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
487485
}
488486
#endif
489487

490-
fill_mask_t(&mask_tile_float, q, k, q, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
488+
apply_causal_mask_t(&S_tile, q, k, q, sg_j0_kq + wg_j0, k0 + sg_i0_kq, iscale);
491489

492490
#if DEBUG_PRINT
493491
if (b0 == 0 && b1 == 0 && k0 == 0 && sg_i_kq == 0 && sg_j_kq == 0 && (get_sub_group_local_id() < 4)) {
@@ -501,7 +499,7 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
501499
#if DEBUG_PRINT
502500
if (b0 == 0 && b1 == 0 && k0 == 0 && sg_i_kq == 0 && sg_j_kq == 0 && get_sub_group_local_id() == 0) {
503501
printf("updated mask before scale:\n");
504-
printf("fill_mask_t(q=%d, k=%d, col_r=%d, col_c=%d; br=%d bc=%d nbr=%d nbc=%d)\n", q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, ugemm_kq_c_type_nblock1);
502+
printf("apply_causal_mask_t(q=%d, k=%d, col_r=%d, col_c=%d; br=%d bc=%d nbr=%d nbc=%d)\n", q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, ugemm_kq_c_type_nblock1);
505503
for (int i = 0; i < 8; i++) {
506504
for (int j = 0; j < 8; j++)
507505
printf("\t%f", xlane_tile_access(mask_tile_float, i, j, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
@@ -511,8 +509,8 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
511509
}
512510
#endif
513511

514-
tile_elementwise(mask_tile_float, unscale);
515-
tile_binary(S_tile, mask_tile_float, binary_add);
512+
// tile_elementwise(mask_tile_float, unscale);
513+
// tile_binary(S_tile, mask_tile_float, binary_add);
516514
#endif
517515

518516
#if DEBUG_PRINT

0 commit comments

Comments
 (0)