@@ -58,6 +58,33 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
58
58
KV_CACHE_ROTATE,
59
59
};
60
60
61
+ PagedAttentionStage get_paged_attention_stage (const kernel_impl_params& impl_param) const {
62
+ const auto & query_shape = impl_param.get_input_layout (0 ).get_partial_shape ();
63
+ const auto & past_lens_shape = impl_param.get_input_layout (5 ).get_partial_shape ();
64
+
65
+ if (query_shape.is_static () && past_lens_shape.is_static ()) {
66
+ if (query_shape[0 ].get_length () == past_lens_shape[0 ].get_length ()) {
67
+ return PagedAttentionStage::GENERATE;
68
+ }
69
+
70
+ const auto past_lens_idx = 5 ;
71
+ const auto & memory_deps = impl_param.memory_deps ;
72
+ const auto past_lens_mem = memory_deps.at (past_lens_idx);
73
+ mem_lock<int32_t , mem_lock_type::read > past_lens_mem_lock (past_lens_mem, *impl_param.strm );
74
+
75
+ const auto past_lens_size = past_lens_mem_lock.size ();
76
+ for (size_t i = 0 ; i < past_lens_size; i++) {
77
+ if (past_lens_mem_lock[i] != 0 ) {
78
+ return PagedAttentionStage::MIXED;
79
+ }
80
+ }
81
+
82
+ return PagedAttentionStage::PREFILL;
83
+ }
84
+
85
+ return PagedAttentionStage::UNKNOWN;
86
+ }
87
+
61
88
bool requires_update (primitive_inst& inst, const kernel_impl_params& impl_params) const override {
62
89
const auto stage = get_paged_attention_stage (impl_params);
63
90
@@ -67,15 +94,6 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
67
94
return stage == PagedAttentionStage::MIXED;
68
95
}
69
96
70
- void update_inst_params (primitive_inst& inst) const override {
71
- OPENVINO_ASSERT (inst.type () == paged_attention::type_id ());
72
- OPENVINO_ASSERT (inst.get_impl () == this );
73
-
74
- auto & pa_inst = reinterpret_cast <paged_attention_inst&>(inst);
75
- pa_inst.query_block_size = get_query_block_size (PagedAttentionStage::PREFILL);
76
- pa_inst.use_micro_sdpa = use_micro_sdpa;
77
- }
78
-
79
97
size_t get_query_block_size (const PagedAttentionStage& stage) const {
80
98
const auto default_block_size = 16 ;
81
99
@@ -292,6 +310,147 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
292
310
return lockable_ids;
293
311
};
294
312
313
+ void prepare_internal_buffers (paged_attention_inst& instance, const PagedAttentionStage& stage) {
314
+ const auto & desc = instance.get_impl_params ()->typed_desc <paged_attention>();
315
+ const bool has_scores_output = desc->has_scores_output ();
316
+
317
+ if ((stage == PagedAttentionStage::UNKNOWN) ||
318
+ (stage == PagedAttentionStage::GENERATE && !has_scores_output))
319
+ return ;
320
+
321
+ auto & stream = instance.get_network ().get_stream ();
322
+ const auto past_lens_mem = instance.past_lens_memory_ptr ();
323
+ const auto subsequence_begins_mem = instance.subsequence_begins_memory_ptr ();
324
+ auto intermediates_memories = instance.get_intermediates_memories ();
325
+ mem_lock<int32_t , mem_lock_type::read > past_lens_mem_lock (past_lens_mem, stream);
326
+ mem_lock<int32_t , mem_lock_type::read > subsequence_begins_mem_lock (subsequence_begins_mem, stream);
327
+ std::unique_ptr<mem_lock<int32_t , mem_lock_type::write >> subsequence_offsets_lock = nullptr ;
328
+
329
+ if (has_scores_output) {
330
+ const size_t subsequence_offsets_idx = 4 ;
331
+
332
+ OPENVINO_ASSERT (intermediates_memories.size () > subsequence_offsets_idx,
333
+ " [GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation" );
334
+
335
+ auto subsequence_offsets_mem = intermediates_memories[subsequence_offsets_idx];
336
+ subsequence_offsets_lock.reset (new mem_lock<int32_t , mem_lock_type::write >(subsequence_offsets_mem, stream));
337
+ }
338
+
339
+ if (stage == PagedAttentionStage::GENERATE) {
340
+ // For the generate stage it's not necessary to configure any other intermediate
341
+ // buffers. Simply calculate the offsets and exit
342
+ size_t subsequence_offsets_acc = 0 ;
343
+ for (size_t i = 0 ; i < subsequence_begins_mem_lock.size () - 1 ; i++) {
344
+ const auto past_len = past_lens_mem_lock[i];
345
+ const auto seq_start = subsequence_begins_mem_lock[i];
346
+ const auto seq_end = subsequence_begins_mem_lock[i + 1 ];
347
+ const auto seq_length = seq_end - seq_start;
348
+
349
+ if (subsequence_offsets_lock) {
350
+ subsequence_offsets_lock->operator [](i) = static_cast <int32_t >(subsequence_offsets_acc);
351
+ subsequence_offsets_acc += seq_length + past_len;
352
+ }
353
+ }
354
+
355
+ return ;
356
+ }
357
+
358
+ OPENVINO_ASSERT (intermediates_memories.size () >= 3 , " Unexpected number of intermediates buffers for Paged Attention at prefill stage" );
359
+
360
+ const auto blocks_indexes_start_idx = 0 ;
361
+ const auto blocks_indexes_end_idx = 1 ;
362
+ const auto blocked_gws_subseq_mapping_idx = 2 ;
363
+
364
+ auto blocks_indexes_start_mem = intermediates_memories[blocks_indexes_start_idx];
365
+ auto blocks_indexes_end_mem = intermediates_memories[blocks_indexes_end_idx];
366
+ auto blocked_gws_subseq_mapping_mem = intermediates_memories[blocked_gws_subseq_mapping_idx];
367
+
368
+ OPENVINO_ASSERT (subsequence_begins_mem->get_layout ().data_type == data_types::i32);
369
+
370
+ mem_lock<int32_t , mem_lock_type::write > blocks_indexes_start_lock (blocks_indexes_start_mem, stream);
371
+ mem_lock<int32_t , mem_lock_type::write > blocks_indexes_end_lock (blocks_indexes_end_mem, stream);
372
+ mem_lock<int32_t , mem_lock_type::write > blocked_gws_subseq_mapping_mem_lock (blocked_gws_subseq_mapping_mem, stream);
373
+ std::unique_ptr<mem_lock<int32_t , mem_lock_type::write >> sequential_gws_subseq_mapping_lock = nullptr ;
374
+ std::unique_ptr<mem_lock<int32_t , mem_lock_type::write >> micro_sdpa_block_starts_and_gws_mapping_lock = nullptr ;
375
+
376
+ if (stage == PagedAttentionStage::MIXED) {
377
+ const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6 ;
378
+
379
+ OPENVINO_ASSERT (intermediates_memories.size () > sequential_gws_subseq_mapping_idx,
380
+ " [GPU] Unexpected number of intermediates buffers for Paged Attention for mixed stage" );
381
+
382
+ auto sequential_gws_subseq_mapping_mem = intermediates_memories[sequential_gws_subseq_mapping_idx];
383
+ sequential_gws_subseq_mapping_lock.reset (new mem_lock<int32_t , mem_lock_type::write >(sequential_gws_subseq_mapping_mem, stream));
384
+ }
385
+
386
+ if (stage == PagedAttentionStage::PREFILL && use_micro_sdpa) {
387
+ const auto memory_idx = intermediates_memories.size () - 1 ;
388
+
389
+ auto memory = intermediates_memories[memory_idx];
390
+ micro_sdpa_block_starts_and_gws_mapping_lock.reset (new mem_lock<int32_t , mem_lock_type::write >(memory, stream));
391
+ }
392
+
393
+ size_t index = 0 ;
394
+ size_t micro_sdpa_index = 0 ;
395
+ size_t subsequence_offsets_acc = 0 ;
396
+ size_t query_block_size = get_query_block_size (stage);
397
+ const auto pa_block_size = static_cast <int >(paged_attention::block_size);
398
+ for (size_t i = 0 ; i < subsequence_begins_mem_lock.size () - 1 ; i++) {
399
+ const auto past_len = past_lens_mem_lock[i];
400
+ const auto seq_start = subsequence_begins_mem_lock[i];
401
+ const auto seq_end = subsequence_begins_mem_lock[i + 1 ];
402
+ const auto seq_length = seq_end - seq_start;
403
+
404
+ int32_t j = 0 ;
405
+ if (past_len != 0 ) {
406
+ auto block_start_pos = seq_start;
407
+ auto empty_slots = pa_block_size - (past_len % pa_block_size);
408
+ auto block_end_pos = seq_start + std::min (empty_slots, seq_length);
409
+
410
+ blocks_indexes_start_lock[index ] = block_start_pos;
411
+ blocks_indexes_end_lock[index ] = block_end_pos;
412
+ blocked_gws_subseq_mapping_mem_lock[index ] = static_cast <int32_t >(i);
413
+
414
+ index ++;
415
+
416
+ auto added_slots = block_end_pos - block_start_pos;
417
+ j += added_slots;
418
+ }
419
+
420
+ for (; j < seq_length; j += pa_block_size) {
421
+ auto block_start_pos = subsequence_begins_mem_lock[i] + j;
422
+ auto block_end_pos = std::min (block_start_pos + pa_block_size, seq_end);
423
+
424
+ blocks_indexes_start_lock[index ] = block_start_pos;
425
+ blocks_indexes_end_lock[index ] = block_end_pos;
426
+ blocked_gws_subseq_mapping_mem_lock[index ] = static_cast <int32_t >(i);
427
+
428
+ index ++;
429
+ }
430
+
431
+ if (micro_sdpa_block_starts_and_gws_mapping_lock) {
432
+ const auto block_size = static_cast <int >(query_block_size);
433
+ for (int32_t j = 0 ; j < seq_length; j += block_size) {
434
+ auto block_start_pos = subsequence_begins_mem_lock[i] + j;
435
+
436
+ micro_sdpa_block_starts_and_gws_mapping_lock->operator [](micro_sdpa_index++) = block_start_pos;
437
+ micro_sdpa_block_starts_and_gws_mapping_lock->operator [](micro_sdpa_index++) = static_cast <int32_t >(i);
438
+ }
439
+ }
440
+
441
+ if (stage == PagedAttentionStage::MIXED) {
442
+ for (int32_t idx = seq_start; idx < seq_end; idx++) {
443
+ sequential_gws_subseq_mapping_lock->operator [](idx) = static_cast <int32_t >(i);
444
+ }
445
+ }
446
+
447
+ if (subsequence_offsets_lock) {
448
+ subsequence_offsets_lock->operator [](i) = static_cast <int32_t >(subsequence_offsets_acc);
449
+ subsequence_offsets_acc += seq_length + past_len;
450
+ }
451
+ }
452
+ }
453
+
295
454
void execute_stage (const std::vector<event::ptr>& events,
296
455
paged_attention_inst& instance,
297
456
std::vector<event::ptr>& all_events,
@@ -385,6 +544,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
385
544
const auto stage = get_paged_attention_stage (*instance.get_impl_params ());
386
545
const auto is_mixed_mode = stage == PagedAttentionStage::MIXED;
387
546
547
+ prepare_internal_buffers (instance, stage);
548
+
388
549
std::vector<event::ptr> res_events;
389
550
std::vector<event::ptr> dep_events = events;
390
551
if (has_rotated_blocks && !_kernels_data[Stage::KV_CACHE_ROTATE].kernels [0 ].skip_execution ) {
0 commit comments