Skip to content

Commit 5af74a8

Browse files
committed
[GPU] Enable FP32 accumulators for Q*K and QK*V multiplications in sdpa_opt
1 parent ed11461 commit 5af74a8

File tree

2 files changed

+64
-19
lines changed

2 files changed

+64
-19
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl

+23-19
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,11 @@ KERNEL(sdpa_opt)(
730730
#define APPLY_SCALES_TO_QUERY 1
731731
#endif
732732

733-
#define MASK_VECTOR_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE)
733+
#if FORCE_SCALE_TO_QUERY
734+
#define APPLY_SCALES_TO_QUERY 1
735+
#endif
736+
737+
#define MASK_VECTOR_TYPE MAKE_VECTOR_TYPE(QK_ACCUMULATOR_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE)
734738

735739
inline MASK_VECTOR_TYPE FUNC(load_attn_mask)(OPTIONAL_SHAPE_INFO_ARG
736740
uint b0_idx,
@@ -880,7 +884,7 @@ KERNEL(sdpa_opt)(
880884
__local INPUT0_TYPE slm_query[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE];
881885

882886
// SLM buffer for intermediate QK results
883-
__local OUTPUT_TYPE slm_qk_vals[TARGET_SEQ_LEN_BLOCK_SIZE][SEQ_LEN_PARTITION_SIZE];
887+
__local QK_ACCUMULATOR_TYPE slm_qk_vals[TARGET_SEQ_LEN_BLOCK_SIZE][SEQ_LEN_PARTITION_SIZE];
884888

885889
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
886890
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[TARGET_SEQ_LEN_BLOCK_SIZE][SUBGROUPS_PER_WG];
@@ -993,7 +997,7 @@ KERNEL(sdpa_opt)(
993997
}
994998

995999
// Q*K calculation loop
996-
MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) output_acc = OUTPUT_VAL_ZERO;
1000+
MAKE_VECTOR_TYPE(SV_ACCUMULATOR_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) output_acc = OUTPUT_VAL_ZERO;
9971001

9981002
__attribute__((opencl_unroll_hint(1)))
9991003
for (uint start_partition_idx = 0; start_partition_idx < SOURCE_SEQ_LEN; start_partition_idx += SEQ_LEN_PARTITION_SIZE) {
@@ -1004,7 +1008,7 @@ KERNEL(sdpa_opt)(
10041008
const uint partition_seq_len = min((uint)SOURCE_SEQ_LEN - start_partition_idx, (uint)SEQ_LEN_PARTITION_SIZE);
10051009
#endif
10061010

1007-
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZERO;
1011+
MAKE_VECTOR_TYPE(QK_ACCUMULATOR_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZERO;
10081012
#if IS_CAUSAL
10091013
if (seq_len <= target_seq_idx) { // keep tril i.e. m >= n
10101014
#endif
@@ -1086,7 +1090,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
10861090
#endif
10871091

10881092
unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) {
1089-
qk_acc[key_row_idx] = mad(sub_group_broadcast(key_vals, i), queries_vec[i], qk_acc[key_row_idx]);
1093+
qk_acc[key_row_idx] = mad(TO_QK_ACCUMULATOR_TYPE(sub_group_broadcast(key_vals, i)), TO_QK_ACCUMULATOR_TYPE(queries_vec[i]), qk_acc[key_row_idx]);
10901094
}
10911095
}
10921096
}
@@ -1156,7 +1160,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
11561160
#define key_vals key_vec[key_row_idx]
11571161
#endif
11581162
unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) {
1159-
qk_acc[key_row_idx] = mad(sub_group_broadcast(key_vals, i), queries_vec[i], qk_acc[key_row_idx]);
1163+
qk_acc[key_row_idx] = mad(TO_QK_ACCUMULATOR_TYPE(sub_group_broadcast(key_vals, i)), TO_QK_ACCUMULATOR_TYPE(queries_vec[i]), qk_acc[key_row_idx]);
11601164
}
11611165
}
11621166
}
@@ -1183,10 +1187,10 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
11831187
qk_acc[i] += alibi_slopes[num_heads_dim] * alibi_val;
11841188
#endif
11851189

