Skip to content

Commit d364bb1

Browse files
committed
Apply review fixes
1 parent 8638537 commit d364bb1

File tree

4 files changed

+176
-178
lines changed

4 files changed

+176
-178
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+170-9
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,33 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
5858
KV_CACHE_ROTATE,
5959
};
6060

61+
PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) const {
62+
const auto& query_shape = impl_param.get_input_layout(0).get_partial_shape();
63+
const auto& past_lens_shape = impl_param.get_input_layout(5).get_partial_shape();
64+
65+
if (query_shape.is_static() && past_lens_shape.is_static()) {
66+
if (query_shape[0].get_length() == past_lens_shape[0].get_length()) {
67+
return PagedAttentionStage::GENERATE;
68+
}
69+
70+
const auto past_lens_idx = 5;
71+
const auto& memory_deps = impl_param.memory_deps;
72+
const auto past_lens_mem = memory_deps.at(past_lens_idx);
73+
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, *impl_param.strm);
74+
75+
const auto past_lens_size = past_lens_mem_lock.size();
76+
for (size_t i = 0; i < past_lens_size; i++) {
77+
if (past_lens_mem_lock[i] != 0) {
78+
return PagedAttentionStage::MIXED;
79+
}
80+
}
81+
82+
return PagedAttentionStage::PREFILL;
83+
}
84+
85+
return PagedAttentionStage::UNKNOWN;
86+
}
87+
6188
bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const override {
6289
const auto stage = get_paged_attention_stage(impl_params);
6390

@@ -67,15 +94,6 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
6794
return stage == PagedAttentionStage::MIXED;
6895
}
6996

70-
void update_inst_params(primitive_inst& inst) const override {
71-
OPENVINO_ASSERT(inst.type() == paged_attention::type_id());
72-
OPENVINO_ASSERT(inst.get_impl() == this);
73-
74-
auto& pa_inst = reinterpret_cast<paged_attention_inst&>(inst);
75-
pa_inst.query_block_size = get_query_block_size(PagedAttentionStage::PREFILL);
76-
pa_inst.use_micro_sdpa = use_micro_sdpa;
77-
}
78-
7997
size_t get_query_block_size(const PagedAttentionStage& stage) const {
8098
const auto default_block_size = 16;
8199

@@ -292,6 +310,147 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
292310
return lockable_ids;
293311
};
294312

