Skip to content

Commit f9dd9fd

Browse files
committed
TEST: [GPU] Use FP32 accumulator for QK multiplication for 2nd+ token calculation in PagedAttention
1 parent 69ff32d commit f9dd9fd

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+7-7
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ KERNEL(pa_sdpa_opt)(
107107
#endif
108108

109109
// SLM for intermediate QK results
110-
__local OUTPUT_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
110+
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
111111

112112
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
113113
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG];
@@ -168,7 +168,7 @@ KERNEL(pa_sdpa_opt)(
168168
#endif
169169
const uint block_offset = block_indices[start_block_idx + block_num * SUBGROUPS_PER_WG] * HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * HEAD_SIZE * SUBGROUP_SIZE;
170170

171-
INPUT0_TYPE qk_acc = INPUT0_VAL_ZERO;
171+
SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO;
172172

173173
#define KEY_VEC_SIZE SUBGROUP_SIZE
174174
unroll_for (uint qk_idx = 0; qk_idx < HEAD_SIZE / KEY_VEC_SIZE; qk_idx++) {
@@ -183,9 +183,9 @@ KERNEL(pa_sdpa_opt)(
183183

184184
unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
185185
#if STORE_QUERY_TO_SLM
186-
qk_acc = mad(sub_group_broadcast(q_val, i), k_vals[i], qk_acc);
186+
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val, i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
187187
#else
188-
qk_acc = mad(sub_group_broadcast(q_val[qk_idx], i), k_vals[i], qk_acc);
188+
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
189189
#endif
190190
}
191191
}
@@ -198,7 +198,7 @@ KERNEL(pa_sdpa_opt)(
198198
#endif
199199

200200
if (token_idx >= seq_len)
201-
qk_acc = INPUT0_VAL_MIN;
201+
qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN;
202202

203203
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc));
204204

@@ -237,7 +237,7 @@ KERNEL(pa_sdpa_opt)(
237237
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
238238
#endif
239239
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max);
240-
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
240+
slm_qk_vals[local_data_idx] = qk_new;
241241

242242
exp_sum += qk_new;
243243
}
@@ -268,7 +268,7 @@ KERNEL(pa_sdpa_opt)(
268268
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
269269
#endif
270270
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum;
271-
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
271+
slm_qk_vals[local_data_idx] = qk_new;
272272
}
273273
}
274274

0 commit comments

Comments
 (0)