Skip to content

Commit 37e4939

Browse files
committed
[GPU] Use sdpa-micro kernel for prefill processing in PagedAttention
1 parent ec9dfae commit 37e4939

File tree

43 files changed

+435
-206
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+435
-206
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/multi_stage_primitive.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
7676
ob << _kernels_data.size();
7777
for (auto& kd : _kernels_data) {
7878
ob << make_data(&kd.internalBufferDataType, sizeof(kernel_selector::Datatype));
79-
ob << kd.internalBufferSizes;
79+
ob << kd.internalBuffers;
8080
ob << kd.kernels;
8181
ob << kd.kernelName;
8282
}
@@ -90,7 +90,7 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
9090
for (size_t i = 0; i < kernels_size; i++) {
9191
kernel_selector::kernel_data kd;
9292
ib >> make_data(&kd.internalBufferDataType, sizeof(kernel_selector::Datatype));
93-
ib >> kd.internalBufferSizes;
93+
ib >> kd.internalBuffers;
9494
ib >> kd.kernels;
9595
ib >> kd.kernelName;
9696
_kernels_data[i] = kd;
@@ -160,14 +160,14 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
160160
std::vector<layout> get_internal_buffer_layouts_impl() const override {
161161
std::vector<layout> layouts;
162162
for (auto& kd : _kernels_data) {
163-
if (kd.internalBufferSizes.empty())
163+
if (kd.internalBuffers.empty())
164164
continue;
165165

166166
auto dtype = from_data_type(kd.internalBufferDataType);
167167
const auto bpp = data_type_traits::size_of(dtype);
168-
for (auto size : kd.internalBufferSizes) {
168+
for (const auto& buffer : kd.internalBuffers) {
169169
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
170-
{1, 1, 1, (tensor::value_type)(size / bpp)}};
170+
{1, 1, 1, (tensor::value_type)(buffer.byte_count / bpp)}};
171171
layouts.push_back(inbuf_layout);
172172
}
173173
}

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+79-29
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "sdpa/pa_kv_cache_rotate_kernel_ref.h"
1616
#include "sdpa/pa_kv_cache_update_kernel_ref.h"
1717
#include "sdpa/pa_sdpa_kernel_opt.h"
18+
#include "sdpa/sdpa_kernel_micro.h"
1819

1920
namespace cldnn {
2021
namespace ocl {
@@ -66,10 +67,31 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
6667
return stage == PagedAttentionStage::MIXED;
6768
}
6869

70+
void update_inst_params(primitive_inst& inst) const override {
71+
OPENVINO_ASSERT(inst.type() == paged_attention::type_id());
72+
OPENVINO_ASSERT(inst.get_impl() == this);
73+
74+
auto& pa_inst = reinterpret_cast<paged_attention_inst&>(inst);
75+
pa_inst.query_block_size = get_query_block_size(PagedAttentionStage::PREFILL);
76+
pa_inst.use_micro_sdpa = use_micro_sdpa;
77+
}
78+
79+
size_t get_query_block_size(const PagedAttentionStage& stage) const {
80+
const auto default_block_size = 16;
81+
82+
if (stage == PagedAttentionStage::PREFILL) {
83+
return use_micro_sdpa ? kernel_selector::SDPAKernelMicro::GetTileQSize(_kernels_data[Stage::SDPA])
84+
: default_block_size;
85+
} else {
86+
return default_block_size;
87+
}
88+
}
89+
6990
void load(BinaryInputBuffer& ib) override {
7091
parent::load(ib);
7192
ib >> make_data(&has_scores_output, sizeof(bool));
7293
ib >> make_data(&has_rotated_blocks, sizeof(bool));
94+
ib >> make_data(&use_micro_sdpa, sizeof(bool));
7395
if (is_dynamic()) {
7496
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
7597
auto kv_cache_update_kernel_impl = kv_cache_update_kernel_selector.GetImplementation(_kernels_data[Stage::KV_CACHE_UPDATE].kernelName);
@@ -95,9 +117,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
95117
parent::save(ob);
96118
ob << make_data(&has_scores_output, sizeof(bool));
97119
ob << make_data(&has_rotated_blocks, sizeof(bool));
120+
ob << make_data(&use_micro_sdpa, sizeof(bool));
98121
}
99122

100-
std::vector<layout> get_internal_buffer_layouts_impl() const override {
123+
std::vector<kernel_selector::InternalBuffer> get_internal_buffers_desc() const {
101124
/*
102125
* Internal buffers allocation owners and users:
103126
* +--------------------------------------+--------------------+--------------------+
@@ -117,6 +140,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
117140
* +--------------------------------------+--------------------+--------------------+
118141
* | PA_SDPA (mixed mode) + scores output | [3, 4, 5, 6, 7, 8] | |
119142
* +--------------------------------------+--------------------+--------------------+
143+
* | SDPA (1st token, micro-kernel) | [last (8/9)] | |
144+
* +--------------------------------------+--------------------+--------------------+
120145
*
121146
* Description:
122147
* 0, 1, 2 - Buffers used for proper blocks distribution for kv_cache_update and
@@ -129,24 +154,32 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
129154
* Filled in PA/SDPA kernels.
130155
* 8 - Optional buffer used for mixed PA execution mode, mapping gws idx to subsequence id.
131156
* Filled in paged_attention_inst::on_execute() call.
157+
* last - Used for defining query block index for the currently processing subsequence and mapping
158+
* gws index to subsequence idx. Values stored in pairs like:
159+
* [block_idx0, subsequence_idx0, block_idx1, subsequence_idx0, ..., block_idx0, subsequence_idx1].
160+
* Filled in paged_attention_inst::on_execute() call for sdpa-micro kernel only.
132161
*/
133162

134-
auto add_internal_buffers = [](std::vector<layout>& layouts, const kernel_selector::KernelData& kd) {
135-
if (kd.internalBufferSizes.empty())
136-
return;
137-
138-
auto dtype = from_data_type(kd.internalBufferDataType);
139-
const auto bpp = data_type_traits::size_of(dtype);
140-
for (auto size : kd.internalBufferSizes) {
141-
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
142-
{1, 1, 1, (tensor::value_type)(size / bpp)}};
143-
layouts.push_back(inbuf_layout);
144-
}
163+
auto add_internal_buffers = [](std::vector<kernel_selector::InternalBuffer>& internal_buffers,
164+
const kernel_selector::KernelData& kd) {
165+
internal_buffers.insert(internal_buffers.end(), kd.internalBuffers.begin(), kd.internalBuffers.end());
145166
};
146167

168+
std::vector<kernel_selector::InternalBuffer> internal_buffers;
169+
add_internal_buffers(internal_buffers, _kernels_data[Stage::KV_CACHE_UPDATE]);
170+
add_internal_buffers(internal_buffers, _kernels_data[Stage::PA_SDPA]);
171+
172+
if (use_micro_sdpa)
173+
add_internal_buffers(internal_buffers, _kernels_data[Stage::SDPA]);
174+
175+
return internal_buffers;
176+
}
177+
178+
std::vector<layout> get_internal_buffer_layouts_impl() const override {
147179
std::vector<layout> layouts;
148-
add_internal_buffers(layouts, _kernels_data[Stage::KV_CACHE_UPDATE]);
149-
add_internal_buffers(layouts, _kernels_data[Stage::PA_SDPA]);
180+
181+
for (const auto& buffer : get_internal_buffers_desc())
182+
layouts.emplace_back(ov::PartialShape{static_cast<int64_t>(buffer.byte_count)}, ov::element::u8, format::bfyx);
150183

151184
return layouts;
152185
}
@@ -245,12 +278,13 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
245278
}
246279

247280
std::set<size_t> get_lockable_internal_buffers() const override {
248-
size_t mixed_mode_buffer = has_scores_output ? 8 : 6;
249-
250-
std::set<size_t> lockable_ids = { 0, 1, 2, /* SDPA and KV_CACHE_UPDATE indexes configuration */
251-
mixed_mode_buffer /* PA_SDPA multiple tokens mode */ };
252-
if (has_scores_output)
253-
lockable_ids.insert(4 /* Precalculated accumulated sequence length offsets for each subsequence */);
281+
std::set<size_t> lockable_ids;
282+
const auto& internal_buffers = get_internal_buffers_desc();
283+
for (size_t i = 0; i < internal_buffers.size(); i++) {
284+
if (internal_buffers[i].lockable) {
285+
lockable_ids.insert(i);
286+
}
287+
}
254288

255289
return lockable_ids;
256290
};
@@ -271,12 +305,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
271305
size_t internal_buffers_offset = 0;
272306
size_t internal_buffers_count = 0;
273307
if (stage == Stage::PA_SDPA) {
274-
internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size();
275-
internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes.size();
308+
internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers.size();
309+
internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBuffers.size();
276310
} else if (stage == Stage::KV_CACHE_UPDATE) {
277-
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size();
311+
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers.size();
278312
} else if (stage == Stage::SDPA) {
279-
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size();
313+
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBuffers.size();
280314

281315
const auto desc = instance.get_node().as<paged_attention>().get_primitive();
282316
if (desc->has_scores_output()) {
@@ -304,6 +338,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
304338
intermediate_memories.begin() + internal_buffers_offset,
305339
intermediate_memories.begin() + internal_buffers_offset + internal_buffers_count);
306340

341+
if (use_micro_sdpa && stage == Stage::SDPA) {
342+
args.intermediates.push_back(intermediate_memories.back());
343+
}
344+
307345
GPU_DEBUG_TRACE_DETAIL << "Execute stage=" << stage << " kernel=" << kd_idx << " " << _kernels_data[stage].kernelName << " start_offset="
308346
<< internal_buffers_offset << " count=" << internal_buffers_count << "\n";
309347

@@ -581,7 +619,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
581619
static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param,
582620
const PagedAttentionStage& stage,
583621
const kernel_selector::MultiDataTensor& input_tensors,
584-
bool is_dynamic = false) {
622+
int64_t query_block_size,
623+
bool is_dynamic) {
585624
const auto desc = impl_param.typed_desc<paged_attention>();
586625
auto params = get_default_params<sdpa_kernel_params_t>(impl_param, is_dynamic);
587626

@@ -623,6 +662,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
623662

624663
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
625664

665+
const std::vector<int64_t> default_order = {0, 1, 2, 3};
666+
params.input0_order = default_order;
667+
params.input1_order = default_order;
668+
params.input2_order = default_order;
669+
params.output_order = default_order;
670+
626671
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
627672
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;
628673
std::map<size_t, size_t> in_tensor_to_offset_map = {
@@ -643,7 +688,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
643688
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)});
644689

645690
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic)
646-
params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage);
691+
params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage, query_block_size);
647692

