@@ -124,24 +124,26 @@ KERNEL(pa_sdpa_opt)(
124
124
125
125
{
126
126
#if STORE_QUERY_TO_SLM
127
- const uint query_idx_local = sgid * SUBGROUP_SIZE + sglid ;
128
- const uint query_idx = INPUT0_OFFSET +
129
- seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM ) +
130
- head_num_idx * HEAD_SIZE +
131
- query_idx_local ;
127
+ if (sgid < HEAD_SIZE / SUBGROUP_SIZE ) {
128
+ const uint query_idx_local = sgid * SUBGROUP_SIZE + sglid ;
129
+ const uint query_idx = INPUT0_OFFSET +
130
+ seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM ) +
131
+ head_num_idx * HEAD_SIZE +
132
+ query_idx_local ;
132
133
133
- INPUT0_TYPE q_val = BLOCK_READN (INPUT0_TYPE , 1 , query , query_idx );
134
+ INPUT0_TYPE q_val = BLOCK_READN (INPUT0_TYPE , 1 , query , query_idx );
134
135
135
- // Apply scale value directly to the query input to improve accuracy in case of a high range of input data
136
+ // Apply scale value directly to the query input to improve accuracy in case of a high range of input data
136
137
#ifdef SCALE_VAL
137
- q_val = TO_INPUT0_TYPE (SCALE_VAL ) * q_val ;
138
+ q_val = TO_INPUT0_TYPE (SCALE_VAL ) * q_val ;
138
139
#else
139
- q_val = * scale * q_val ;
140
+ q_val = * scale * q_val ;
140
141
#endif
141
142
142
- slm_query [query_idx_local ] = q_val ;
143
+ slm_query [query_idx_local ] = q_val ;
143
144
144
- barrier (CLK_LOCAL_MEM_FENCE );
145
+ barrier (CLK_LOCAL_MEM_FENCE );
146
+ }
145
147
#else
146
148
INPUT0_TYPE q_val [HEAD_SIZE / SUBGROUP_SIZE ];
147
149
unroll_for (uint i = 0 ; i < HEAD_SIZE / SUBGROUP_SIZE ; i ++ ) {
0 commit comments