@@ -1004,6 +1004,7 @@ KERNEL(sdpa_opt)(
1004
1004
const uint partition_seq_len = min ((uint )SOURCE_SEQ_LEN - start_partition_idx , (uint )SEQ_LEN_PARTITION_SIZE );
1005
1005
#endif
1006
1006
1007
+ MAKE_VECTOR_TYPE (INPUT0_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_acc = INPUT0_VAL_ZERO ;
1007
1008
#if IS_CAUSAL
1008
1009
if (seq_len <= target_seq_idx ) { // keep tril i.e. m >= n
1009
1010
#endif
@@ -1037,11 +1038,7 @@ KERNEL(sdpa_opt)(
1037
1038
#endif
1038
1039
1039
1040
int seq_len_calc_size = min ((int )(SOURCE_SEQ_LEN ) - (int )seq_len , (int )SUBGROUP_SIZE );
1040
- #if IS_CAUSAL
1041
- MAKE_VECTOR_TYPE (INPUT0_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_acc = INPUT0_VAL_ZERO ;
1042
- #else // !IS_CAUSAL
1043
- MAKE_VECTOR_TYPE (INPUT0_TYPE , TARGET_SEQ_LEN_BLOCK_SIZE ) qk_acc ;
1044
-
1041
+ #if !IS_CAUSAL
1045
1042
qk_acc = FUNC_CALL (load_attn_mask )(OPTIONAL_SHAPE_INFO_TENSOR
1046
1043
b0_idx ,
1047
1044
b1_idx ,
@@ -1279,39 +1276,6 @@ KERNEL(sdpa_opt)(
1279
1276
}
1280
1277
}
1281
1278
1282
- #if PAGED_ATTENTION_SCORES_OUTPUT
1283
- const uint subsequence_idx = gws_seq_indexes_correspondence [target_seq_dim ];
1284
- const uint subsequence_end_pos = subsequence_begins [subsequence_idx + 1 ];
1285
- const uint block_start_pos = blocked_indexes_start [target_seq_dim ];
1286
- const uint block_end_pos = blocked_indexes_end [target_seq_dim ];
1287
-
1288
- // PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
1289
- // so save SEQ_LEN_PARTITION_SIZE elements for each partition
1290
- if (subsequence_end_pos == block_end_pos ) {
1291
- const uint last_row_idx = block_end_pos - block_start_pos - 1 ;
1292
- if (sglid == last_row_idx ) {
1293
- const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE ;
1294
-
1295
- if (sgid == 0 ) {
1296
- const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE ;
1297
- const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num +
1298
- num_heads_dim * max_partitions_num +
1299
- partition_idx ;
1300
- exp_sums [exp_sums_output_offset ] = exp_sum_new ;
1301
- max_logits [exp_sums_output_offset ] = qk_max_new ;
1302
- }
1303
-
1304
- const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
1305
- num_heads_dim * aligned_max_context_len +
1306
- partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE ;
1307
- for (uint i = 0 ; i < TARGET_SEQ_LEN_BLOCK_SIZE ; i ++ ) {
1308
- softmax_results [output_offset + i ] = qk_acc [i ];
1309
- }
1310
-
1311
- }
1312
- }
1313
- #endif
1314
-
1315
1279
barrier (CLK_LOCAL_MEM_FENCE );
1316
1280
}
1317
1281
0 commit comments