20
20
namespace cldnn {
21
21
namespace ocl {
22
22
23
+ inline ::std::ostream& operator <<(::std::ostream& os, const std::set<size_t >& vals) {
24
+ os << " [ " ;
25
+ for (const auto & val : vals) {
26
+ os << val << " " ;
27
+ }
28
+ os << " ]" ;
29
+
30
+ return os;
31
+ }
32
+
23
33
struct paged_attention_impl : multi_stage_primitive<paged_attention> {
24
34
using parent = multi_stage_primitive<paged_attention>;
25
35
using parent::parent;
@@ -72,25 +82,29 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
72
82
OPENVINO_ASSERT (inst.get_impl () == this );
73
83
74
84
auto & pa_inst = reinterpret_cast <paged_attention_inst&>(inst);
75
- if (is_micro_kernel_used ) {
85
+ if (use_micro_sdpa ) {
76
86
auto tile_q_size = get_target_seq_len_block_size (PagedAttentionStage::PREFILL);
77
87
pa_inst.tile_q_size = tile_q_size;
78
- std::cout << " update_inst_params: from micro-sdpa tile_q_size = " << tile_q_size << " \n " ;
88
+ pa_inst.use_micro_sdpa = true ;
89
+ // std::cout << "update_inst_params: from micro-sdpa tile_q_size = " << tile_q_size << "\n";
79
90
} else {
80
91
pa_inst.tile_q_size = get_target_seq_len_block_size (PagedAttentionStage::PREFILL);
81
- std::cout << " update_inst_params: sdpa_opt tile_q_size = " << get_target_seq_len_block_size (PagedAttentionStage::PREFILL) << " \n " ;
92
+ pa_inst.use_micro_sdpa = false ;
93
+ // std::cout << "update_inst_params: sdpa_opt tile_q_size = " << get_target_seq_len_block_size(PagedAttentionStage::PREFILL) << "\n";
82
94
}
83
95
}
84
96
85
97
size_t get_target_seq_len_block_size (const PagedAttentionStage& stage) const {
98
+ const auto default_block_size = 16 ;
99
+
86
100
if (stage == PagedAttentionStage::PREFILL) {
87
- if (is_micro_kernel_used ) {
101
+ if (use_micro_sdpa ) {
88
102
return kernel_selector::SDPAKernelMicro::GetTileQSize (_kernels_data[Stage::SDPA]);
89
103
} else {
90
- return 16 ;
104
+ return default_block_size ;
91
105
}
92
106
} else {
93
- return 16 ;
107
+ return default_block_size ;
94
108
}
95
109
}
96
110
@@ -125,7 +139,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
125
139
ob << make_data (&has_rotated_blocks, sizeof (bool ));
126
140
}
127
141
128
- std::vector<layout> get_internal_buffer_layouts_impl () const override {
142
+ std::vector<kernel_selector::InternalBuffer> get_internal_buffers_desc () const {
129
143
/*
130
144
* Internal buffers allocation owners and users:
131
145
* +--------------------------------------+--------------------+--------------------+
@@ -145,6 +159,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
145
159
* +--------------------------------------+--------------------+--------------------+
146
160
* | PA_SDPA (mixed mode) + scores output | [3, 4, 5, 6, 7, 8] | |
147
161
* +--------------------------------------+--------------------+--------------------+
162
+ * | SDPA (1st token, micro-kernel) | [last(8/9)] | [0, 1, 2] |
163
+ * +--------------------------------------+--------------------+--------------------+
148
164
*
149
165
* Description:
150
166
* 0, 1, 2 - Buffers used for proper blocks distribution for kv_cache_update and
@@ -157,24 +173,36 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
157
173
* Filled in PA/SDPA kernels.
158
174
* 8 - Optional buffer used for mixed PA execution mode, mapping gws idx to subsequence id.
159
175
* Filled in paged_attention_inst::on_execute() call.
176
+ * last -
160
177
*/
161
178
162
- auto add_internal_buffers = [](std::vector<layout>& layouts, const kernel_selector::KernelData& kd) {
163
- if (kd.internalBufferSizes .empty ())
164
- return ;
165
-
166
- auto dtype = from_data_type (kd.internalBufferDataType );
167
- const auto bpp = data_type_traits::size_of (dtype);
168
- for (auto size : kd.internalBufferSizes ) {
169
- layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
170
- {1 , 1 , 1 , (tensor::value_type)(size / bpp)}};
171
- layouts.push_back (inbuf_layout);
172
- }
179
+ auto add_internal_buffers = [](std::vector<kernel_selector::InternalBuffer>& internal_buffers,
180
+ const kernel_selector::KernelData& kd) {
181
+ internal_buffers.insert (internal_buffers.end (), kd.internalBuffers .begin (), kd.internalBuffers .end ());
173
182
};
174
183
184
+ std::vector<kernel_selector::InternalBuffer> internal_buffers;
185
+ // size_t count = 0;
186
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::KV_CACHE_UPDATE]);
187
+ // std::cout << "Stage::KV_CACHE_UPDATE added: " << internal_buffers.size() - count << "\n";
188
+ // count = internal_buffers.size();
189
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::PA_SDPA]);
190
+ // std::cout << "Stage::PA_SDPA added: " << internal_buffers.size() - count << "\n";
191
+ // count = internal_buffers.size();
192
+
193
+ if (use_micro_sdpa) {
194
+ add_internal_buffers (internal_buffers, _kernels_data[Stage::SDPA]);
195
+ // std::cout << "Stage::SDPA added: " << internal_buffers.size() - count << "\n";
196
+ }
197
+
198
+ return internal_buffers;
199
+ }
200
+
201
+ std::vector<layout> get_internal_buffer_layouts_impl () const override {
175
202
std::vector<layout> layouts;
176
- add_internal_buffers (layouts, _kernels_data[Stage::KV_CACHE_UPDATE]);
177
- add_internal_buffers (layouts, _kernels_data[Stage::PA_SDPA]);
203
+
204
+ for (const auto & buffer : get_internal_buffers_desc ())
205
+ layouts.emplace_back (ov::PartialShape{static_cast <int64_t >(buffer.byte_count )}, ov::element::u8, format::bfyx);
178
206
179
207
return layouts;
180
208
}
@@ -273,12 +301,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
273
301
}
274
302
275
303
std::set<size_t > get_lockable_internal_buffers () const override {
276
- size_t mixed_mode_buffer = has_scores_output ? 8 : 6 ;
304
+ std::set<size_t > lockable_ids;
305
+ const auto & internal_buffers = get_internal_buffers_desc ();
306
+ for (size_t i = 0 ; i < internal_buffers.size (); i++) {
307
+ if (internal_buffers[i].lockable ) {
308
+ lockable_ids.insert (i);
309
+ }
310
+ }
277
311
278
- std::set<size_t > lockable_ids = { 0 , 1 , 2 , /* SDPA and KV_CACHE_UPDATE indexes configuration */
279
- mixed_mode_buffer /* PA_SDPA multiple tokens mode */ };
280
- if (has_scores_output)
281
- lockable_ids.insert (4 /* Precalculated accumulated sequence length offsets for each subsequence */ );
312
+ // std::cout << "Lockable indexes: " << lockable_ids << "\n";
282
313
283
314
return lockable_ids;
284
315
};
@@ -299,12 +330,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
299
330
size_t internal_buffers_offset = 0 ;
300
331
size_t internal_buffers_count = 0 ;
301
332
if (stage == Stage::PA_SDPA) {
302
- internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
303
- internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes .size ();
333
+ internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
334
+ internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBuffers .size ();
304
335
} else if (stage == Stage::KV_CACHE_UPDATE) {
305
- internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
336
+ internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
306
337
} else if (stage == Stage::SDPA) {
307
- internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes .size ();
338
+ internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers .size ();
308
339
309
340
const auto desc = instance.get_node ().as <paged_attention>().get_primitive ();
310
341
if (desc->has_scores_output ()) {
@@ -332,6 +363,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
332
363
intermediate_memories.begin () + internal_buffers_offset,
333
364
intermediate_memories.begin () + internal_buffers_offset + internal_buffers_count);
334
365
366
+ if (use_micro_sdpa && stage == Stage::SDPA) {
367
+ args.intermediates .push_back (intermediate_memories.back ());
368
+ }
369
+
335
370
GPU_DEBUG_TRACE_DETAIL << " Execute stage=" << stage << " kernel=" << kd_idx << " " << _kernels_data[stage].kernelName << " start_offset="
336
371
<< internal_buffers_offset << " count=" << internal_buffers_count << " \n " ;
337
372
@@ -627,7 +662,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
627
662
628
663
new_layout.set_partial_shape (new_shape);
629
664
630
- std::cout << " Convert layout: " << input_layout.to_short_string () << " -> " << new_layout.to_short_string () << " \n " ;
665
+ // std::cout << "Convert layout: " << input_layout.to_short_string() << " -> " << new_layout.to_short_string() << "\n";
631
666
632
667
return convert_data_tensor (new_layout);
633
668
};
@@ -808,7 +843,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
808
843
(_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func )(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]);
809
844
}
810
845
811
- auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, get_target_seq_len_block_size (stage) , impl_param.is_dynamic ());
846
+ auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, 16 /* default_block_size */ , impl_param.is_dynamic ());
812
847
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func )(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
813
848
814
849
if (stage == PagedAttentionStage::PREFILL) {
@@ -854,9 +889,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
854
889
impl->has_rotated_blocks = desc->has_rotated_blocks ;
855
890
856
891
if (!kernels_data[Stage::SDPA].kernels [0 ].micro_kernels .empty ()) {
857
- std::cout << " Micro SDPA is choosen!\n " ;
858
- std::cout << " tile_q_size = " << kernel_selector::SDPAKernelMicro::GetTileQSize (kernels_data[Stage::SDPA]) << " \n " ;
859
- impl->is_micro_kernel_used = true ;
892
+ std::cout << " Micro SDPA is choosen! tile_q_size = " << kernel_selector::SDPAKernelMicro::GetTileQSize (kernels_data[Stage::SDPA]) << " \n " ;
893
+ impl->use_micro_sdpa = true ;
860
894
}
861
895
862
896
return impl;
@@ -865,7 +899,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
865
899
private:
866
900
bool has_scores_output = false ;
867
901
bool has_rotated_blocks = false ;
868
- bool is_micro_kernel_used = false ;
902
+ bool use_micro_sdpa = false ;
869
903
};
870
904
871
905
namespace detail {
0 commit comments