648693
if (has_scores_output)
649694
out_tensor_to_offset_map.insert({1, out_offsets_map.at(1)});
@@ -760,7 +805,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
760805
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
761806

762807
if (stage == PagedAttentionStage::PREFILL) {
763-
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
808+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, get_query_block_size(stage), impl_param.is_dynamic());
764809
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
765810
}
766811

@@ -782,8 +827,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
782827
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
783828
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
784829
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
785-
786-
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
830+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, 0, impl_param.is_dynamic());
787831
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
788832
kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params));
789833

@@ -801,12 +845,18 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
801845
impl->has_scores_output = desc->has_scores_output();
802846
impl->has_rotated_blocks = desc->has_rotated_blocks;
803847

848+
if (!kernels_data[Stage::SDPA].kernels[0].micro_kernels.empty()) {
849+
std::cout << "Micro SDPA is chosen! tile_q_size = " << kernel_selector::SDPAKernelMicro::GetTileQSize(kernels_data[Stage::SDPA]) << "\n";
850+
impl->use_micro_sdpa = true;
851+
}
852+
804853
return impl;
805854
}
806855

807856
private:
808857
bool has_scores_output = false;
809858
bool has_rotated_blocks = false;
859+
bool use_micro_sdpa = false;
810860
};
811861

812862
namespace detail {

src/plugins/intel_gpu/src/graph/impls/ocl/primitive_base.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
7474
void save(BinaryOutputBuffer& ob) const override {
7575
primitive_impl::save(ob);
7676
ob << make_data(&_kernel_data.internalBufferDataType, sizeof(kernel_selector::Datatype));
77-
ob << _kernel_data.internalBufferSizes;
77+
ob << _kernel_data.internalBuffers;
7878
ob << _kernel_data.kernels;
7979
ob << _kernel_data.kernelName;
8080
}
8181

8282
void load(BinaryInputBuffer& ib) override {
8383
primitive_impl::load(ib);
8484
ib >> make_data(&_kernel_data.internalBufferDataType, sizeof(kernel_selector::Datatype));
85-
ib >> _kernel_data.internalBufferSizes;
85+
ib >> _kernel_data.internalBuffers;
8686
ib >> _kernel_data.kernels;
8787
ib >> _kernel_data.kernelName;
8888
}
@@ -185,15 +185,15 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
185185
}
186186

