Skip to content

Commit aa6ff61

Browse files
committed
Fix kernel arguments and adjust min seq_len
1 parent 38107ba commit aa6ff61

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
238238

239239
args.outputs = { instance.output_memory_ptr(0) };
240240
} else if (stage == Stage::PA_SDPA) {
241-
if (kernel_idx == 0 || kernel_idx == 1) {
241+
if (kernel_idx == 0 || kernel_idx == 1 || kernel_idx == 2) {
242242
// 2nd+ token calculation or mixed stage tokens calculation
243243
args.shape_info = instance.shape_info_memory_ptr();
244244

@@ -262,7 +262,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
262262
if (desc->has_alibi) {
263263
args.inputs.push_back(instance.alibi_memory_ptr());
264264
}
265-
} else if (kernel_idx == 2 || kernel_idx == 3) {
265+
} else if (kernel_idx == 3 || kernel_idx == 4) {
266266
// Finalization kernel or mixed stage finalization kernel
267267
args.inputs = { instance.past_lens_memory_ptr() };
268268

@@ -276,15 +276,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
276276
args.inputs.push_back(instance.rotation_deltas_memory_ptr());
277277
args.inputs.push_back(instance.rotation_trig_lut_memory_ptr());
278278
}
279-
} else if (kernel_idx == 4) {
279+
} else if (kernel_idx == 5) {
280280
// Output scores calculation kernel
281281
args.inputs = { instance.past_lens_memory_ptr(),
282282
instance.subsequence_begins_memory_ptr() };
283283
}
284284

285285
args.outputs = { instance.output_memory_ptr(0) };
286286

287-
if (kernel_idx == 4) {
287+
if (kernel_idx == 5) {
288288
args.outputs.push_back(instance.output_memory_ptr(1));
289289
}
290290
}

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons
369369
const auto scores_calc_only = prim_params.stage == PagedAttentionStage::PREFILL && has_scores_output;
370370
const auto multi_tokens_mode = prim_params.stage == PagedAttentionStage::MIXED;
371371

372-
// Apply GQA optimization starting from a certain sequence length value
373-
const auto min_gqa_sequence_len = 8 * seq_len_partition_size;
372+
// Apply GQA optimization starting from a certain sequence length (4K tokens) value
373+
const auto min_gqa_sequence_len = 16 * seq_len_partition_size;
374374
// Apply GQA only if there is a single subsequence in the request,
375375
// as multiple subsequences might have significantly different lengths
376376
const auto max_subsequences_num = 1;

0 commit comments

Comments
 (0)