@@ -107,7 +107,7 @@ KERNEL(pa_sdpa_opt)(
107
107
#endif
108
108
109
109
// SLM for intermediate QK results
110
- __local OUTPUT_TYPE slm_qk_vals [SEQ_LEN_PARTITION_SIZE ];
110
+ __local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals [SEQ_LEN_PARTITION_SIZE ];
111
111
112
112
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
113
113
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals [SUBGROUPS_PER_WG ];
@@ -168,7 +168,7 @@ KERNEL(pa_sdpa_opt)(
168
168
#endif
169
169
const uint block_offset = block_indices [start_block_idx + block_num * SUBGROUPS_PER_WG ] * HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * HEAD_SIZE * SUBGROUP_SIZE ;
170
170
171
- INPUT0_TYPE qk_acc = INPUT0_VAL_ZERO ;
171
+ SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO ;
172
172
173
173
#define KEY_VEC_SIZE SUBGROUP_SIZE
174
174
unroll_for (uint qk_idx = 0 ; qk_idx < HEAD_SIZE / KEY_VEC_SIZE ; qk_idx ++ ) {
@@ -183,9 +183,9 @@ KERNEL(pa_sdpa_opt)(
183
183
184
184
unroll_for (uint i = 0 ; i < KEY_VEC_SIZE ; i ++ ) {
185
185
#if STORE_QUERY_TO_SLM
186
- qk_acc = mad (sub_group_broadcast (q_val , i ), k_vals [i ], qk_acc );
186
+ qk_acc = mad (TO_SOFTMAX_ACCUMULATOR_TYPE ( sub_group_broadcast (q_val , i )), TO_SOFTMAX_ACCUMULATOR_TYPE ( k_vals [i ]) , qk_acc );
187
187
#else
188
- qk_acc = mad (sub_group_broadcast (q_val [qk_idx ], i ), k_vals [i ], qk_acc );
188
+ qk_acc = mad (TO_SOFTMAX_ACCUMULATOR_TYPE ( sub_group_broadcast (q_val [qk_idx ], i )), TO_SOFTMAX_ACCUMULATOR_TYPE ( k_vals [i ]) , qk_acc );
189
189
#endif
190
190
}
191
191
}
@@ -198,7 +198,7 @@ KERNEL(pa_sdpa_opt)(
198
198
#endif
199
199
200
200
if (token_idx >= seq_len )
201
- qk_acc = INPUT0_VAL_MIN ;
201
+ qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN ;
202
202
203
203
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC (qk_max , TO_SOFTMAX_ACCUMULATOR_TYPE (qk_acc ));
204
204
@@ -237,7 +237,7 @@ KERNEL(pa_sdpa_opt)(
237
237
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE ) {
238
238
#endif
239
239
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp (TO_SOFTMAX_ACCUMULATOR_TYPE (slm_qk_vals [local_data_idx ]) - qk_max );
240
- slm_qk_vals [local_data_idx ] = TO_OUTPUT_TYPE ( qk_new ) ;
240
+ slm_qk_vals [local_data_idx ] = qk_new ;
241
241
242
242
exp_sum += qk_new ;
243
243
}
@@ -268,7 +268,7 @@ KERNEL(pa_sdpa_opt)(
268
268
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE ) {
269
269
#endif
270
270
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE (slm_qk_vals [local_data_idx ]) / exp_sum ;
271
- slm_qk_vals [local_data_idx ] = TO_OUTPUT_TYPE ( qk_new ) ;
271
+ slm_qk_vals [local_data_idx ] = qk_new ;
272
272
}
273
273
}
274
274
0 commit comments