15
15
#include " sdpa/pa_kv_cache_rotate_kernel_ref.h"
16
16
#include " sdpa/pa_kv_cache_update_kernel_ref.h"
17
17
#include " sdpa/pa_sdpa_kernel_opt.h"
18
+ #include " sdpa/sdpa_kernel_micro.h"
18
19
19
20
namespace cldnn {
20
21
namespace ocl {
@@ -66,10 +67,34 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
66
67
return stage == PagedAttentionStage::MIXED;
67
68
}
68
69
70
+ void update_inst_params (primitive_inst& inst) const override {
71
+ OPENVINO_ASSERT (inst.type () == paged_attention::type_id ());
72
+ OPENVINO_ASSERT (inst.get_impl () == this );
73
+
74
+ auto & pa_inst = reinterpret_cast <paged_attention_inst&>(inst);
75
+ pa_inst.query_block_size = get_query_block_size (PagedAttentionStage::PREFILL);
76
+ pa_inst.use_micro_sdpa = use_micro_sdpa;
77
+ }
78
+
79
+ size_t get_query_block_size (const PagedAttentionStage& stage) const {
80
+ const auto default_block_size = 16 ;
81
+
82
+ if (stage == PagedAttentionStage::PREFILL) {
83
+ #ifdef ENABLE_ONEDNN_FOR_GPU
84
+ if (use_micro_sdpa)
85
+ return kernel_selector::SDPAKernelMicro::GetTileQSize (_kernels_data[Stage::SDPA]);
86
+ #endif
87
+ return default_block_size;
88
+ } else {
89
+ return default_block_size;
90
+ }
91
+ }
92
+
69
93
void load (BinaryInputBuffer& ib) override {
70
94
parent::load (ib);
71
95
ib >> make_data (&has_scores_output, sizeof (bool ));
72
96
ib >> make_data (&has_rotated_blocks, sizeof (bool ));
97
+ ib >> make_data (&use_micro_sdpa, sizeof (bool ));
73
98
if (is_dynamic ()) {
74
99
auto & kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance ();
75
100
auto kv_cache_update_kernel_impl = kv_cache_update_kernel_selector.GetImplementation (_kernels_data[Stage::KV_CACHE_UPDATE].kernelName );
@@ -95,9 +120,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
95
120
parent::save (ob);
96
121
ob << make_data (&has_scores_output, sizeof (bool ));
97
122
ob << make_data (&has_rotated_blocks, sizeof (bool ));
123
+ ob << make_data (&use_micro_sdpa, sizeof (bool ));
98
124
}
99
125
100
- std::vector<layout> get_internal_buffer_layouts_impl () const override {
126
+ std::vector<kernel_selector::InternalBuffer> get_internal_buffers_desc () const {
101
127
/*
102
128
* Internal buffers allocation owners and users:
103
129
* +--------------------------------------+--------------------+--------------------+
@@ -117,6 +143,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
117
143
* +--------------------------------------+--------------------+--------------------+
118
144
* | PA_SDPA (mixed mode) + scores output | [3, 4, 5, 6, 7, 8] | |
119
145
* +--------------------------------------+--------------------+--------------------+
146
+ * | SDPA (1st token, micro-kernel) | [last (8/9)] | |
147
+ * +--------------------------------------+--------------------+--------------------+
120
148
*
121
149
* Description:
122
150
* 0, 1, 2 - Buffers used for proper blocks distribution for kv_cache_update and
@@ -129,24 +157,32 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
129
157
* Filled in PA/SDPA kernels.
130
158
* 8 - Optional buffer used for mixed PA execution mode, mapping gws idx to subsequence id.
131
159
* Filled in paged_attention_inst::on_execute() call.
160
+ * last - Used for defining query block index for the currently processing subsequence and mapping
161
+ * gws index to subsequence idx. Values stored in pairs like:
162
+ * [block_idx0, subsequence_idx0, block_idx1, subsequence_idx0, ..., block_idx0, subsequence_idx1].
163
+ * Filled in paged_attention_inst::on_execute() call for sdpa-micro kernel only.
132
164
*/
133
165
134
- auto add_internal_buffers = [](std::vector<layout>& layouts, const kernel_selector::KernelData& kd) {
135
- if (kd.internalBufferSizes .empty ())
136
- return ;
137
-
138
- auto dtype = from_data_type (kd.internalBufferDataType );
139
- const auto bpp = data_type_traits::size_of (dtype);
140
- for (auto size : kd.internalBufferSizes ) {
141
- layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
142
- {1 , 1 , 1 , (tensor::value_type)(size / bpp)}};
143
- layouts.push_back (inbuf_layout);
144
- }
166
+ auto add_internal_buffers = [](std::vector<kernel_selector::InternalBuffer>& internal_buffers,
167
+ const kernel_selector::KernelData& kd) {
168
+ internal_buffers.insert (internal_buffers.end (), kd.internalBuffers .begin (), kd.internalBuffers .end ());
145
169
};
146
170
171
+ std::vector<kernel_selector::InternalBuffer> internal_buffers;
172
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::KV_CACHE_UPDATE]);
173
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::PA_SDPA]);
174
+
175
+ if (use_micro_sdpa)
176
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::SDPA]);
177
+
178
+ return internal_buffers;
179
+ }
180
+
181
+ std::vector<layout> get_internal_buffer_layouts_impl () const override {
147
182
std::vector<layout> layouts;
148
- add_internal_buffers (layouts, _kernels_data[Stage::KV_CACHE_UPDATE]);
149
- add_internal_buffers (layouts, _kernels_data[Stage::PA_SDPA]);
183
+
184
+ for (const auto & buffer : get_internal_buffers_desc ())
185
+ layouts.emplace_back (ov::PartialShape{static_cast <int64_t >(buffer.byte_count )}, ov::element::u8, format::bfyx);
150
186
151
187
return layouts;
152
188
}
@@ -245,12 +281,13 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
245
281
}
246
282
247
283
std::set<size_t > get_lockable_internal_buffers () const override {
248
- size_t mixed_mode_buffer = has_scores_output ? 8 : 6 ;
249
-
250
- std::set<size_t > lockable_ids = { 0 , 1 , 2 , /* SDPA and KV_CACHE_UPDATE indexes configuration */
251
- mixed_mode_buffer /* PA_SDPA multiple tokens mode */ };
252
- if (has_scores_output)
253
- lockable_ids.insert (4 /* Precalculated accumulated sequence length offsets for each subsequence */ );
284
+ std::set<size_t > lockable_ids;
285
+ const auto & internal_buffers = get_internal_buffers_desc ();
286
+ for (size_t i = 0 ; i < internal_buffers.size (); i++) {
287
+ if (internal_buffers[i].lockable ) {
288
+ lockable_ids.insert (i);
289
+ }
290
+ }
254
291
255
292
return lockable_ids;
256
293
};
@@ -271,12 +308,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
271
308
size_t internal_buffers_offset = 0 ;
272
309
size_t internal_buffers_count = 0 ;
273
310
if (stage == Stage::PA_SDPA) {
274
- internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
275
- internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes .size ();
311
+ internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
312
+ internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBuffers .size ();
276
313
} else if (stage == Stage::KV_CACHE_UPDATE) {
277
- internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
314
+ internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
278
315
} else if (stage == Stage::SDPA) {
279
- internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
316
+ internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
280
317
281
318
const auto desc = instance.get_node ().as <paged_attention>().get_primitive ();
282
319
if (desc->has_scores_output ()) {
@@ -304,6 +341,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
304
341
intermediate_memories.begin () + internal_buffers_offset,
305
342
intermediate_memories.begin () + internal_buffers_offset + internal_buffers_count);
306
343
344
+ if (use_micro_sdpa && stage == Stage::SDPA) {
345
+ args.intermediates .push_back (intermediate_memories.back ());
346
+ }
347
+
307
348
GPU_DEBUG_TRACE_DETAIL << " Execute stage=" << stage << " kernel=" << kd_idx << " " << _kernels_data[stage].kernelName << " start_offset="
308
349
<< internal_buffers_offset << " count=" << internal_buffers_count << " \n " ;
309
350
@@ -581,7 +622,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
581
622
static sdpa_kernel_params_t get_sdpa_kernel_params (const kernel_impl_params& impl_param,
582
623
const PagedAttentionStage& stage,
583
624
const kernel_selector::MultiDataTensor& input_tensors,
584
- bool is_dynamic = false ) {
625
+ int64_t query_block_size,
626
+ bool is_dynamic) {
585
627
const auto desc = impl_param.typed_desc <paged_attention>();
586
628
auto params = get_default_params<sdpa_kernel_params_t >(impl_param, is_dynamic);
587
629
@@ -623,6 +665,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
623
665
624
666
params.conf = get_sdpa_configuration (impl_param, is_dynamic);
625
667
668
+ const std::vector<int64_t > default_order = {0 , 1 , 2 , 3 };
669
+ params.input0_order = default_order;
670
+ params.input1_order = default_order;
671
+ params.input2_order = default_order;
672
+ params.output_order = default_order;
673
+
626
674
const auto & in_offsets_map = impl_param.in_port_to_shape_info_offset ;
627
675
const auto & out_offsets_map = impl_param.out_port_to_shape_info_offset ;
628
676
std::map<size_t , size_t > in_tensor_to_offset_map = {
@@ -643,7 +691,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
643
691
in_tensor_to_offset_map.insert ({input_idx++, in_offsets_map.at (11 )});
644
692
645
693
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic)
646
- params.conf .paged_attention_aligned_seq_len = get_aligned_seq_len (impl_param, stage);
694
+ params.conf .paged_attention_aligned_seq_len = get_aligned_seq_len (impl_param, stage, query_block_size );
647
695
648
696
if (has_scores_output)
649
697
out_tensor_to_offset_map.insert ({1 , out_offsets_map.at (1 )});
@@ -760,7 +808,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
760
808
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func )(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
761
809
762
810
if (stage == PagedAttentionStage::PREFILL) {
763
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
811
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, get_query_block_size (stage), impl_param.is_dynamic ());
764
812
(_kernels_data[Stage::SDPA].update_dispatch_data_func )(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
765
813
}
766
814
@@ -782,8 +830,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
782
830
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
783
831
auto & kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance ();
784
832
kernels_data.push_back (kv_cache_update_kernel_selector.get_best_kernel (kv_cache_update_kernel_params));
785
-
786
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
833
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, 0 , impl_param.is_dynamic ());
787
834
auto & sdpa_kernel_selector = sdpa_kernel_selector_t::Instance ();
788
835
kernels_data.push_back (sdpa_kernel_selector.get_best_kernel (sdpa_kernel_params));
789
836
@@ -801,12 +848,17 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
801
848
impl->has_scores_output = desc->has_scores_output ();
802
849
impl->has_rotated_blocks = desc->has_rotated_blocks ;
803
850
851
+ if (!kernels_data[Stage::SDPA].kernels [0 ].micro_kernels .empty ()) {
852
+ impl->use_micro_sdpa = true ;
853
+ }
854
+
804
855
return impl;
805
856
}
806
857
807
858
private:
808
859
bool has_scores_output = false ;
809
860
bool has_rotated_blocks = false ;
861
+ bool use_micro_sdpa = false ;
810
862
};
811
863
812
864
namespace detail {
0 commit comments