313+
void prepare_internal_buffers(paged_attention_inst& instance, const PagedAttentionStage& stage) {
314+
const auto& desc = instance.get_impl_params()->typed_desc<paged_attention>();
315+
const bool has_scores_output = desc->has_scores_output();
316+
317+
if ((stage == PagedAttentionStage::UNKNOWN) ||
318+
(stage == PagedAttentionStage::GENERATE && !has_scores_output))
319+
return;
320+
321+
auto& stream = instance.get_network().get_stream();
322+
const auto past_lens_mem = instance.past_lens_memory_ptr();
323+
const auto subsequence_begins_mem = instance.subsequence_begins_memory_ptr();
324+
auto intermediates_memories = instance.get_intermediates_memories();
325+
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
326+
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
327+
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> subsequence_offsets_lock = nullptr;
328+
329+
if (has_scores_output) {
330+
const size_t subsequence_offsets_idx = 4;
331+
332+
OPENVINO_ASSERT(intermediates_memories.size() > subsequence_offsets_idx,
333+
"[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation");
334+
335+
auto subsequence_offsets_mem = intermediates_memories[subsequence_offsets_idx];
336+
subsequence_offsets_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(subsequence_offsets_mem, stream));
337+
}
338+
339+
if (stage == PagedAttentionStage::GENERATE) {
340+
// For the generate stage it's not necessary to configure any other intermediate
341+
// buffers. Simply calculate the offsets and exit
342+
size_t subsequence_offsets_acc = 0;
343+
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
344+
const auto past_len = past_lens_mem_lock[i];
345+
const auto seq_start = subsequence_begins_mem_lock[i];
346+
const auto seq_end = subsequence_begins_mem_lock[i + 1];
347+
const auto seq_length = seq_end - seq_start;
348+
349+
if (subsequence_offsets_lock) {
350+
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
351+
subsequence_offsets_acc += seq_length + past_len;
352+
}
353+
}
354+
355+
return;
356+
}
357+
358+
OPENVINO_ASSERT(intermediates_memories.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");
359+
360+
const auto blocks_indexes_start_idx = 0;
361+
const auto blocks_indexes_end_idx = 1;
362+
const auto blocked_gws_subseq_mapping_idx = 2;
363+
364+
auto blocks_indexes_start_mem = intermediates_memories[blocks_indexes_start_idx];
365+
auto blocks_indexes_end_mem = intermediates_memories[blocks_indexes_end_idx];
366+
auto blocked_gws_subseq_mapping_mem = intermediates_memories[blocked_gws_subseq_mapping_idx];
367+
368+
OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32);
369+
370+
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
371+
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
372+
mem_lock<int32_t, mem_lock_type::write> blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream);
373+
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;
374+
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> micro_sdpa_block_starts_and_gws_mapping_lock = nullptr;
375+
376+
if (stage == PagedAttentionStage::MIXED) {
377+
const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6;
378+
379+
OPENVINO_ASSERT(intermediates_memories.size() > sequential_gws_subseq_mapping_idx,
380+
"[GPU] Unexpected number of intermediates buffers for Paged Attention for mixed stage");
381+
382+
auto sequential_gws_subseq_mapping_mem = intermediates_memories[sequential_gws_subseq_mapping_idx];
383+
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
384+
}
385+
386+
if (stage == PagedAttentionStage::PREFILL && use_micro_sdpa) {
387+
const auto memory_idx = intermediates_memories.size() - 1;
388+
389+
auto memory = intermediates_memories[memory_idx];
390+
micro_sdpa_block_starts_and_gws_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(memory, stream));
391+
}
392+
393+
size_t index = 0;
394+
size_t micro_sdpa_index = 0;
395+
size_t subsequence_offsets_acc = 0;
396+
size_t query_block_size = get_query_block_size(stage);
397+
const auto pa_block_size = static_cast<int>(paged_attention::block_size);
398+
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
399+
const auto past_len = past_lens_mem_lock[i];
400+
const auto seq_start = subsequence_begins_mem_lock[i];
401+
const auto seq_end = subsequence_begins_mem_lock[i + 1];
402+
const auto seq_length = seq_end - seq_start;
403+
404+
int32_t j = 0;
405+
if (past_len != 0) {
406+
auto block_start_pos = seq_start;
407+
auto empty_slots = pa_block_size - (past_len % pa_block_size);
408+
auto block_end_pos = seq_start + std::min(empty_slots, seq_length);
409+
410+
blocks_indexes_start_lock[index] = block_start_pos;
411+
blocks_indexes_end_lock[index] = block_end_pos;
412+
blocked_gws_subseq_mapping_mem_lock[index] = static_cast<int32_t>(i);
413+
414+
index++;
415+
416+
auto added_slots = block_end_pos - block_start_pos;
417+
j += added_slots;
418+
}
419+
420+
for (; j < seq_length; j += pa_block_size) {
421+
auto block_start_pos = subsequence_begins_mem_lock[i] + j;
422+
auto block_end_pos = std::min(block_start_pos + pa_block_size, seq_end);
423+
424+
blocks_indexes_start_lock[index] = block_start_pos;
425+
blocks_indexes_end_lock[index] = block_end_pos;
426+
blocked_gws_subseq_mapping_mem_lock[index] = static_cast<int32_t>(i);
427+
428+
index++;
429+
}
430+
431+
if (micro_sdpa_block_starts_and_gws_mapping_lock) {
432+
const auto block_size = static_cast<int>(query_block_size);
433+
for (int32_t j = 0; j < seq_length; j += block_size) {
434+
auto block_start_pos = subsequence_begins_mem_lock[i] + j;
435+
436+
micro_sdpa_block_starts_and_gws_mapping_lock->operator[](micro_sdpa_index++) = block_start_pos;
437+
micro_sdpa_block_starts_and_gws_mapping_lock->operator[](micro_sdpa_index++) = static_cast<int32_t>(i);
438+
}
439+
}
440+
441+
if (stage == PagedAttentionStage::MIXED) {
442+
for (int32_t idx = seq_start; idx < seq_end; idx++) {
443+
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast<int32_t>(i);
444+
}
445+
}
446+
447+
if (subsequence_offsets_lock) {
448+
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
449+
subsequence_offsets_acc += seq_length + past_len;
450+
}
451+
}
452+
}
453+
295454
void execute_stage(const std::vector<event::ptr>& events,
296455
paged_attention_inst& instance,
297456
std::vector<event::ptr>& all_events,
@@ -385,6 +544,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
385544
const auto stage = get_paged_attention_stage(*instance.get_impl_params());
386545
const auto is_mixed_mode = stage == PagedAttentionStage::MIXED;
387546

547+
prepare_internal_buffers(instance, stage);
548+
388549
std::vector<event::ptr> res_events;
389550
std::vector<event::ptr> dep_events = events;
390551
if (has_rotated_blocks && !_kernels_data[Stage::KV_CACHE_ROTATE].kernels[0].skip_execution) {

src/plugins/intel_gpu/src/graph/include/primitive_inst.h

-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ struct primitive_impl {
5959
virtual std::set<size_t> get_lockable_internal_buffers() const { return {}; }
6060
virtual void set_node_params(const program_node&) {}
6161
virtual const std::string& get_type_info() const = 0;
62-
virtual void update_inst_params(primitive_inst& instance) const {}
6362
virtual void set_arguments(primitive_inst& instance) = 0;
6463
virtual void set_arguments(primitive_inst& instance, kernel_arguments_data& args) = 0;
6564
virtual event::ptr execute(const std::vector<event::ptr>& events, primitive_inst& instance) = 0;

src/plugins/intel_gpu/src/graph/paged_attention.cpp

-166
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,6 @@ GPU_DEFINE_PRIMITIVE_TYPE_ID(paged_attention)
1313

1414
constexpr size_t paged_attention::block_size;
1515

16-
PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) {
17-
const auto& query_shape = impl_param.get_input_layout(0).get_partial_shape();
18-
const auto& past_lens_shape = impl_param.get_input_layout(5).get_partial_shape();
19-
20-
if (query_shape.is_static() && past_lens_shape.is_static()) {
21-
if (query_shape[0].get_length() == past_lens_shape[0].get_length()) {
22-
return PagedAttentionStage::GENERATE;
23-
}
24-
25-
const auto past_lens_idx = 5;
26-
const auto& memory_deps = impl_param.memory_deps;
27-
const auto past_lens_mem = memory_deps.at(past_lens_idx);
28-
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, *impl_param.strm);
29-
30-
const auto past_lens_size = past_lens_mem_lock.size();
31-
for (size_t i = 0; i < past_lens_size; i++) {
32-
if (past_lens_mem_lock[i] != 0) {
33-
return PagedAttentionStage::MIXED;
34-
}
35-
}
36-
37-
return PagedAttentionStage::PREFILL;
38-
}
39-
40-
return PagedAttentionStage::UNKNOWN;
41-
}
42-
4316
layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*node*/, kernel_impl_params const& impl_param) {
4417
auto out_layout = impl_param.get_input_layout(0);
4518

@@ -107,146 +80,7 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
10780
}
10881

