Skip to content

Commit fc48ca3

Browse files
committed
[GPU] Use micro-sdpa kernel for PA prefill stage
1 parent 9740544 commit fc48ca3

39 files changed

+323
-241
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/multi_stage_primitive.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
7676
ob << _kernels_data.size();
7777
for (auto& kd : _kernels_data) {
7878
ob << make_data(&kd.internalBufferDataType, sizeof(kernel_selector::Datatype));
79-
ob << kd.internalBufferSizes;
79+
ob << kd.internalBuffers;
8080
ob << kd.kernels;
8181
ob << kd.kernelName;
8282
}
@@ -90,7 +90,7 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
9090
for (size_t i = 0; i < kernels_size; i++) {
9191
kernel_selector::kernel_data kd;
9292
ib >> make_data(&kd.internalBufferDataType, sizeof(kernel_selector::Datatype));
93-
ib >> kd.internalBufferSizes;
93+
ib >> kd.internalBuffers;
9494
ib >> kd.kernels;
9595
ib >> kd.kernelName;
9696
_kernels_data[i] = kd;
@@ -160,14 +160,14 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
160160
std::vector<layout> get_internal_buffer_layouts_impl() const override {
161161
std::vector<layout> layouts;
162162
for (auto& kd : _kernels_data) {
163-
if (kd.internalBufferSizes.empty())
163+
if (kd.internalBuffers.empty())
164164
continue;
165165

166166
auto dtype = from_data_type(kd.internalBufferDataType);
167167
const auto bpp = data_type_traits::size_of(dtype);
168-
for (auto size : kd.internalBufferSizes) {
168+
for (const auto& buffer : kd.internalBuffers) {
169169
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
170-
{1, 1, 1, (tensor::value_type)(size / bpp)}};
170+
{1, 1, 1, (tensor::value_type)(buffer.byte_count / bpp)}};
171171
layouts.push_back(inbuf_layout);
172172
}
173173
}

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+69-35
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@
2020
namespace cldnn {
2121
namespace ocl {
2222

23+
inline ::std::ostream& operator<<(::std::ostream& os, const std::set<size_t>& vals) {
24+
os << "[ ";
25+
for (const auto& val : vals) {
26+
os << val << " ";
27+
}
28+
os << "]";
29+
30+
return os;
31+
}
32+
2333
struct paged_attention_impl : multi_stage_primitive<paged_attention> {
2434
using parent = multi_stage_primitive<paged_attention>;
2535
using parent::parent;
@@ -72,25 +82,29 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
7282
OPENVINO_ASSERT(inst.get_impl() == this);
7383

7484
auto& pa_inst = reinterpret_cast<paged_attention_inst&>(inst);
75-
if (is_micro_kernel_used) {
85+
if (use_micro_sdpa) {
7686
auto tile_q_size = get_target_seq_len_block_size(PagedAttentionStage::PREFILL);
7787
pa_inst.tile_q_size = tile_q_size;
78-
std::cout << "update_inst_params: from micro-sdpa tile_q_size = " << tile_q_size << "\n";
88+
pa_inst.use_micro_sdpa = true;
89+
// std::cout << "update_inst_params: from micro-sdpa tile_q_size = " << tile_q_size << "\n";
7990
} else {
8091
pa_inst.tile_q_size = get_target_seq_len_block_size(PagedAttentionStage::PREFILL);
81-
std::cout << "update_inst_params: sdpa_opt tile_q_size = " << get_target_seq_len_block_size(PagedAttentionStage::PREFILL) << "\n";
92+
pa_inst.use_micro_sdpa = false;
93+
// std::cout << "update_inst_params: sdpa_opt tile_q_size = " << get_target_seq_len_block_size(PagedAttentionStage::PREFILL) << "\n";
8294
}
8395
}
8496

8597
size_t get_target_seq_len_block_size(const PagedAttentionStage& stage) const {
98+
const auto default_block_size = 16;
99+
86100
if (stage == PagedAttentionStage::PREFILL) {
87-
if (is_micro_kernel_used) {
101+
if (use_micro_sdpa) {
88102
return kernel_selector::SDPAKernelMicro::GetTileQSize(_kernels_data[Stage::SDPA]);
89103
} else {
90-
return 16;
104+
return default_block_size;
91105
}
92106
} else {
93-
return 16;
107+
return default_block_size;
94108
}
95109
}
96110

@@ -125,7 +139,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
125139
ob << make_data(&has_rotated_blocks, sizeof(bool));
126140
}
127141

128-
std::vector<layout> get_internal_buffer_layouts_impl() const override {
142+
std::vector<kernel_selector::InternalBuffer> get_internal_buffers_desc() const {
129143
/*
130144
* Internal buffers allocation owners and users:
131145
* +--------------------------------------+--------------------+--------------------+
@@ -145,6 +159,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
145159
* +--------------------------------------+--------------------+--------------------+
146160
* | PA_SDPA (mixed mode) + scores output | [3, 4, 5, 6, 7, 8] | |
147161
* +--------------------------------------+--------------------+--------------------+
162+
* | SDPA (1st token, micro-kernel) | [last(8/9)] | [0, 1, 2] |
163+
* +--------------------------------------+--------------------+--------------------+
148164
*
149165
* Description:
150166
* 0, 1, 2 - Buffers used for proper blocks distribution for kv_cache_update and
@@ -157,24 +173,36 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
157173
* Filled in PA/SDPA kernels.
158174
* 8 - Optional buffer used for mixed PA execution mode, mapping gws idx to subsequence id.
159175
* Filled in paged_attention_inst::on_execute() call.
176+
* last -
160177
*/
161178

162-
auto add_internal_buffers = [](std::vector<layout>& layouts, const kernel_selector::KernelData& kd) {
163-
if (kd.internalBufferSizes.empty())
164-
return;
165-
166-
auto dtype = from_data_type(kd.internalBufferDataType);
167-
const auto bpp = data_type_traits::size_of(dtype);
168-
for (auto size : kd.internalBufferSizes) {
169-
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
170-
{1, 1, 1, (tensor::value_type)(size / bpp)}};
171-
layouts.push_back(inbuf_layout);
172-
}
179+
auto add_internal_buffers = [](std::vector<kernel_selector::InternalBuffer>& internal_buffers,
180+
const kernel_selector::KernelData& kd) {
181+
internal_buffers.insert(internal_buffers.end(), kd.internalBuffers.begin(), kd.internalBuffers.end());
173182
};
174183

184+
std::vector<kernel_selector::InternalBuffer> internal_buffers;
185+
// size_t count = 0;
186+
add_internal_buffers(internal_buffers, _kernels_data[Stage::KV_CACHE_UPDATE]);
187+
// std::cout << "Stage::KV_CACHE_UPDATE added: " << internal_buffers.size() - count << "\n";
188+
// count = internal_buffers.size();
189+
add_internal_buffers(internal_buffers, _kernels_data[Stage::PA_SDPA]);
190+
// std::cout << "Stage::PA_SDPA added: " << internal_buffers.size() - count << "\n";
191+
// count = internal_buffers.size();
192+
193+
if (use_micro_sdpa) {
194+
add_internal_buffers(internal_buffers, _kernels_data[Stage::SDPA]);
195+
// std::cout << "Stage::SDPA added: " << internal_buffers.size() - count << "\n";
196+
}
197+
198+
return internal_buffers;
199+
}
200+
201+
std::vector<layout> get_internal_buffer_layouts_impl() const override {
175202
std::vector<layout> layouts;
176-
add_internal_buffers(layouts, _kernels_data[Stage::KV_CACHE_UPDATE]);
177-
add_internal_buffers(layouts, _kernels_data[Stage::PA_SDPA]);
203+
204+
for (const auto& buffer : get_internal_buffers_desc())
205+
layouts.emplace_back(ov::PartialShape{static_cast<int64_t>(buffer.byte_count)}, ov::element::u8, format::bfyx);
178206

179207
return layouts;
180208
}
@@ -273,12 +301,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
273301
}
274302

275303
std::set<size_t> get_lockable_internal_buffers() const override {
276-
size_t mixed_mode_buffer = has_scores_output ? 8 : 6;
304+
std::set<size_t> lockable_ids;
305+
const auto& internal_buffers = get_internal_buffers_desc();
306+
for (size_t i = 0; i < internal_buffers.size(); i++) {
307+
if (internal_buffers[i].lockable) {
308+
lockable_ids.insert(i);
309+
}
310+
}
277311

278-
std::set<size_t> lockable_ids = { 0, 1, 2, /* SDPA and KV_CACHE_UPDATE indexes configuration */
279-
mixed_mode_buffer /* PA_SDPA multiple tokens mode */ };
280-
if (has_scores_output)
281-
lockable_ids.insert(4 /* Precalculated accumulated sequence length offsets for each subsequence */);
312+
// std::cout << "Lockable indexes: " << lockable_ids << "\n";
282313

283314
return lockable_ids;
284315
};
@@ -299,12 +330,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
299330
size_t internal_buffers_offset = 0;
300331
size_t internal_buffers_count = 0;
301332
if (stage == Stage::PA_SDPA) {
302-
internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size();
303-
internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes.size();
333+
internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers.size();
334+
internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBuffers.size();
304335
} else if (stage == Stage::KV_CACHE_UPDATE) {
305-
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size();
336+
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers.size();
306337
} else if (stage == Stage::SDPA) {
307-
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size();
338+
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers.size();
308339

309340
const auto desc = instance.get_node().as<paged_attention>().get_primitive();
310341
if (desc->has_scores_output()) {
@@ -332,6 +363,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
332363
intermediate_memories.begin() + internal_buffers_offset,
333364
intermediate_memories.begin() + internal_buffers_offset + internal_buffers_count);
334365

366+
if (use_micro_sdpa && stage == Stage::SDPA) {
367+
args.intermediates.push_back(intermediate_memories.back());
368+
}
369+
335370
GPU_DEBUG_TRACE_DETAIL << "Execute stage=" << stage << " kernel=" << kd_idx << " " << _kernels_data[stage].kernelName << " start_offset="
336371
<< internal_buffers_offset << " count=" << internal_buffers_count << "\n";
337372

@@ -627,7 +662,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
627662

628663
new_layout.set_partial_shape(new_shape);
629664

630-
std::cout << "Convert layout: " << input_layout.to_short_string() << " -> " << new_layout.to_short_string() << "\n";
665+
// std::cout << "Convert layout: " << input_layout.to_short_string() << " -> " << new_layout.to_short_string() << "\n";
631666

632667
return convert_data_tensor(new_layout);
633668
};
@@ -808,7 +843,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
808843
(_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func)(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]);
809844
}
810845

811-
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, get_target_seq_len_block_size(stage), impl_param.is_dynamic());
846+
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, 16 /* default_block_size */, impl_param.is_dynamic());
812847
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
813848

