Skip to content

Commit d1b450d

Browse files
committed
Fix for large head_size configuration in case of applied SG_SCALE
1 parent d8edd80 commit d1b450d

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+13-11
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,26 @@ KERNEL(pa_sdpa_opt)(
124124

125125
{
126126
#if STORE_QUERY_TO_SLM
127-
const uint query_idx_local = sgid * SUBGROUP_SIZE + sglid;
128-
const uint query_idx = INPUT0_OFFSET +
129-
seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
130-
head_num_idx * HEAD_SIZE +
131-
query_idx_local;
127+
if (sgid < HEAD_SIZE / SUBGROUP_SIZE) {
128+
const uint query_idx_local = sgid * SUBGROUP_SIZE + sglid;
129+
const uint query_idx = INPUT0_OFFSET +
130+
seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
131+
head_num_idx * HEAD_SIZE +
132+
query_idx_local;
132133

133-
INPUT0_TYPE q_val = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);
134+
INPUT0_TYPE q_val = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);
134135

135-
// Apply scale value directly to the query input to improve accuracy in case of a high range of input data
136+
// Apply scale value directly to the query input to improve accuracy in case of a high range of input data
136137
#ifdef SCALE_VAL
137-
q_val = TO_INPUT0_TYPE(SCALE_VAL) * q_val;
138+
q_val = TO_INPUT0_TYPE(SCALE_VAL) * q_val;
138139
#else
139-
q_val = *scale * q_val;
140+
q_val = *scale * q_val;
140141
#endif
141142

142-
slm_query[query_idx_local] = q_val;
143+
slm_query[query_idx_local] = q_val;
143144

144-
barrier(CLK_LOCAL_MEM_FENCE);
145+
barrier(CLK_LOCAL_MEM_FENCE);
146+
}
145147
#else
146148
INPUT0_TYPE q_val[HEAD_SIZE / SUBGROUP_SIZE];
147149
unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {

0 commit comments

Comments
 (0)