10982
void paged_attention_inst::on_execute() {
110-
const auto& desc = _impl_params->typed_desc<paged_attention>();
111-
const bool has_scores_output = desc->has_scores_output();
112-
const auto stage = get_paged_attention_stage(*_impl_params);
113-
114-
if ((stage == PagedAttentionStage::UNKNOWN) ||
115-
(stage == PagedAttentionStage::GENERATE && !has_scores_output))
116-
return;
117-
118-
OPENVINO_ASSERT(_impl != nullptr, "[GPU] PagedAttention impl shouldn't be nullptr");
119-
_impl->update_inst_params(*this);
120-
121-
auto& stream = get_network().get_stream();
122-
const auto past_lens_mem = past_lens_memory_ptr();
123-
const auto subsequence_begins_mem = subsequence_begins_memory_ptr();
124-
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
125-
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
126-
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> subsequence_offsets_lock = nullptr;
127-
128-
if (has_scores_output) {
129-
const size_t subsequence_offsets_idx = 4;
130-
131-
OPENVINO_ASSERT(_intermediates_memory.size() > subsequence_offsets_idx,
132-
"[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation");
133-
134-
auto subsequence_offsets_mem = _intermediates_memory[subsequence_offsets_idx];
135-
subsequence_offsets_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(subsequence_offsets_mem, stream));
136-
}
137-
138-
if (stage == PagedAttentionStage::GENERATE) {
139-
// For the generate stage it's not necessary to configure any other intermediate
140-
// buffers. Simply calculate the offsets and exit
141-
size_t subsequence_offsets_acc = 0;
142-
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
143-
const auto past_len = past_lens_mem_lock[i];
144-
const auto seq_start = subsequence_begins_mem_lock[i];
145-
const auto seq_end = subsequence_begins_mem_lock[i + 1];
146-
const auto seq_length = seq_end - seq_start;
147-
148-
if (subsequence_offsets_lock) {
149-
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
150-
subsequence_offsets_acc += seq_length + past_len;
151-
}
152-
}
153-
154-
return;
155-
}
156-
157-
OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");
158-
159-
const auto blocks_indexes_start_idx = 0;
160-
const auto blocks_indexes_end_idx = 1;
161-
const auto blocked_gws_subseq_mapping_idx = 2;
162-
163-
auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx];
164-
auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx];
165-
auto blocked_gws_subseq_mapping_mem = _intermediates_memory[blocked_gws_subseq_mapping_idx];
16683