814849
if (stage == PagedAttentionStage::PREFILL) {
@@ -854,9 +889,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
854889
impl->has_rotated_blocks = desc->has_rotated_blocks;
855890

856891
if (!kernels_data[Stage::SDPA].kernels[0].micro_kernels.empty()) {
857-
std::cout << "Micro SDPA is choosen!\n";
858-
std::cout << "tile_q_size = " << kernel_selector::SDPAKernelMicro::GetTileQSize(kernels_data[Stage::SDPA]) << "\n";
859-
impl->is_micro_kernel_used = true;
892+
std::cout << "Micro SDPA is choosen! tile_q_size = " << kernel_selector::SDPAKernelMicro::GetTileQSize(kernels_data[Stage::SDPA]) << "\n";
893+
impl->use_micro_sdpa = true;
860894
}
861895

862896
return impl;
@@ -865,7 +899,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
865899
private:
866900
bool has_scores_output = false;
867901
bool has_rotated_blocks = false;
868-
bool is_micro_kernel_used = false;
902+
bool use_micro_sdpa = false;
869903
};
870904

871905
namespace detail {

src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
7474
void save(BinaryOutputBuffer& ob) const override {
7575
primitive_impl::save(ob);
7676
ob << make_data(&_kernel_data.internalBufferDataType, sizeof(kernel_selector::Datatype));
77-
ob << _kernel_data.internalBufferSizes;
77+
ob << _kernel_data.internalBuffers;
7878
ob << _kernel_data.kernels;
7979
ob << _kernel_data.kernelName;
8080
}
8181

8282
void load(BinaryInputBuffer& ib) override {
8383
primitive_impl::load(ib);
8484
ib >> make_data(&_kernel_data.internalBufferDataType, sizeof(kernel_selector::Datatype));
85-
ib >> _kernel_data.internalBufferSizes;
85+
ib >> _kernel_data.internalBuffers;
8686
ib >> _kernel_data.kernels;
8787
ib >> _kernel_data.kernelName;
8888
}
@@ -185,15 +185,15 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
185185
}
186186

