Skip to content

Commit 92e39e2

Browse files
committed
[GPU] Paged attention
1 parent a8cc74e commit 92e39e2

23 files changed

+1583
-2
lines changed

src/plugins/intel_gpu/include/intel_gpu/plugin/program_builder.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class ProgramBuilder final {
180180
void CreateSingleLayerPrimitive(cldnn::topology& topology, const std::shared_ptr<ov::Node>& op);
181181
};
182182

183+
void CreatePagedAttention(ProgramBuilder& p, const std::shared_ptr<ov::Node>& op);
183184
void CreateCustomOp(ProgramBuilder& p, const std::shared_ptr<ov::Node>& node, CustomLayerPtr customLayer);
184185
void CreateUnaryEltwiseOp(ProgramBuilder& p, const std::shared_ptr<ov::Node>& node,
185186
cldnn::activation_func func, cldnn::activation_additional_params params);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
#include "primitive.hpp"
7+
#include "intel_gpu/graph/program.hpp"
8+
9+
#include <vector>
10+
11+
namespace cldnn {
12+
13+
struct paged_attention : public primitive_base<paged_attention> {
14+
CLDNN_DECLARE_PRIMITIVE(paged_attention)
15+
16+
paged_attention() : primitive_base("", {}) {}
17+
18+
paged_attention(const primitive_id& id,
19+
const std::vector<input_info>& inputs,
20+
const padding& output_padding = padding())
21+
: primitive_base(id, inputs, {output_padding}) {
22+
OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size());
23+
}
24+
25+
bool operator==(const primitive& rhs) const override {
26+
return compare_common_params(rhs);
27+
}
28+
29+
void save(BinaryOutputBuffer& ob) const override {
30+
primitive_base<paged_attention>::save(ob);
31+
ob << head_size;
32+
ob << heads_num;
33+
ob << kv_heads_num;
34+
ob << block_size;
35+
ob << x_block_size;
36+
}
37+
38+
void load(BinaryInputBuffer& ib) override {
39+
primitive_base<paged_attention>::load(ib);
40+
ib >> head_size;
41+
ib >> heads_num;
42+
ib >> kv_heads_num;
43+
ib >> block_size;
44+
ib >> x_block_size;
45+
}
46+
47+
size_t head_size;
48+
size_t heads_num;
49+
size_t kv_heads_num;
50+
size_t block_size;
51+
size_t x_block_size;
52+
};
53+
} // namespace cldnn