167-
OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32);
168-
169-
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
170-
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
171-
mem_lock<int32_t, mem_lock_type::write> blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream);
172-
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;
173-
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> micro_sdpa_block_starts_and_gws_mapping_lock = nullptr;
174-
175-
if (stage == PagedAttentionStage::MIXED) {
176-
const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6;
177-
178-
OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx,
179-
"[GPU] Unexpected number of intermediates buffers for Paged Attention for mixed stage");
180-
181-
auto sequential_gws_subseq_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx];
182-
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
183-
}
184-
185-
if (stage == PagedAttentionStage::PREFILL && use_micro_sdpa) {
186-
const auto sequential_gws_subseq_mapping_idx = _intermediates_memory.size() - 1;
187-
188-
auto micro_sdpa_block_starts_and_gws_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx];
189-
micro_sdpa_block_starts_and_gws_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(micro_sdpa_block_starts_and_gws_mapping_mem, stream));
190-
}
191-
192-
size_t index = 0;
193-
size_t micro_sdpa_index = 0;
194-
size_t subsequence_offsets_acc = 0;
195-
const auto pa_block_size = static_cast<int>(paged_attention::block_size);
196-
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
197-
const auto past_len = past_lens_mem_lock[i];
198-
const auto seq_start = subsequence_begins_mem_lock[i];
199-
const auto seq_end = subsequence_begins_mem_lock[i + 1];
200-
const auto seq_length = seq_end - seq_start;
201-
202-
int32_t j = 0;
203-
if (past_len != 0) {
204-
auto block_start_pos = seq_start;
205-
auto empty_slots = pa_block_size - (past_len % pa_block_size);
206-
auto block_end_pos = seq_start + std::min(empty_slots, seq_length);
207-
208-
blocks_indexes_start_lock[index] = block_start_pos;
209-
blocks_indexes_end_lock[index] = block_end_pos;
210-
blocked_gws_subseq_mapping_mem_lock[index] = static_cast<int32_t>(i);
211-
212-
index++;
213-
214-
auto added_slots = block_end_pos - block_start_pos;
215-
j += added_slots;
216-
}
217-
218-
for (; j < seq_length; j += pa_block_size) {
219-
auto block_start_pos = subsequence_begins_mem_lock[i] + j;
220-
auto block_end_pos = std::min(block_start_pos + pa_block_size, seq_end);
221-
222-
blocks_indexes_start_lock[index] = block_start_pos;
223-
blocks_indexes_end_lock[index] = block_end_pos;
224-
blocked_gws_subseq_mapping_mem_lock[index] = static_cast<int32_t>(i);
225-
226-
index++;
227-
}
228-
229-
if (micro_sdpa_block_starts_and_gws_mapping_lock) {
230-
const auto block_size = static_cast<int>(query_block_size);
231-
for (int32_t j = 0; j < seq_length; j += block_size) {
232-
auto block_start_pos = subsequence_begins_mem_lock[i] + j;
233-
234-
micro_sdpa_block_starts_and_gws_mapping_lock->operator[](micro_sdpa_index++) = block_start_pos;
235-
micro_sdpa_block_starts_and_gws_mapping_lock->operator[](micro_sdpa_index++) = static_cast<int32_t>(i);
236-
}
237-
}
238-
239-
if (stage == PagedAttentionStage::MIXED) {
240-
for (int32_t idx = seq_start; idx < seq_end; idx++) {
241-
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast<int32_t>(i);
242-
}
243-
}
244-
245-
if (subsequence_offsets_lock) {
246-
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
247-
subsequence_offsets_acc += seq_length + past_len;
248-
}
249-
}
25084
}
25185

25286
paged_attention_inst::typed_primitive_inst(network& network, const paged_attention_node& node)

0 commit comments

Comments
 (0)