187187
std::vector<layout> get_internal_buffer_layouts_impl() const override {
188-
if (_kernel_data.internalBufferSizes.empty())
188+
if (_kernel_data.internalBuffers.empty())
189189
return {};
190190

191191
std::vector<layout> layouts;
192192
auto dtype = from_data_type(_kernel_data.internalBufferDataType);
193193
const auto bpp = data_type_traits::size_of(dtype);
194-
for (auto size : _kernel_data.internalBufferSizes) {
194+
for (const auto& buffer : _kernel_data.internalBuffers) {
195195
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
196-
{1, 1, 1, (tensor::value_type)(size / bpp)}};
196+
{1, 1, 1, (tensor::value_type)(buffer.byte_count / bpp)}};
197197
layouts.push_back(inbuf_layout);
198198
}
199199
return layouts;

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,19 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
7171
// _kernels_data[1] - sdpa_opt (indirect)
7272
// => use internal buffers from [1] kernel
7373
size_t kernel_idx = _kernels_data.size();
74-
if (_kernels_data.size() >= 1 && !_kernels_data[0].internalBufferSizes.empty()) {
74+
if (_kernels_data.size() >= 1 && !_kernels_data[0].internalBuffers.empty()) {
7575
kernel_idx = 0;
76-
} else if (_kernels_data.size() >= 2 && !_kernels_data[1].internalBufferSizes.empty()) {
76+
} else if (_kernels_data.size() >= 2 && !_kernels_data[1].internalBuffers.empty()) {
7777
kernel_idx = 1;
7878
}
7979

8080
std::vector<layout> layouts;
8181
if (kernel_idx < _kernels_data.size()) {
8282
auto dtype = from_data_type(_kernels_data[kernel_idx].internalBufferDataType);
8383
const auto bpp = data_type_traits::size_of(dtype);
84-
for (auto size : _kernels_data[kernel_idx].internalBufferSizes) {
84+
for (const auto& buffer : _kernels_data[kernel_idx].internalBuffers) {
8585
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
86-
{1, 1, 1, (tensor::value_type)(size / bpp)}};
86+
{1, 1, 1, (tensor::value_type)(buffer.byte_count / bpp)}};
8787
layouts.push_back(inbuf_layout);
8888
}
8989
}

src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
6262
memory::ptr rotation_deltas_memory_ptr() const { return input_memory_ptr(14); }
6363
memory::ptr rotation_trig_lut_memory_ptr() const { return input_memory_ptr(15); }
6464

65+
bool use_micro_sdpa = false;
6566
size_t tile_q_size = 0;
6667

6768
protected:

src/plugins/intel_gpu/src/graph/include/primitive_inst.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ class primitive_inst {
412412
std::vector<memory::ptr> allocate_outputs(kernel_impl_params* updated_params = nullptr,
413413
bool reset_mem = true,
414414
bool runtime_alloc = false);
415-
memory::ptr allocate_internal_buffer(size_t idx, bool reset = true);
415+
memory::ptr allocate_internal_buffer(const layout& layout, size_t idx, bool reset = true, bool lockable = false);
416416
void allocate_shape_info_memory();
417417
static std::vector<primitive_inst*> build_exec_deps(
418418
std::vector<std::pair<primitive_inst*, int32_t>> const& mem_deps);

0 commit comments

Comments
 (0)