Skip to content

Commit 84de26d

Browse files
p-durandinvshampor
authored andcommitted
[GPU] PA, rotation minor fixes
1 parent c8ab4eb commit 84de26d

File tree

3 files changed

+28
-41
lines changed

3 files changed

+28
-41
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
214214
if (desc->has_alibi) {
215215
args.inputs.push_back(instance.alibi_memory_ptr());
216216
}
217+
218+
if (desc->has_rotated_blocks) {
219+
args.inputs.push_back(instance.rotated_block_indices_memory_ptr());
220+
args.inputs.push_back(instance.rotation_deltas_memory_ptr());
221+
args.inputs.push_back(instance.rotation_trig_lut_memory_ptr());
222+
}
217223
} else if (kernel_idx == 2 || kernel_idx == 3) {
218224
// Finalization kernel or mixed stage finalization kernel
219225
args.inputs = { instance.past_lens_memory_ptr() };
@@ -681,6 +687,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
681687
if (has_alibi)
682688
inputs_number++;
683689

690+
const auto has_rotation = impl_param.input_layouts.size() == 16;
691+
if (has_rotation)
692+
inputs_number += 3;
693+
684694
auto input_idx = 0;
685695
params.inputs.resize(inputs_number);
686696
params.inputs[input_idx++] = query_tensor;
@@ -699,6 +709,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
699709
if (has_alibi)
700710
params.inputs[input_idx++] = alibi_tensor;
701711

712+
if (has_rotation) {
713+
params.inputs[input_idx++] = input_tensors[13];
714+
params.inputs[input_idx++] = input_tensors[14];
715+
params.inputs[input_idx++] = input_tensors[15];
716+
}
717+
702718
if (has_scores_output) {
703719
params.outputs.resize(2);
704720
params.outputs[1] = convert_data_tensor(impl_param.get_output_layout(1));
@@ -736,6 +752,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
736752
if (has_alibi)
737753
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)});
738754

755+
if (has_rotation) {
756+
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(13)});
757+
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(14)});
758+
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(15)});
759+
}
760+
739761
if (has_scores_output)
740762
out_tensor_to_offset_map.insert({1, out_offsets_map.at(1)});
741763

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+4-3
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ KERNEL(pa_sdpa_opt)(
4343
#if HAS_ALIBI
4444
const __global ALIBI_INPUT_TYPE* alibi_slopes,
4545
#endif
46+
4647
#if HAS_ROTATED_BLOCKS
47-
const __global INPUT8_TYPE* rotated_block_indices,
48-
const __global INPUT9_TYPE* rotation_deltas,
49-
const __global INPUT10_TYPE* rotation_trig_lut,
48+
const __global INPUT7_TYPE* rotated_block_indices,
49+
const __global INPUT8_TYPE* rotation_deltas,
50+
const __global INPUT9_TYPE* rotation_trig_lut,
5051
#endif
5152
__global OUTPUT_TYPE* output,
5253
#if PAGED_ATTENTION_SCORES_OUTPUT

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl

+2-38
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,7 @@ KERNEL(sdpa_opt)(
10041004
const uint partition_seq_len = min((uint)SOURCE_SEQ_LEN - start_partition_idx, (uint)SEQ_LEN_PARTITION_SIZE);
10051005
#endif
10061006

1007+
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZERO;
10071008
#if IS_CAUSAL
10081009
if (seq_len <= target_seq_idx) { // keep tril i.e. m >= n
10091010
#endif
@@ -1037,11 +1038,7 @@ KERNEL(sdpa_opt)(
10371038
#endif
10381039

10391040
int seq_len_calc_size = min((int)(SOURCE_SEQ_LEN) - (int)seq_len, (int)SUBGROUP_SIZE);
1040-
#if IS_CAUSAL
1041-
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZERO;
1042-
#else // !IS_CAUSAL
1043-
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc;
1044-
1041+
#if !IS_CAUSAL
10451042
qk_acc = FUNC_CALL(load_attn_mask)(OPTIONAL_SHAPE_INFO_TENSOR
10461043
b0_idx,
10471044
b1_idx,
@@ -1279,39 +1276,6 @@ KERNEL(sdpa_opt)(
12791276
}
12801277
}
12811278

1282-
#if PAGED_ATTENTION_SCORES_OUTPUT
1283-
const uint subsequence_idx = gws_seq_indexes_correspondence[target_seq_dim];
1284-
const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];
1285-
const uint block_start_pos = blocked_indexes_start[target_seq_dim];
1286-
const uint block_end_pos = blocked_indexes_end[target_seq_dim];
1287-
1288-
// PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
1289-
// so save SEQ_LEN_PARTITION_SIZE elements for each partition
1290-
if (subsequence_end_pos == block_end_pos) {
1291-
const uint last_row_idx = block_end_pos - block_start_pos - 1;
1292-
if (sglid == last_row_idx) {
1293-
const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE;
1294-
1295-
if (sgid == 0) {
1296-
const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE;
1297-
const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num +
1298-
num_heads_dim * max_partitions_num +
1299-
partition_idx;
1300-
exp_sums[exp_sums_output_offset] = exp_sum_new;
1301-
max_logits[exp_sums_output_offset] = qk_max_new;
1302-
}
1303-
1304-
const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
1305-
num_heads_dim * aligned_max_context_len +
1306-
partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE;
1307-
for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
1308-
softmax_results[output_offset + i] = qk_acc[i];
1309-
}
1310-
1311-
}
1312-
}
1313-
#endif
1314-
13151279
barrier(CLK_LOCAL_MEM_FENCE);
13161280
}
13171281

0 commit comments

Comments
 (0)