@@ -730,7 +730,11 @@ KERNEL(sdpa_opt)(
730
730
#define APPLY_SCALES_TO_QUERY 1
731
731
#endif
732
732
733
- #define MASK_VECTOR_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE)
733
+ #if FORCE_SCALE_TO_QUERY
734
+ #define APPLY_SCALES_TO_QUERY 1
735
+ #endif
736
+
737
+ #define MASK_VECTOR_TYPE MAKE_VECTOR_TYPE(QK_ACCUMULATOR_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE)
734
738
735
739
inline MASK_VECTOR_TYPE FUNC (load_attn_mask )(OPTIONAL_SHAPE_INFO_ARG
736
740
uint b0_idx ,
@@ -880,7 +884,7 @@ KERNEL(sdpa_opt)(
880
884
__local INPUT0_TYPE slm_query [HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE ];
881
885
882
886
// SLM buffer for intermediate QK results
883
- __local OUTPUT_TYPE slm_qk_vals [TARGET_SEQ_LEN_BLOCK_SIZE ][SEQ_LEN_PARTITION_SIZE ];
887
+ __local QK_ACCUMULATOR_TYPE slm_qk_vals [TARGET_SEQ_LEN_BLOCK_SIZE ][SEQ_LEN_PARTITION_SIZE ];
884
888
885
889
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
886
890
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals [TARGET_SEQ_LEN_BLOCK_SIZE ][SUBGROUPS_PER_WG ];
@@ -993,7 +997,7 @@ KERNEL(sdpa_opt)(
993
997
}
994
998
995
999
// Q*K calculation loop
996
- MAKE_VECTOR_TYPE (OUTPUT_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) output_acc = OUTPUT_VAL_ZERO ;
1000
+ MAKE_VECTOR_TYPE (SV_ACCUMULATOR_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) output_acc = OUTPUT_VAL_ZERO ;
997
1001
998
1002
__attribute__((opencl_unroll_hint (1 )))
999
1003
for (uint start_partition_idx = 0 ; start_partition_idx < SOURCE_SEQ_LEN ; start_partition_idx += SEQ_LEN_PARTITION_SIZE ) {
@@ -1004,7 +1008,7 @@ KERNEL(sdpa_opt)(
1004
1008
const uint partition_seq_len = min ((uint )SOURCE_SEQ_LEN - start_partition_idx , (uint )SEQ_LEN_PARTITION_SIZE );
1005
1009
#endif
1006
1010
1007
- MAKE_VECTOR_TYPE (INPUT0_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_acc = INPUT0_VAL_ZERO ;
1011
+ MAKE_VECTOR_TYPE (QK_ACCUMULATOR_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_acc = INPUT0_VAL_ZERO ;
1008
1012
#if IS_CAUSAL
1009
1013
if (seq_len <= target_seq_idx ) { // keep tril i.e. m >= n
1010
1014
#endif
@@ -1086,7 +1090,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1086
1090
#endif
1087
1091
1088
1092
unroll_for (uint i = 0 ; i < SUBGROUP_SIZE ; i ++ ) {
1089
- qk_acc [key_row_idx ] = mad (sub_group_broadcast (key_vals , i ), queries_vec [i ], qk_acc [key_row_idx ]);
1093
+ qk_acc [key_row_idx ] = mad (TO_QK_ACCUMULATOR_TYPE ( sub_group_broadcast (key_vals , i )), TO_QK_ACCUMULATOR_TYPE ( queries_vec [i ]) , qk_acc [key_row_idx ]);
1090
1094
}
1091
1095
}
1092
1096
}
@@ -1156,7 +1160,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1156
1160
#define key_vals key_vec[key_row_idx]
1157
1161
#endif
1158
1162
unroll_for (uint i = 0 ; i < SUBGROUP_SIZE ; i ++ ) {
1159
- qk_acc [key_row_idx ] = mad (sub_group_broadcast (key_vals , i ), queries_vec [i ], qk_acc [key_row_idx ]);
1163
+ qk_acc [key_row_idx ] = mad (TO_QK_ACCUMULATOR_TYPE ( sub_group_broadcast (key_vals , i )), TO_QK_ACCUMULATOR_TYPE ( queries_vec [i ]) , qk_acc [key_row_idx ]);
1160
1164
}
1161
1165
}
1162
1166
}
@@ -1183,10 +1187,10 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1183
1187
qk_acc [i ] += alibi_slopes [num_heads_dim ] * alibi_val ;
1184
1188
#endif
1185
1189
1186
- qk_acc [i ] = INPUT0_MIN_FUNC ( INPUT0_MAX_FUNC (qk_acc [i ], INPUT0_VAL_MIN ), INPUT0_VAL_MAX );
1190
+ qk_acc [i ] = QK_ACCUMULATOR_MIN_FUNC ( QK_ACCUMULATOR_MAX_FUNC (qk_acc [i ], QK_ACCUMULATOR_VAL_MIN ), QK_ACCUMULATOR_VAL_MAX );
1187
1191
#if IS_CAUSAL
1188
1192
} else {
1189
- qk_acc [i ] = INPUT0_VAL_MIN ;
1193
+ qk_acc [i ] = QK_ACCUMULATOR_VAL_MIN ;
1190
1194
}
1191
1195
#endif // IS_CAUSAL
1192
1196
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC (qk_max , TO_SOFTMAX_ACCUMULATOR_TYPE (qk_acc [i ]));
@@ -1226,7 +1230,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1226
1230
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO ;
1227
1231
for (uint k = sglid ; k < partition_seq_len ; k += SUBGROUP_SIZE ) {
1228
1232
SOFTMAX_ACCUMULATOR_TYPE a = native_exp (TO_SOFTMAX_ACCUMULATOR_TYPE (slm_qk_vals [m ][k ]) - qk_max_new );
1229
- slm_qk_vals [m ][k ] = TO_OUTPUT_TYPE (a );
1233
+ slm_qk_vals [m ][k ] = TO_QK_ACCUMULATOR_TYPE (a );
1230
1234
exp_sum_new += a ;
1231
1235
}
1232
1236
exp_sum_new = sub_group_reduce_add (exp_sum_new );
@@ -1281,7 +1285,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1281
1285
1282
1286
{
1283
1287
// QK*V calculation
1284
- MAKE_VECTOR_TYPE (OUTPUT_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) acc_output_res = OUTPUT_VAL_ZERO ;
1288
+ MAKE_VECTOR_TYPE (SV_ACCUMULATOR_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) acc_output_res = OUTPUT_VAL_ZERO ;
1285
1289
#if IS_PAGED_ATTENTION
1286
1290
const uint value_pitch = (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM );
1287
1291
#else
@@ -1322,7 +1326,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1322
1326
#endif
1323
1327
#endif
1324
1328
1325
- MAKE_VECTOR_TYPE (OUTPUT_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_val ;
1329
+ MAKE_VECTOR_TYPE (SV_ACCUMULATOR_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_val ;
1326
1330
unroll_for (uint seq_idx = 0 ; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE ; seq_idx ++ ) {
1327
1331
qk_val [seq_idx ] = slm_qk_vals [seq_idx ][seq_len + sglid ];
1328
1332
}
@@ -1350,7 +1354,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1350
1354
#endif
1351
1355
1352
1356
unroll_for (uint seq_idx = 0 ; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE ; seq_idx ++ ) {
1353
- acc_output_res [seq_idx ] = mad (sub_group_broadcast (qk_val [seq_idx ], i ), value_val , acc_output_res [seq_idx ]);
1357
+ acc_output_res [seq_idx ] = mad (TO_SV_ACCUMULATOR_TYPE ( sub_group_broadcast (qk_val [seq_idx ], i )), TO_SV_ACCUMULATOR_TYPE ( value_val ) , acc_output_res [seq_idx ]);
1354
1358
}
1355
1359
1356
1360
#ifndef BEAM_TABLE_TYPE
@@ -1398,7 +1402,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1398
1402
#endif
1399
1403
#endif
1400
1404
1401
- MAKE_VECTOR_TYPE (OUTPUT_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_val ;
1405
+ MAKE_VECTOR_TYPE (SV_ACCUMULATOR_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_val ;
1402
1406
unroll_for (uint seq_idx = 0 ; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE ; seq_idx ++ ) {
1403
1407
qk_val [seq_idx ] = slm_qk_vals [seq_idx ][seq_len * SUBGROUP_SIZE + sglid ];
1404
1408
}
@@ -1418,7 +1422,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1418
1422
INPUT2_TYPE value_val = value_packed ;
1419
1423
#endif
1420
1424
unroll_for (uint seq_idx = 0 ; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE ; seq_idx ++ ) {
1421
- acc_output_res [seq_idx ] = mad (sub_group_broadcast (qk_val [seq_idx ], i ), value_val , acc_output_res [seq_idx ]);
1425
+ acc_output_res [seq_idx ] = mad (TO_SV_ACCUMULATOR_TYPE ( sub_group_broadcast (qk_val [seq_idx ], i )), TO_SV_ACCUMULATOR_TYPE ( value_val ) , acc_output_res [seq_idx ]);
1422
1426
}
1423
1427
1424
1428
#ifndef BEAM_TABLE_TYPE
@@ -1430,7 +1434,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1430
1434
// QK*V leftovers processing
1431
1435
const uint seq_len_leftovers_start = ((seq_len_end / SUBGROUP_SIZE ) * SUBGROUP_SIZE );
1432
1436
if (seq_len_leftovers_start != seq_len_end ) {
1433
- MAKE_VECTOR_TYPE (OUTPUT_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_val ;
1437
+ MAKE_VECTOR_TYPE (SV_ACCUMULATOR_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_val ;
1434
1438
unroll_for (uint seq_idx = 0 ; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE ; seq_idx ++ ) {
1435
1439
qk_val [seq_idx ] = slm_qk_vals [seq_idx ][seq_len_leftovers_start + sglid ];
1436
1440
}
@@ -1484,7 +1488,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1484
1488
#endif
1485
1489
1486
1490
for (uint seq_idx = 0 ; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE ; seq_idx ++ ) {
1487
- acc_output_res [seq_idx ] = mad (sub_group_broadcast (qk_val [seq_idx ], seq_len_idx ), value_val , acc_output_res [seq_idx ]);
1491
+ acc_output_res [seq_idx ] = mad (TO_SV_ACCUMULATOR_TYPE ( sub_group_broadcast (qk_val [seq_idx ], seq_len_idx )), TO_SV_ACCUMULATOR_TYPE ( value_val ) , acc_output_res [seq_idx ]);
1488
1492
}
1489
1493
1490
1494
#ifndef BEAM_TABLE_TYPE
@@ -1502,7 +1506,7 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1502
1506
// Rescale acc_output_res values and save current iter results to global accumulator
1503
1507
for (uint seq_idx = 0 ; seq_idx < seq_idx_end ; seq_idx ++ ) {
1504
1508
if (start_partition_idx > 0 ) {
1505
- OUTPUT_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE (output_acc [seq_idx ]) * slm_update_factor [seq_idx ];
1509
+ SV_ACCUMULATOR_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE (output_acc [seq_idx ]) * slm_update_factor [seq_idx ];
1506
1510
acc_output_res [seq_idx ] += updated_prev_res ;
1507
1511
}
1508
1512
output_acc [seq_idx ] = acc_output_res [seq_idx ];
@@ -1539,13 +1543,13 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
1539
1543
if (TARGET_SEQ_LEN_BLOCK_SIZE > seq_idx_end ) {
1540
1544
for (uint seq_idx = 0 ; seq_idx < seq_idx_end ; seq_idx ++ ) {
1541
1545
output_acc [seq_idx ] /= slm_exp_sum_prev [seq_idx ];
1542
- OUTPUT_BLOCK_WRITE (output , output_offset , output_acc [seq_idx ]);
1546
+ OUTPUT_BLOCK_WRITE (output , output_offset , TO_OUTPUT_TYPE ( output_acc [seq_idx ]) );
1543
1547
output_offset += output_pitch ;
1544
1548
}
1545
1549
} else {
1546
1550
unroll_for (uint seq_idx = 0 ; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE ; seq_idx ++ ) {
1547
1551
output_acc [seq_idx ] /= slm_exp_sum_prev [seq_idx ];
1548
- OUTPUT_BLOCK_WRITE (output , output_offset , output_acc [seq_idx ]);
1552
+ OUTPUT_BLOCK_WRITE (output , output_offset , TO_OUTPUT_TYPE ( output_acc [seq_idx ]) );
1549
1553
output_offset += output_pitch ;
1550
1554
}
1551
1555
}
0 commit comments