187187
std::vector<layout> get_internal_buffer_layouts_impl() const override {
188-
if (_kernel_data.internalBufferSizes.empty())
188+
if (_kernel_data.internalBuffers.empty())
189189
return {};
190190

191191
std::vector<layout> layouts;
192192
auto dtype = from_data_type(_kernel_data.internalBufferDataType);
193193
const auto bpp = data_type_traits::size_of(dtype);
194-
for (auto size : _kernel_data.internalBufferSizes) {
194+
for (const auto& buffer : _kernel_data.internalBuffers) {
195195
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
196-
{1, 1, 1, (tensor::value_type)(size / bpp)}};
196+
{1, 1, 1, (tensor::value_type)(buffer.byte_count / bpp)}};
197197
layouts.push_back(inbuf_layout);
198198
}
199199
return layouts;

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,19 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
7171
// _kernels_data[1] - sdpa_opt (indirect)
7272
// => use internal buffers from [1] kernel
7373
size_t kernel_idx = _kernels_data.size();
74-
if (_kernels_data.size() >= 1 && !_kernels_data[0].internalBufferSizes.empty()) {
74+
if (_kernels_data.size() >= 1 && !_kernels_data[0].internalBuffers.empty()) {
7575
kernel_idx = 0;
76-
} else if (_kernels_data.size() >= 2 && !_kernels_data[1].internalBufferSizes.empty()) {
76+
} else if (_kernels_data.size() >= 2 && !_kernels_data[1].internalBuffers.empty()) {
7777
kernel_idx = 1;
7878
}
7979

8080
std::vector<layout> layouts;
8181
if (kernel_idx < _kernels_data.size()) {
8282
auto dtype = from_data_type(_kernels_data[kernel_idx].internalBufferDataType);
8383
const auto bpp = data_type_traits::size_of(dtype);
84-
for (auto size : _kernels_data[kernel_idx].internalBufferSizes) {
84+
for (const auto& buffer : _kernels_data[kernel_idx].internalBuffers) {
8585
layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel)
86-
{1, 1, 1, (tensor::value_type)(size / bpp)}};
86+
{1, 1, 1, (tensor::value_type)(buffer.byte_count / bpp)}};
8787
layouts.push_back(inbuf_layout);
8888
}
8989
}

