Skip to content

Commit af77957

Browse files
committed
TEST: [GPU] Use FP32 accumulator for QK multiplication for 2nd+ token calculation in PagedAttention
1 parent cfbc998 commit af77957

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+7-7
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ KERNEL(pa_sdpa_opt)(
107107
#endif
108108

109109
// SLM for intermediate QK results
110-
__local OUTPUT_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
110+
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
111111

112112
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
113113
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG];
@@ -166,7 +166,7 @@ KERNEL(pa_sdpa_opt)(
166166
#endif
167167
const uint block_offset = block_indices[start_block_idx + block_num * SUBGROUPS_PER_WG] * HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * HEAD_SIZE * SUBGROUP_SIZE;
168168

169-
INPUT0_TYPE qk_acc = INPUT0_VAL_ZERO;
169+
SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO;
170170

171171
#define KEY_VEC_SIZE SUBGROUP_SIZE
172172
unroll_for (uint qk_idx = 0; qk_idx < HEAD_SIZE / KEY_VEC_SIZE; qk_idx++) {
@@ -181,9 +181,9 @@ KERNEL(pa_sdpa_opt)(
181181

182182
unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
183183
#if STORE_QUERY_TO_SLM
184-
qk_acc = mad(sub_group_broadcast(q_val, i), k_vals[i], qk_acc);
184+
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val, i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
185185
#else
186-
qk_acc = mad(sub_group_broadcast(q_val[qk_idx], i), k_vals[i], qk_acc);
186+
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
187187
#endif
188188
}
189189
}
@@ -196,7 +196,7 @@ KERNEL(pa_sdpa_opt)(
196196
#endif
197197

198198
if (token_idx >= seq_len)
199-
qk_acc = INPUT0_VAL_MIN;
199+
qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN;
200200

201201
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc));
202202

@@ -235,7 +235,7 @@ KERNEL(pa_sdpa_opt)(
235235
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
236236
#endif
237237
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max);
238-
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
238+
slm_qk_vals[local_data_idx] = qk_new;
239239

240240
exp_sum += qk_new;
241241
}
@@ -266,7 +266,7 @@ KERNEL(pa_sdpa_opt)(
266266
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
267267
#endif
268268
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum;
269-
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
269+
slm_qk_vals[local_data_idx] = qk_new;
270270
}
271271
}
272272

0 commit comments

Comments
 (0)