src/plugins/intel_gpu/src/graph/impls/ocl/multi_stage_primitive.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
5353
}
5454
this->can_reuse_memory = other.can_reuse_memory;
5555
this->_kernel_name = other._kernel_name;
56+
this->can_reuse_memory = other.can_reuse_memory;
5657
this->_is_dynamic = other._is_dynamic;
5758
}
5859

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "intel_gpu/plugin/multi_tensor_variable_state.hpp"
6+
#include "intel_gpu/plugin/variable_state.hpp"
7+
#include "intel_gpu/runtime/debug_configuration.hpp"
8+
#include "intel_gpu/runtime/memory.hpp"
9+
#include "multi_stage_primitive.hpp"
10+
11+
#include "paged_attention_inst.h"
12+
#include "paged_attention/paged_attention_kernel_selector.hpp"
13+
#include "paged_attention/kv_cache_update_kernel_ref.hpp"
14+
#include "paged_attention/sdpa_kernel_ref.hpp"
15+
16+
namespace cldnn {
17+
namespace ocl {
18+
19+
struct paged_attention_impl : multi_stage_primitive<paged_attention> {
20+
using parent = multi_stage_primitive<paged_attention>;
21+
using parent::parent;
22+
using kv_cache_update_kernel_selector_t = kernel_selector::kv_cache_update_kernel_selector;
23+
using kv_cache_update_kernel_params_t = kernel_selector::kv_cache_update_params;
24+
25+
using sdpa_kernel_selector_t = kernel_selector::sdpa_kernel_selector;
26+
using sdpa_kernel_params_t = kernel_selector::sdpa_params;
27+
28+
DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::paged_attention_impl)
29+
30+
std::unique_ptr<primitive_impl> clone() const override {
31+
return make_unique<paged_attention_impl>(*this);
32+
}
33+
34+
paged_attention_impl() = default;
35+
36+
paged_attention_impl(const std::vector<kernel_selector::kernel_data>& kd) : parent(kd) {
37+
this->can_reuse_memory = true;
38+
}
39+
40+
void set_arguments_impl(paged_attention_inst& instance) override {}
41+
kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage) const override { return kernel_arguments_data(); }
42+
43+
enum Stage {
44+
KV_CACHE_UPDATE,
45+
SDPA
46+
};
47+
48+
void load(BinaryInputBuffer& ib) override {
49+
parent::load(ib);
50+
if (is_dynamic()) {
51+
auto& kernel_selector = kv_cache_update_kernel_selector_t::Instance();
52+
auto kernel_impl = kernel_selector.GetImplementation(_kernels_data[Stage::KV_CACHE_UPDATE].kernelName);
53+
kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[Stage::KV_CACHE_UPDATE]);
54+
55+
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
56+
auto bt_kernel_impl = sdpa_kernel_selector.GetImplementation(_kernels_data[Stage::SDPA].kernelName);
57+
bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[Stage::SDPA]);
58+
}
59+
}
60+
61+
kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage, size_t kernel_idx) const {
62+
kernel_arguments_data args;
63+
if (stage == Stage::KV_CACHE_UPDATE || (stage == Stage::SDPA && kernel_idx == 0))
64+
args.shape_info = instance.shape_info_memory_ptr();
65+
66+
if (stage == Stage::KV_CACHE_UPDATE) {
67+
args.inputs = { instance.input_memory_ptr(1), /* key */
68+
instance.input_memory_ptr(2), /* value */
69+
instance.input_memory_ptr(6) /* slot_mapping */};
70+
args.outputs = { instance.input_memory_ptr(3), /* key_cache */
71+
instance.input_memory_ptr(4) /* value_cache */ };
72+
} else if (stage == Stage::SDPA) {
73+
if (kernel_idx == 0) {
74+
args.inputs = { instance.input_memory_ptr(0), /* query */
75+
instance.input_memory_ptr(3), /* key_cache */
76+
instance.input_memory_ptr(4), /* value_cache */
77+
instance.input_memory_ptr(7), /* max_context_len */
78+
instance.input_memory_ptr(8), /* context_lens */
79+
instance.input_memory_ptr(9), /* block_tables */
80+
instance.input_memory_ptr(10) /* scale */ };
81+
} else {
82+
args.inputs = { instance.input_memory_ptr(8), /* context_lens */ };
83+
}
84+
args.outputs = { instance.output_memory_ptr(0) };
85+
}
86+
87+
return args;
88+
}
89+
90+
void execute_stage(const std::vector<event::ptr>& events, paged_attention_inst& instance, std::vector<event::ptr>& all_events, size_t stage) {
91+
stream& stream = instance.get_network().get_stream();
92+
std::vector<event::ptr> tmp_events(events);
93+
size_t kernel_offset = 0;
94+
for (size_t s = 0; s < stage; s++) {
95+
kernel_offset += _kernels_data[s].kernels.size();
96+
}
97+
for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) {
98+
auto time0 = std::chrono::high_resolution_clock::now();
99+
if (_kernels_data[stage].kernels[kd_idx].skip_execution)
100+
continue;
101+
102+
size_t idx_final = kernel_offset + kd_idx;
103+
// If any user of the prim's users is CPU implementation or network's output, set prim as a output event (event won't be nullptr)
104+
bool needs_completion_event = instance.needs_completion_event();
105+
106+
auto& params = _kernels_data[stage].kernels[kd_idx].params;
107+
108+
auto args = get_arguments(instance, stage, kd_idx);
109+
args.scalars = &params.scalars;
110+
111+
for (const auto& m : instance.get_intermediates_memories()) {
112+
args.intermediates.push_back(m);
113+
}
114+
115+
auto time1 = std::chrono::high_resolution_clock::now();
116+
stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args);
117+
auto time2 = std::chrono::high_resolution_clock::now();
118+
119+
const auto& gws = params.workGroups.global;
120+
const auto& lws = params.workGroups.local;
121+
122+
GPU_DEBUG_TRACE_DETAIL << "Enqueue stage " << stage << " kernel " << idx_final << ": gws=[" << gws[0] << ", " << gws[1] << ", " << gws[2] << "] "
123+
<< "lws=[" << lws[0] << ", " << lws[1] << ", " << lws[2] << "]"
124+
<< (needs_completion_event ? " has_completion_event=true" : "") << std::endl;
125+
126+
auto ev = stream.enqueue_kernel(*_kernels[idx_final], params, args, tmp_events, needs_completion_event);
127+
auto time3 = std::chrono::high_resolution_clock::now();
128+
if (_kernels_data[stage].needs_sub_kernels_sync) {
129+
tmp_events = {ev};
130+
}
131+
132+
auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(time1 - time0).count();
133+
auto time_res1 = std::chrono::duration_cast<std::chrono::microseconds>(time2 - time1).count();
134+
auto time_res2 = std::chrono::duration_cast<std::chrono::microseconds>(time3 - time2).count();
135+
GPU_DEBUG_TRACE_DETAIL << "Time execute_stage inside = " << time_res0 << " " << time_res1 << " " << time_res2 << "\n";
136+
137+
all_events.push_back(ev);
138+
}
139+
}
140+
141+
event::ptr execute_impl(const std::vector<event::ptr>& events, paged_attention_inst& instance) override {
142+
std::vector<event::ptr> res_events;
143+
execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE);
144+
145+
std::vector<event::ptr> dep_events(res_events.begin(), res_events.end());
146+
execute_stage(dep_events, instance, res_events, Stage::SDPA);
147+
148+
return aggregate_events(res_events, instance.get_network().get_stream(), res_events.size() > 1);
149+
}
150+
151+
static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) {
152+
kernel_selector::sdpa_configuration config;
153+
154+
const auto query_layout = impl_param.get_input_layout(0);
155+
const auto key_cache_layout = impl_param.get_input_layout(3);
156+
const auto value_cache_layout = impl_param.get_input_layout(4);
157+
158+
const auto desc = impl_param.typed_desc<paged_attention>();
159+
config.head_size = desc->head_size;
160+
config.heads_num = desc->heads_num;
161+
config.kv_heads_num = desc->kv_heads_num;
162+
config.block_size = desc->block_size;
163+
config.x_block_size = desc->x_block_size;
164+
config.max_context_len = 1;
165+
166+
const size_t simd_size = 16;
167+
OPENVINO_ASSERT(config.head_size % simd_size == 0, "[GPU] Head size is expected to be divisible by 16");
168+
169+
return config;
170+
}
171+
172+
static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic = false) {
173+
kv_cache_update_kernel_params_t params;
174+
set_params(impl_param, params);
175+
176+
auto query = impl_param.get_input_layout(0);
177+
auto key = impl_param.get_input_layout(1);
178+
auto value = impl_param.get_input_layout(2);
179+
auto key_cache = impl_param.get_input_layout(3);
180+
auto value_cache = impl_param.get_input_layout(4);
181+
auto slot_mapping = impl_param.get_input_layout(6);
182+
183+
params.is_shape_agnostic = is_dynamic;
184+
params.stage_id = 0;
185+
params.inputs.resize(3);
186+
params.outputs.resize(2);
187+
params.inputs[0] = convert_data_tensor(key);
188+
params.inputs[1] = convert_data_tensor(value);
189+
params.inputs[2] = convert_data_tensor(slot_mapping);
190+
params.outputs[0] = convert_data_tensor(key_cache);
191+
params.outputs[1] = convert_data_tensor(value_cache);
192+
params.layerID = impl_param.desc->id;
193+
194+
params.configuration = get_sdpa_configuration(impl_param);
195+
196+
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
197+
std::map<size_t, size_t> in_tensor_to_offset_map = {
198+
{0, in_offsets_map.at(1)},
199+
{1, in_offsets_map.at(2)},
200+
{2, in_offsets_map.at(6)},
201+
};
202+
std::map<size_t, size_t> out_tensor_to_offset_map = {
203+
{0, in_offsets_map.at(3)},
204+
{1, in_offsets_map.at(4)},
205+
};
206+
207+
params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);
208+
209+
return params;
210+
}
211+
212+
static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic = false) {
213+
auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic);
214+
215+
const auto inputs_count = 7;
216+
const auto query_layout = impl_param.get_input_layout(0);
217+
const auto key_cache_layout = impl_param.get_input_layout(3);
218+
const auto value_cache_layout = impl_param.get_input_layout(4);
219+
const auto max_context_len_layout = impl_param.get_input_layout(7);
220+
const auto context_lens_layout = impl_param.get_input_layout(8);
221+
const auto block_tables_layout = impl_param.get_input_layout(9);
222+
const auto scale_layout = impl_param.get_input_layout(10);
223+
224+
params.inputs.resize(inputs_count);
225+
params.inputs[1] = convert_data_tensor(key_cache_layout);
226+
params.inputs[2] = convert_data_tensor(value_cache_layout);
227+
params.inputs[3] = convert_data_tensor(max_context_len_layout);
228+
params.inputs[4] = convert_data_tensor(context_lens_layout);
229+
params.inputs[5] = convert_data_tensor(block_tables_layout);
230+
params.inputs[6] = convert_data_tensor(scale_layout);
231+
232+
params.configuration = get_sdpa_configuration(impl_param);
233+
if (!is_dynamic) {
234+
auto& constant_mem = impl_param.memory_deps;
235+
236+
const auto max_context_len_mem = constant_mem.at(7);
237+
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
238+
239+
const auto is_prompt_stage_mem = constant_mem.at(5);
240+
mem_lock<uint8_t, mem_lock_type::read> is_prompt_stage_mem_lock(is_prompt_stage_mem, impl_param.get_stream());
241+
bool is_prompt_stage = is_prompt_stage_mem_lock[0];
242+
243+
if (is_prompt_stage) {
244+
// Use number of slots for KV cache as a maximum context length for the first iteration
245+
auto slot_mapping = impl_param.get_input_layout(6);
246+
params.configuration.max_context_len = slot_mapping.get_shape()[1];
247+
} else {
248+
const auto max_context_len_mem = constant_mem.at(7);
249+
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
250+
params.configuration.max_context_len = max_context_len_mem_lock[0];
251+
}
252+
}
253+
254+
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
255+
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;
256+
std::map<size_t, size_t> in_tensor_to_offset_map = {
257+
{0, in_offsets_map.at(0)},
258+
{1, in_offsets_map.at(3)},
259+
{2, in_offsets_map.at(4)},
260+
{3, in_offsets_map.at(7)},
261+
{4, in_offsets_map.at(8)},
262+
{5, in_offsets_map.at(9)},
263+
{6, in_offsets_map.at(10)},
264+
};
265+
std::map<size_t, size_t> out_tensor_to_offset_map = {
266+
{0, out_offsets_map.at(0)},
267+
};
268+
269+
params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);
270+
271+
return params;
272+
}
273+
274+
static std::unique_ptr<primitive_impl> create(const typed_program_node<paged_attention>& arg, const kernel_impl_params& impl_param) {
275+
std::vector<kernel_selector::kernel_data> kernels_data;
276+
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic());
277+
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
278+
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
279+
280+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic());
281+
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
282+
kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params));
283+
284+
return cldnn::make_unique<paged_attention_impl>(kernels_data);
285+
}
286+
287+
void update_dispatch_data(const kernel_impl_params& impl_param) override {
288+
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic());
289+
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
290+
291+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic());
292+
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
293+
}
294+
};
295+
296+
namespace detail {
297+
298+
attach_paged_attention_impl::attach_paged_attention_impl() {
299+
auto types = { data_types::f16, data_types::f32 };
300+
auto formats = { format::bfyx };
301+
implementation_map<paged_attention>::add(impl_types::ocl,
302+
shape_types::dynamic_shape,
303+
paged_attention_impl::create,
304+
types,
305+
formats);
306+
307+
implementation_map<paged_attention>::add(impl_types::ocl,
308+
shape_types::static_shape,
309+
paged_attention_impl::create,
310+
types,
311+
formats);
312+
}
313+
314+
} // namespace detail
315+
} // namespace ocl
316+
} // namespace cldnn
317+
318+
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::paged_attention_impl)
319+
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::paged_attention)

src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ void register_implementations() {
4141
REGISTER_OCL(grid_sample);
4242
REGISTER_OCL(group_normalization);
4343
REGISTER_OCL(kv_cache);
44+
REGISTER_OCL(paged_attention);
4445
REGISTER_OCL(lrn);
4546
REGISTER_OCL(lstm_elt);
4647
REGISTER_OCL(multiclass_nms);

0 commit comments

Comments
 (0)