34
34
typedef ugemm_kq_c_type s_tile_type ;
35
35
typedef ugemm_vs_c_type a_tile_type ;
36
36
37
-
38
-
39
- #define DECLARE_2D_MASK_FILL (tile_type , element_type , sg , br , bc , nbr , nbc ) \
40
- __attribute__((overloadable)) void fill_mask_t(tile_type *t, \
37
+ #define DECLARE_2D_CAUSAL_MASK (tile_type , element_type , sg , br , bc , nbr , nbc ) \
38
+ __attribute__((overloadable)) void apply_causal_mask_t(tile_type *t, \
41
39
int m, int n, int ld, \
42
- int offset_r, int offset_c) { \
40
+ int offset_r, int offset_c, \
41
+ float iscale) { \
43
42
if (offset_c + bc * nbc - 1 <= offset_r) { \
44
- tile_fill((*t), 0); \
45
43
return; \
46
44
} else if (offset_r + br * nbr <= offset_c) { \
47
- tile_fill((*t), -HALF_MAX); \
45
+ tile_fill((*t), ( -HALF_MAX * iscale) ); \
48
46
return; \
49
47
} \
50
48
_Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \
51
49
int i = i0 + get_sub_group_local_id(); \
52
50
_Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \
53
- tile_access(*t, i0, j, sg, br, bc, nbr) = (offset_c + j > offset_r + i) ? -HALF_MAX : 0.0f; \
51
+ tile_access(*t, i0, j, sg, br, bc, nbr) + = (offset_c + j > offset_r + i) ? ( -HALF_MAX * iscale) : 0.0f; \
54
52
} \
55
53
} \
56
54
}
@@ -86,7 +84,7 @@ DECLARE_2D_TILE(
86
84
87
85
DECLARE_2D_TILE (mask_tile_type , half , SUBGROUP_SIZE , ugemm_kq_c_type_block0 , ugemm_kq_c_type_block1 , ugemm_kq_c_type_nblock0 , ugemm_kq_c_type_nblock1 )
88
86
DECLARE_2D_TILE (mask_tile_type_float , float , SUBGROUP_SIZE , ugemm_kq_c_type_block0 , ugemm_kq_c_type_block1 , ugemm_kq_c_type_nblock0 , ugemm_kq_c_type_nblock1 )
89
- DECLARE_2D_MASK_FILL (mask_tile_type_float , float , SUBGROUP_SIZE , ugemm_kq_c_type_block0 , ugemm_kq_c_type_block1 , ugemm_kq_c_type_nblock0 , ugemm_kq_c_type_nblock1 )
87
+ DECLARE_2D_CAUSAL_MASK (mask_tile_type_float , float , SUBGROUP_SIZE , ugemm_kq_c_type_block0 , ugemm_kq_c_type_block1 , ugemm_kq_c_type_nblock0 , ugemm_kq_c_type_nblock1 )
90
88
91
89
#ifdef BLOCK_A
92
90
DECLARE_2D_TILE_BLOCK_OPS (a_tile_type_half , half , SUBGROUP_SIZE ,
@@ -472,12 +470,12 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
472
470
473
471
#if IS_PAGED_ATTENTION
474
472
#define unscale (x ) ((x)*iscale)
475
- mask_tile_type_float mask_tile_float ;
473
+ // mask_tile_type_float mask_tile_float;
476
474
// tile_fill(mask_tile_float, -66.0f);
477
475
#if DEBUG_PRINT
478
476
if (b0 == 0 && b1 == 0 && k0 == 0 && sg_i_kq == 0 && sg_j_kq == 0 && get_sub_group_local_id () == 0 ) {
479
477
printf ("mask after init:\n" );
480
- printf ("fill_mask_t (q=%d, k=%d, col_r=%d, col_c=%d; br=%d bc=%d nbr=%d nbc=%d)\n" , q , k , sg_j0_kq + wg_j0 , k0 + sg_i0_kq , ugemm_kq_c_type_block0 , ugemm_kq_c_type_block1 , ugemm_kq_c_type_nblock0 , ugemm_kq_c_type_nblock1 );
478
+ printf ("apply_causal_mask_t (q=%d, k=%d, col_r=%d, col_c=%d; br=%d bc=%d nbr=%d nbc=%d)\n" , q , k , sg_j0_kq + wg_j0 , k0 + sg_i0_kq , ugemm_kq_c_type_block0 , ugemm_kq_c_type_block1 , ugemm_kq_c_type_nblock0 , ugemm_kq_c_type_nblock1 );
481
479
for (int i = 0 ; i < 8 ; i ++ ) {
482
480
for (int j = 0 ; j < 8 ; j ++ )
483
481
printf ("\t%f" , xlane_tile_access (mask_tile_float , i , j , SUBGROUP_SIZE , ugemm_kq_c_type_block0 ,
@@ -487,7 +485,7 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
487
485
}
488
486
#endif
489
487
490
- fill_mask_t ( & mask_tile_float , q , k , q , sg_j0_kq + wg_j0 , k0 + sg_i0_kq );
488
+ apply_causal_mask_t ( & S_tile , q , k , q , sg_j0_kq + wg_j0 , k0 + sg_i0_kq , iscale );
491
489
492
490
#if DEBUG_PRINT
493
491
if (b0 == 0 && b1 == 0 && k0 == 0 && sg_i_kq == 0 && sg_j_kq == 0 && (get_sub_group_local_id () < 4 )) {
@@ -501,7 +499,7 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
501
499
#if DEBUG_PRINT
502
500
if (b0 == 0 && b1 == 0 && k0 == 0 && sg_i_kq == 0 && sg_j_kq == 0 && get_sub_group_local_id () == 0 ) {
503
501
printf ("updated mask before scale:\n" );
504
- printf ("fill_mask_t (q=%d, k=%d, col_r=%d, col_c=%d; br=%d bc=%d nbr=%d nbc=%d)\n" , q , k , sg_j0_kq + wg_j0 , k0 + sg_i0_kq , ugemm_kq_c_type_block0 , ugemm_kq_c_type_block1 , ugemm_kq_c_type_nblock0 , ugemm_kq_c_type_nblock1 );
502
+ printf ("apply_causal_mask_t (q=%d, k=%d, col_r=%d, col_c=%d; br=%d bc=%d nbr=%d nbc=%d)\n" , q , k , sg_j0_kq + wg_j0 , k0 + sg_i0_kq , ugemm_kq_c_type_block0 , ugemm_kq_c_type_block1 , ugemm_kq_c_type_nblock0 , ugemm_kq_c_type_nblock1 );
505
503
for (int i = 0 ; i < 8 ; i ++ ) {
506
504
for (int j = 0 ; j < 8 ; j ++ )
507
505
printf ("\t%f" , xlane_tile_access (mask_tile_float , i , j , SUBGROUP_SIZE , ugemm_kq_c_type_block0 ,
@@ -511,8 +509,8 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
511
509
}
512
510
#endif
513
511
514
- tile_elementwise (mask_tile_float , unscale );
515
- tile_binary (S_tile , mask_tile_float , binary_add );
512
+ // tile_elementwise(mask_tile_float, unscale);
513
+ // tile_binary(S_tile, mask_tile_float, binary_add);
516
514
#endif
517
515
518
516
#if DEBUG_PRINT
0 commit comments