1186-
qk_acc[i] = INPUT0_MIN_FUNC(INPUT0_MAX_FUNC(qk_acc[i], INPUT0_VAL_MIN), INPUT0_VAL_MAX);
1190+
qk_acc[i] = QK_ACCUMULATOR_MIN_FUNC(QK_ACCUMULATOR_MAX_FUNC(qk_acc[i], QK_ACCUMULATOR_VAL_MIN), QK_ACCUMULATOR_VAL_MAX);
11871191
#if IS_CAUSAL
11881192
} else {
1189-
qk_acc[i] = INPUT0_VAL_MIN;
1193+
qk_acc[i] = QK_ACCUMULATOR_VAL_MIN;
11901194
}
11911195
#endif // IS_CAUSAL
11921196
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc[i]));
@@ -1226,7 +1230,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
12261230
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO;
12271231
for (uint k = sglid; k < partition_seq_len; k += SUBGROUP_SIZE) {
12281232
SOFTMAX_ACCUMULATOR_TYPE a = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[m][k]) - qk_max_new);
1229-
slm_qk_vals[m][k] = TO_OUTPUT_TYPE(a);
1233+
slm_qk_vals[m][k] = TO_QK_ACCUMULATOR_TYPE(a);
12301234
exp_sum_new += a;
12311235
}
12321236
exp_sum_new = sub_group_reduce_add(exp_sum_new);
@@ -1281,7 +1285,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
12811285

12821286
{
12831287
// QK*V calculation
1284-
MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) acc_output_res = OUTPUT_VAL_ZERO;
1288+
MAKE_VECTOR_TYPE(SV_ACCUMULATOR_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) acc_output_res = OUTPUT_VAL_ZERO;
12851289
#if IS_PAGED_ATTENTION
12861290
const uint value_pitch = (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM);
12871291
#else
@@ -1322,7 +1326,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
13221326
#endif
13231327
#endif
13241328

1325-
MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
1329+
MAKE_VECTOR_TYPE(SV_ACCUMULATOR_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
13261330
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
13271331
qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len + sglid];
13281332
}
@@ -1350,7 +1354,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
13501354
#endif
13511355

13521356
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
1353-
acc_output_res[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], i), value_val, acc_output_res[seq_idx]);
1357+
acc_output_res[seq_idx] = mad(TO_SV_ACCUMULATOR_TYPE(sub_group_broadcast(qk_val[seq_idx], i)), TO_SV_ACCUMULATOR_TYPE(value_val), acc_output_res[seq_idx]);
13541358
}
13551359

13561360
#ifndef BEAM_TABLE_TYPE
@@ -1398,7 +1402,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
13981402
#endif
13991403
#endif
14001404

1401-
MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
1405+
MAKE_VECTOR_TYPE(SV_ACCUMULATOR_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
14021406
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
14031407
qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len * SUBGROUP_SIZE + sglid];
14041408
}
@@ -1418,7 +1422,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
14181422
INPUT2_TYPE value_val = value_packed;
14191423
#endif
14201424
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
1421-
acc_output_res[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], i), value_val, acc_output_res[seq_idx]);
1425+
acc_output_res[seq_idx] = mad(TO_SV_ACCUMULATOR_TYPE(sub_group_broadcast(qk_val[seq_idx], i)), TO_SV_ACCUMULATOR_TYPE(value_val), acc_output_res[seq_idx]);
14221426
}
14231427

14241428
#ifndef BEAM_TABLE_TYPE
@@ -1430,7 +1434,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
14301434
// QK*V leftovers processing
14311435
const uint seq_len_leftovers_start = ((seq_len_end / SUBGROUP_SIZE) * SUBGROUP_SIZE);
14321436
if (seq_len_leftovers_start != seq_len_end) {
1433-
MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
1437+
MAKE_VECTOR_TYPE(SV_ACCUMULATOR_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
14341438
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
14351439
qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len_leftovers_start+sglid];
14361440
}
@@ -1484,7 +1488,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
14841488
#endif
14851489

14861490
for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
1487-
acc_output_res[seq_idx] = mad(sub_group_broadcast(qk_val[seq_idx], seq_len_idx), value_val, acc_output_res[seq_idx]);
1491+
acc_output_res[seq_idx] = mad(TO_SV_ACCUMULATOR_TYPE(sub_group_broadcast(qk_val[seq_idx], seq_len_idx)), TO_SV_ACCUMULATOR_TYPE(value_val), acc_output_res[seq_idx]);
14881492
}
14891493