src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
6262
memory::ptr rotation_deltas_memory_ptr() const { return input_memory_ptr(14); }
6363
memory::ptr rotation_trig_lut_memory_ptr() const { return input_memory_ptr(15); }
6464

65-
std::shared_ptr<network> prefill_network;
65+
bool use_micro_sdpa = false;
66+
size_t query_block_size = 0;
6667

6768
protected:
6869
void on_execute() override;

src/plugins/intel_gpu/src/graph/include/primitive_inst.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct primitive_impl {
5959
virtual std::set<size_t> get_lockable_internal_buffers() const { return {}; }
6060
virtual void set_node_params(const program_node&) {}
6161
virtual const std::string& get_type_info() const = 0;
62+
virtual void update_inst_params(primitive_inst& instance) const {}
6263
virtual void set_arguments(primitive_inst& instance) = 0;
6364
virtual void set_arguments(primitive_inst& instance, kernel_arguments_data& args) = 0;
6465
virtual event::ptr execute(const std::vector<event::ptr>& events, primitive_inst& instance) = 0;
@@ -413,7 +414,7 @@ class primitive_inst {
413414
std::vector<memory::ptr> allocate_outputs(kernel_impl_params* updated_params = nullptr,
414415
bool reset_mem = true,
415416
bool runtime_alloc = false);
416-
memory::ptr allocate_internal_buffer(size_t idx, bool reset = true);
417+
memory::ptr allocate_internal_buffer(const layout& layout, size_t idx, bool reset = true, bool lockable = false);
417418
void allocate_shape_info_memory();
418419
static std::vector<primitive_inst*> build_exec_deps(
419420
std::vector<std::pair<primitive_inst*, int32_t>> const& mem_deps);

0 commit comments

Comments
 (0)