@@ -107,7 +107,7 @@ KERNEL(pa_sdpa_opt)(
107
107
#endif
108
108
109
109
// SLM for intermediate QK results
110
- __local OUTPUT_TYPE slm_qk_vals [SEQ_LEN_PARTITION_SIZE ];
110
+ __local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals [SEQ_LEN_PARTITION_SIZE ];
111
111
112
112
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
113
113
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals [SUBGROUPS_PER_WG ];
@@ -166,7 +166,7 @@ KERNEL(pa_sdpa_opt)(
166
166
#endif
167
167
const uint block_offset = block_indices [start_block_idx + block_num * SUBGROUPS_PER_WG ] * HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * HEAD_SIZE * SUBGROUP_SIZE ;
168
168
169
- INPUT0_TYPE qk_acc = INPUT0_VAL_ZERO ;
169
+ SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO ;
170
170
171
171
#define KEY_VEC_SIZE SUBGROUP_SIZE
172
172
unroll_for (uint qk_idx = 0 ; qk_idx < HEAD_SIZE / KEY_VEC_SIZE ; qk_idx ++ ) {
@@ -181,9 +181,9 @@ KERNEL(pa_sdpa_opt)(
181
181
182
182
unroll_for (uint i = 0 ; i < KEY_VEC_SIZE ; i ++ ) {
183
183
#if STORE_QUERY_TO_SLM
184
- qk_acc = mad (sub_group_broadcast (q_val , i ), k_vals [i ], qk_acc );
184
+ qk_acc = mad (TO_SOFTMAX_ACCUMULATOR_TYPE ( sub_group_broadcast (q_val , i )), TO_SOFTMAX_ACCUMULATOR_TYPE ( k_vals [i ]) , qk_acc );
185
185
#else
186
- qk_acc = mad (sub_group_broadcast (q_val [qk_idx ], i ), k_vals [i ], qk_acc );
186
+ qk_acc = mad (TO_SOFTMAX_ACCUMULATOR_TYPE ( sub_group_broadcast (q_val [qk_idx ], i )), TO_SOFTMAX_ACCUMULATOR_TYPE ( k_vals [i ]) , qk_acc );
187
187
#endif
188
188
}
189
189
}
@@ -196,7 +196,7 @@ KERNEL(pa_sdpa_opt)(
196
196
#endif
197
197
198
198
if (token_idx >= seq_len )
199
- qk_acc = INPUT0_VAL_MIN ;
199
+ qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN ;
200
200
201
201
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC (qk_max , TO_SOFTMAX_ACCUMULATOR_TYPE (qk_acc ));
202
202
@@ -235,7 +235,7 @@ KERNEL(pa_sdpa_opt)(
235
235
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE ) {
236
236
#endif
237
237
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp (TO_SOFTMAX_ACCUMULATOR_TYPE (slm_qk_vals [local_data_idx ]) - qk_max );
238
- slm_qk_vals [local_data_idx ] = TO_OUTPUT_TYPE ( qk_new ) ;
238
+ slm_qk_vals [local_data_idx ] = qk_new ;
239
239
240
240
exp_sum += qk_new ;
241
241
}
@@ -266,7 +266,7 @@ KERNEL(pa_sdpa_opt)(
266
266
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE ) {
267
267
#endif
268
268
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE (slm_qk_vals [local_data_idx ]) / exp_sum ;
269
- slm_qk_vals [local_data_idx ] = TO_OUTPUT_TYPE ( qk_new ) ;
269
+ slm_qk_vals [local_data_idx ] = qk_new ;
270
270
}
271
271
}
272
272
0 commit comments