15
15
#include " sdpa/pa_kv_cache_rotate_kernel_ref.h"
16
16
#include " sdpa/pa_kv_cache_update_kernel_ref.h"
17
17
#include " sdpa/pa_sdpa_kernel_opt.h"
18
+ #include " sdpa/sdpa_kernel_micro.h"
18
19
19
20
namespace cldnn {
20
21
namespace ocl {
@@ -66,10 +67,31 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
66
67
return stage == PagedAttentionStage::MIXED;
67
68
}
68
69
70
+ void update_inst_params (primitive_inst& inst) const override {
71
+ OPENVINO_ASSERT (inst.type () == paged_attention::type_id ());
72
+ OPENVINO_ASSERT (inst.get_impl () == this );
73
+
74
+ auto & pa_inst = reinterpret_cast <paged_attention_inst&>(inst);
75
+ pa_inst.query_block_size = get_query_block_size (PagedAttentionStage::PREFILL);
76
+ pa_inst.use_micro_sdpa = use_micro_sdpa;
77
+ }
78
+
79
+ size_t get_query_block_size (const PagedAttentionStage& stage) const {
80
+ const auto default_block_size = 16 ;
81
+
82
+ if (stage == PagedAttentionStage::PREFILL) {
83
+ return use_micro_sdpa ? kernel_selector::SDPAKernelMicro::GetTileQSize (_kernels_data[Stage::SDPA])
84
+ : default_block_size;
85
+ } else {
86
+ return default_block_size;
87
+ }
88
+ }
89
+
69
90
void load (BinaryInputBuffer& ib) override {
70
91
parent::load (ib);
71
92
ib >> make_data (&has_scores_output, sizeof (bool ));
72
93
ib >> make_data (&has_rotated_blocks, sizeof (bool ));
94
+ ib >> make_data (&use_micro_sdpa, sizeof (bool ));
73
95
if (is_dynamic ()) {
74
96
auto & kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance ();
75
97
auto kv_cache_update_kernel_impl = kv_cache_update_kernel_selector.GetImplementation (_kernels_data[Stage::KV_CACHE_UPDATE].kernelName );
@@ -95,9 +117,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
95
117
parent::save (ob);
96
118
ob << make_data (&has_scores_output, sizeof (bool ));
97
119
ob << make_data (&has_rotated_blocks, sizeof (bool ));
120
+ ob << make_data (&use_micro_sdpa, sizeof (bool ));
98
121
}
99
122
100
- std::vector<layout> get_internal_buffer_layouts_impl () const override {
123
+ std::vector<kernel_selector::InternalBuffer> get_internal_buffers_desc () const {
101
124
/*
102
125
* Internal buffers allocation owners and users:
103
126
* +--------------------------------------+--------------------+--------------------+
@@ -117,6 +140,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
117
140
* +--------------------------------------+--------------------+--------------------+
118
141
* | PA_SDPA (mixed mode) + scores output | [3, 4, 5, 6, 7, 8] | |
119
142
* +--------------------------------------+--------------------+--------------------+
143
+ * | SDPA (1st token, micro-kernel) | [last (8/9)] | |
144
+ * +--------------------------------------+--------------------+--------------------+
120
145
*
121
146
* Description:
122
147
* 0, 1, 2 - Buffers used for proper blocks distribution for kv_cache_update and
@@ -129,24 +154,32 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
129
154
* Filled in PA/SDPA kernels.
130
155
* 8 - Optional buffer used for mixed PA execution mode, mapping gws idx to subsequence id.
131
156
* Filled in paged_attention_inst::on_execute() call.
157
+ * last - Used for defining query block index for the currently processing subsequence and mapping
158
+ * gws index to subsequence idx. Values stored in pairs like:
159
+ * [block_idx0, subsequence_idx0, block_idx1, subsequence_idx0, ..., block_idx0, subsequence_idx1].
160
+ * Filled in paged_attention_inst::on_execute() call for sdpa-micro kernel only.
132
161
*/
133
162
134
- auto add_internal_buffers = [](std::vector<layout>& layouts, const kernel_selector::KernelData& kd) {
135
- if (kd.internalBufferSizes .empty ())
136
- return ;
137
-
138
- auto dtype = from_data_type (kd.internalBufferDataType );
139
- const auto bpp = data_type_traits::size_of (dtype);
140
- for (auto size : kd.internalBufferSizes ) {
141
- layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
142
- {1 , 1 , 1 , (tensor::value_type)(size / bpp)}};
143
- layouts.push_back (inbuf_layout);
144
- }
163
+ auto add_internal_buffers = [](std::vector<kernel_selector::InternalBuffer>& internal_buffers,
164
+ const kernel_selector::KernelData& kd) {
165
+ internal_buffers.insert (internal_buffers.end (), kd.internalBuffers .begin (), kd.internalBuffers .end ());
145
166
};
146
167
168
+ std::vector<kernel_selector::InternalBuffer> internal_buffers;
169
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::KV_CACHE_UPDATE]);
170
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::PA_SDPA]);
171
+
172
+ if (use_micro_sdpa)
173
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::SDPA]);
174
+
175
+ return internal_buffers;
176
+ }
177
+
178
+ std::vector<layout> get_internal_buffer_layouts_impl () const override {
147
179
std::vector<layout> layouts;
148
- add_internal_buffers (layouts, _kernels_data[Stage::KV_CACHE_UPDATE]);
149
- add_internal_buffers (layouts, _kernels_data[Stage::PA_SDPA]);
180
+
181
+ for (const auto & buffer : get_internal_buffers_desc ())
182
+ layouts.emplace_back (ov::PartialShape{static_cast <int64_t >(buffer.byte_count )}, ov::element::u8, format::bfyx);
150
183
151
184
return layouts;
152
185
}
@@ -245,12 +278,13 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
245
278
}
246
279
247
280
std::set<size_t > get_lockable_internal_buffers () const override {
248
- size_t mixed_mode_buffer = has_scores_output ? 8 : 6 ;
249
-
250
- std::set<size_t > lockable_ids = { 0 , 1 , 2 , /* SDPA and KV_CACHE_UPDATE indexes configuration */
251
- mixed_mode_buffer /* PA_SDPA multiple tokens mode */ };
252
- if (has_scores_output)
253
- lockable_ids.insert (4 /* Precalculated accumulated sequence length offsets for each subsequence */ );
281
+ std::set<size_t > lockable_ids;
282
+ const auto & internal_buffers = get_internal_buffers_desc ();
283
+ for (size_t i = 0 ; i < internal_buffers.size (); i++) {
284
+ if (internal_buffers[i].lockable ) {
285
+ lockable_ids.insert (i);
286
+ }
287
+ }
254
288
255
289
return lockable_ids;
256
290
};
@@ -271,12 +305,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
271
305
size_t internal_buffers_offset = 0 ;
272
306
size_t internal_buffers_count = 0 ;
273
307
if (stage == Stage::PA_SDPA) {
274
- internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
275
- internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes .size ();
308
+ internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
309
+ internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBuffers .size ();
276
310
} else if (stage == Stage::KV_CACHE_UPDATE) {
277
- internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
311
+ internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
278
312
} else if (stage == Stage::SDPA) {
279
- internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
313
+ internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
280
314
281
315
const auto desc = instance.get_node ().as <paged_attention>().get_primitive ();
282
316
if (desc->has_scores_output ()) {
@@ -304,6 +338,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
304
338
intermediate_memories.begin () + internal_buffers_offset,
305
339
intermediate_memories.begin () + internal_buffers_offset + internal_buffers_count);
306
340
341
+ if (use_micro_sdpa && stage == Stage::SDPA) {
342
+ args.intermediates .push_back (intermediate_memories.back ());
343
+ }
344
+
307
345
GPU_DEBUG_TRACE_DETAIL << " Execute stage=" << stage << " kernel=" << kd_idx << " " << _kernels_data[stage].kernelName << " start_offset="
308
346
<< internal_buffers_offset << " count=" << internal_buffers_count << " \n " ;
309
347
@@ -581,7 +619,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
581
619
static sdpa_kernel_params_t get_sdpa_kernel_params (const kernel_impl_params& impl_param,
582
620
const PagedAttentionStage& stage,
583
621
const kernel_selector::MultiDataTensor& input_tensors,
584
- bool is_dynamic = false ) {
622
+ int64_t query_block_size,
623
+ bool is_dynamic) {
585
624
const auto desc = impl_param.typed_desc <paged_attention>();
586
625
auto params = get_default_params<sdpa_kernel_params_t >(impl_param, is_dynamic);
587
626
@@ -623,6 +662,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
623
662
624
663
params.conf = get_sdpa_configuration (impl_param, is_dynamic);
625
664
665
+ const std::vector<int64_t > default_order = {0 , 1 , 2 , 3 };
666
+ params.input0_order = default_order;
667
+ params.input1_order = default_order;
668
+ params.input2_order = default_order;
669
+ params.output_order = default_order;
670
+
626
671
const auto & in_offsets_map = impl_param.in_port_to_shape_info_offset ;
627
672
const auto & out_offsets_map = impl_param.out_port_to_shape_info_offset ;
628
673
std::map<size_t , size_t > in_tensor_to_offset_map = {
@@ -643,7 +688,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
643
688
in_tensor_to_offset_map.insert ({input_idx++, in_offsets_map.at (11 )});
644
689
645
690
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic)
646
- params.conf .paged_attention_aligned_seq_len = get_aligned_seq_len (impl_param, stage);
691
+ params.conf .paged_attention_aligned_seq_len = get_aligned_seq_len (impl_param, stage, query_block_size );
647
692
648
693
if (has_scores_output)
649
694
out_tensor_to_offset_map.insert ({1 , out_offsets_map.at (1 )});
@@ -760,7 +805,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
760
805
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func )(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
761
806
762
807
if (stage == PagedAttentionStage::PREFILL) {
763
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
808
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, get_query_block_size (stage), impl_param.is_dynamic ());
764
809
(_kernels_data[Stage::SDPA].update_dispatch_data_func )(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
765
810
}
766
811
@@ -782,8 +827,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
782
827
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
783
828
auto & kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance ();
784
829
kernels_data.push_back (kv_cache_update_kernel_selector.get_best_kernel (kv_cache_update_kernel_params));
785
-
786
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
830
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, 0 , impl_param.is_dynamic ());
787
831
auto & sdpa_kernel_selector = sdpa_kernel_selector_t::Instance ();
788
832
kernels_data.push_back (sdpa_kernel_selector.get_best_kernel (sdpa_kernel_params));
789
833
@@ -801,12 +845,18 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
801
845
impl->has_scores_output = desc->has_scores_output ();
802
846
impl->has_rotated_blocks = desc->has_rotated_blocks ;
803
847
848
+ if (!kernels_data[Stage::SDPA].kernels [0 ].micro_kernels .empty ()) {
849
+ std::cout << " Micro SDPA is chosen! tile_q_size = " << kernel_selector::SDPAKernelMicro::GetTileQSize (kernels_data[Stage::SDPA]) << " \n " ;
850
+ impl->use_micro_sdpa = true ;
851
+ }
852
+
804
853
return impl;
805
854
}
806
855
807
856
private:
808
857
bool has_scores_output = false ;
809
858
bool has_rotated_blocks = false ;
859
+ bool use_micro_sdpa = false ;
810
860
};
811
861
812
862
namespace detail {
0 commit comments