14901494
#ifndef BEAM_TABLE_TYPE
@@ -1502,7 +1506,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
15021506
// Rescale acc_output_res values and save current iter results to global accumulator
15031507
for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
15041508
if (start_partition_idx > 0) {
1505-
OUTPUT_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE(output_acc[seq_idx]) * slm_update_factor[seq_idx];
1509+
SV_ACCUMULATOR_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE(output_acc[seq_idx]) * slm_update_factor[seq_idx];
15061510
acc_output_res[seq_idx] += updated_prev_res;
15071511
}
15081512
output_acc[seq_idx] = acc_output_res[seq_idx];
@@ -1539,13 +1543,13 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
15391543
if (TARGET_SEQ_LEN_BLOCK_SIZE > seq_idx_end) {
15401544
for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
15411545
output_acc[seq_idx] /= slm_exp_sum_prev[seq_idx];
1542-
OUTPUT_BLOCK_WRITE(output, output_offset, output_acc[seq_idx]);
1546+
OUTPUT_BLOCK_WRITE(output, output_offset, TO_OUTPUT_TYPE(output_acc[seq_idx]));
15431547
output_offset += output_pitch;
15441548
}
15451549
} else {
15461550
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
15471551
output_acc[seq_idx] /= slm_exp_sum_prev[seq_idx];
1548-
OUTPUT_BLOCK_WRITE(output, output_offset, output_acc[seq_idx]);
1552+
OUTPUT_BLOCK_WRITE(output, output_offset, TO_OUTPUT_TYPE(output_acc[seq_idx]));
15491553
output_offset += output_pitch;
15501554
}
15511555
}

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,49 @@ bool SDPAKernelOpt::Validate(const Params& p) const {
170170
JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t kernel_idx) const {
171171
auto jit = SDPAKernelBase::GetJitConstants(params);
172172

173+
int USE_FP32_QK = 0;
174+
if (const auto env_var = std::getenv("USE_FP32_QK")) {
175+
std::istringstream ss(env_var);
176+
ss >> USE_FP32_QK;
177+
static bool printed = false;
178+
if (!printed) {
179+
std::cout << "Set USE_FP32_QK=" << USE_FP32_QK << "\n";
180+
printed = true;
181+
}
182+
}
183+
184+
int USE_FP32_QKV = 0;
185+
if (const auto env_var = std::getenv("USE_FP32_QKV")) {
186+
std::istringstream ss(env_var);
187+
ss >> USE_FP32_QKV;
188+
static bool printed = false;
189+
if (!printed) {
190+
std::cout << "Set USE_FP32_QKV=" << USE_FP32_QKV << "\n";
191+
printed = true;
192+
}
193+
}
194+
195+
int FORCE_SCALE_TO_QUERY = 0;
196+
if (const auto env_var = std::getenv("FORCE_SCALE_TO_QUERY")) {
197+
std::istringstream ss(env_var);
198+
ss >> FORCE_SCALE_TO_QUERY;
199+
static bool printed = false;
200+
if (!printed) {
201+
std::cout << "Set FORCE_SCALE_TO_QUERY=" << FORCE_SCALE_TO_QUERY << "\n";
202+
printed = true;
203+
}
204+
}
205+
206+
if (FORCE_SCALE_TO_QUERY) {
207+
jit.AddConstant(MakeJitConstant("FORCE_SCALE_TO_QUERY", 1));
208+
}
209+
173210
const auto softmax_acc_dt = get_softmax_acc_type();
211+
const auto qk_acc_dt = USE_FP32_QK ? Datatype::F32 : params.outputs[0].GetDType();
212+
const auto sv_acc_dt = USE_FP32_QKV ? Datatype::F32 : params.outputs[0].GetDType();
174213
jit.Merge(MakeTypeJitConstants(softmax_acc_dt, "SOFTMAX_ACCUMULATOR"));
214+
jit.Merge(MakeTypeJitConstants(qk_acc_dt, "QK_ACCUMULATOR"));
215+
jit.Merge(MakeTypeJitConstants(sv_acc_dt, "SV_ACCUMULATOR"));
175216

176217
const auto& config = params.conf;
177218
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));

0 commit comments

Comments
 (0)