Skip to content

Commit 15e3064

Browse files
committed
WIP
1 parent 757caa1 commit 15e3064

22 files changed

+830
-15
lines changed

src/plugins/intel_gpu/include/intel_gpu/plugin/program_builder.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ class ProgramBuilder final {
172172
void CreateSingleLayerPrimitive(cldnn::topology& topology, const std::shared_ptr<ov::Node>& op);
173173
};
174174

175+
void CreatePagedAttention(ProgramBuilder& p, const std::shared_ptr<ov::Node>& op);
175176
void CreateCustomOp(ProgramBuilder& p, const std::shared_ptr<ov::Node>& node, CustomLayerPtr customLayer);
176177
void CreateUnaryEltwiseOp(ProgramBuilder& p, const std::shared_ptr<ov::Node>& node,
177178
cldnn::activation_func func, cldnn::activation_additional_params params);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
#include "primitive.hpp"
7+
8+
#include <vector>
9+
10+
namespace cldnn {
11+
12+
struct paged_attention : public primitive_base<paged_attention> {
13+
CLDNN_DECLARE_PRIMITIVE(paged_attention)
14+
15+
paged_attention() : primitive_base("", {}) {}
16+
17+
paged_attention(const primitive_id& id,
18+
const std::vector<input_info>& inputs,
19+
const padding& output_padding = padding())
20+
: primitive_base(id, inputs, {output_padding}) {
21+
OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size());
22+
}
23+
24+
bool operator==(const primitive& rhs) const override {
25+
return compare_common_params(rhs);
26+
}
27+
28+
void save(BinaryOutputBuffer& ob) const override {
29+
primitive_base<paged_attention>::save(ob);
30+
}
31+
32+
void load(BinaryInputBuffer& ib) override {
33+
primitive_base<paged_attention>::load(ib);
34+
}
35+
};
36+
} // namespace cldnn

src/plugins/intel_gpu/src/graph/graph_optimizer/compile_graph.cpp

+10-12
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ void compile_graph::run(program& p) {
3434
}
3535
}
3636

37-
auto task_executor = p.get_task_executor();
3837
auto& proc_order = p.get_processing_order();
39-
std::vector<ov::threading::Task> tasks;
40-
std::exception_ptr exception;
4138
for (size_t idx = 0; idx < proc_order.size(); idx++) {
4239
auto& node = *(std::next(proc_order.begin(), idx));
4340
const bool use_shape_agnostic_impl = !p.get_config().get_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape);
@@ -79,29 +76,30 @@ void compile_graph::run(program& p) {
7976
can_select_impl = true;
8077

8178
if (can_select_impl) {
82-
tasks.push_back([node, &exception, change_initial_impl, original_impl_type] {
8379
try {
80+
std::exception_ptr curr_excp;
8481
node->selected_impl = node->type()->choose_impl(*node);
8582
if (change_initial_impl) {
8683
GPU_DEBUG_TRACE_DETAIL << node->id() << ": use " << node->get_preferred_impl_type()
8784
<< " as initial impl instead of " << original_impl_type << std::endl;
8885
node->set_preferred_impl_type(original_impl_type);
8986
}
9087
} catch(...) {
91-
exception = std::current_exception();
88+
try {
89+
std::exception_ptr curr_excp;
90+
if (curr_excp = std::current_exception())
91+
{
92+
std::rethrow_exception(curr_excp);
93+
}
94+
} catch (const std::exception& e) {
95+
std::cerr << "Can't compile " << node->id() << ", error " << e.what() << "\n";
96+
}
9297
}
93-
});
9498
} else {
9599
if (change_initial_impl) {
96100
node->set_preferred_impl_type(original_impl_type);
97101
}
98102
}
99103
}
100104

101-
task_executor->run_and_wait(tasks);
102-
tasks.clear();
103-
104-
if (exception) {
105-
std::rethrow_exception(exception);
106-
}
107105
}

src/plugins/intel_gpu/src/graph/graph_optimizer/mark_runtime_skippable_nodes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void mark_runtime_skippable_nodes::run(program& p) {
4040
}
4141
});
4242
program_helpers::do_for_types<permute>(*node, [](permute_node& node){
43+
return;
4344
auto impl_params = node.get_kernel_impl_params();
4445
if (node.is_output() ||
4546
node.has_fused_primitives() ||
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "intel_gpu/plugin/multi_tensor_variable_state.hpp"
6+
#include "intel_gpu/plugin/variable_state.hpp"
7+
#include "intel_gpu/runtime/debug_configuration.hpp"
8+
#include "intel_gpu/runtime/memory.hpp"
9+
#include "multi_stage_primitive.hpp"
10+
11+
#include "paged_attention_inst.h"
12+
#include "paged_attention/paged_attention_kernel_selector.hpp"
13+
#include "paged_attention/kv_cache_update_kernel_ref.hpp"
14+
15+
namespace cldnn {
16+
namespace ocl {
17+
18+
struct paged_attention_impl : multi_stage_primitive<paged_attention> {
19+
using parent = multi_stage_primitive<paged_attention>;
20+
using parent::parent;
21+
using kv_cache_update_kernel_selector_t = kernel_selector::kv_cache_update_kernel_selector;
22+
using kv_cache_update_kernel_params_t = kernel_selector::kv_cache_update_update_params;
23+
24+
using sdpa_kernel_selector_t = kernel_selector::sdpa_kernel_selector;
25+
using sdpa_kernel_params_t = kernel_selector::kv_cache_update_kernel_selector;
26+
27+
DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::paged_attention_impl)
28+
29+
std::unique_ptr<primitive_impl> clone() const override {
30+
return make_unique<paged_attention_impl>(*this);
31+
}
32+
33+
enum Stage {
34+
concat,
35+
sdpa
36+
};
37+
38+
cldnn::memory::ptr beam_table_prev = nullptr;
39+
cldnn::memory::ptr beam_table_new = nullptr;
40+
41+
void load(BinaryInputBuffer& ib) override {
42+
parent::load(ib);
43+
if (is_dynamic()) {
44+
OPENVINO_THROW("[GPU] Unimplemented load func");
45+
// auto& kernel_selector = kv_cache_update_kernel_selector_t::Instance();
46+
// auto kernel_impl = kernel_selector.GetImplementation(_kernels_data[concat_stage].kernelName);
47+
// kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[concat_stage]);
48+
// if (_kernels_data.size() == 2) {
49+
// auto& bt_kernel_selector = sdpa_kernel_selector_t::Instance();
50+
// auto bt_kernel_impl = bt_kernel_selector.GetImplementation(_kernels_data[beam_table_stage].kernelName);
51+
// bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[beam_table_stage]);
52+
// }
53+
}
54+
}
55+
void set_arguments_impl(paged_attention_inst& instance) override {}
56+
57+
kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage) const override {
58+
kernel_arguments_data args;
59+
args.shape_info = instance.shape_info_memory_ptr();
60+
if (stage == Stage::concat) {
61+
args.inputs = { instance.input_memory_ptr(1),
62+
instance.input_memory_ptr(2),
63+
instance.input_memory_ptr(6) };
64+
args.outputs = { instance.input_memory_ptr(3), instance.input_memory_ptr(4) };
65+
} else if (stage == Stage::sdpa) {
66+
args.inputs = { beam_table_prev, instance.input_memory_ptr(2) };
67+
args.outputs = { beam_table_new };
68+
}
69+
70+
return args;
71+
}
72+
73+
void execute_stage(const std::vector<event::ptr>& events, paged_attention_inst& instance, std::vector<event::ptr>& all_events, size_t stage) {
74+
stream& stream = instance.get_network().get_stream();
75+
std::vector<event::ptr> tmp_events(events);
76+
size_t kernel_offset = 0;
77+
for (size_t s = 0; s < stage; s++) {
78+
kernel_offset += _kernels_data[s].kernels.size();
79+
}
80+
for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) {
81+
if (_kernels_data[stage].kernels[kd_idx].skip_execution)
82+
continue;
83+
84+
size_t idx_final = kernel_offset + kd_idx;
85+
// If any user of the prim's users is CPU implementation or network's output, set prim as a output event (event won't be nullptr)
86+
bool needs_completion_event = instance.needs_completion_event();
87+
88+
auto& params = _kernels_data[stage].kernels[kd_idx].params;
89+
auto args = get_arguments(instance, stage);
90+
args.scalars = &params.scalars;
91+
92+
for (const auto& m : instance.get_intermediates_memories()) {
93+
args.intermediates.push_back(m);
94+
}
95+
96+
stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args);
97+
98+
const auto& gws = params.workGroups.global;
99+
const auto& lws = params.workGroups.local;
100+
101+
GPU_DEBUG_TRACE_DETAIL << "Enqueue stage " << stage << " kernel " << idx_final << ": gws=[" << gws[0] << ", " << gws[1] << ", " << gws[2] << "] "
102+
<< "lws=[" << lws[0] << ", " << lws[1] << ", " << lws[2] << "]"
103+
<< (needs_completion_event ? " has_completion_event=true" : "") << std::endl;
104+
105+
auto ev = stream.enqueue_kernel(*_kernels[idx_final], params, args, tmp_events, needs_completion_event);
106+
if (_kernels_data[stage].needs_sub_kernels_sync) {
107+
tmp_events = {ev};
108+
}
109+
all_events.push_back(ev);
110+
}
111+
}
112+
113+
event::ptr execute_impl(const std::vector<event::ptr>& events, paged_attention_inst& instance) override {
114+
auto& stream = instance.get_network().get_stream();
115+
std::vector<event::ptr> res_events;
116+
117+
execute_stage(events, instance, res_events, Stage::concat);
118+
119+
return aggregate_events(res_events, stream, res_events.size() > 1);
120+
}
121+
122+
static layout get_beam_table_layout(const kernel_impl_params& impl_param) {
123+
const auto& primitive = impl_param.typed_desc<paged_attention>();
124+
auto kv_layout = impl_param.get_input_layout(0);
125+
126+
// // expected to be normalized already on primitive creation
127+
// auto concat_axis = primitive->concat_axis;
128+
// auto gather_axis = primitive->gather_axis;
129+
130+
// auto kv_shape = kv_layout.get_partial_shape();
131+
// auto beam_table_shape = ov::PartialShape(std::vector<size_t>(kv_shape.size(), 1));
132+
// beam_table_shape[gather_axis] = kv_shape[gather_axis];
133+
// beam_table_shape[concat_axis] = kv_shape[concat_axis];
134+
return kv_layout;
135+
}
136+
137+
static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
138+
kv_cache_update_kernel_params_t params;
139+
set_params(impl_param, params);
140+
141+
auto query = impl_param.get_input_layout(0);
142+
auto key = impl_param.get_input_layout(1);
143+
auto value = impl_param.get_input_layout(2);
144+
auto key_cache = impl_param.get_input_layout(3);
145+
auto value_cache = impl_param.get_input_layout(4);
146+
auto slot_mapping = impl_param.get_input_layout(6);
147+
148+
// query_shape = [batch_size, seq_len, num_heads * head_size]
149+
// key_shape, value_shape = [batch_size, seq_len, num_kv_heads * head_size]
150+
// key_cache_shape = [num_blocks, num_kv_heads, head_size/x, block_size, x]
151+
// value_cache_shape = [num_blocks, num_kv_heads, head_size, block_size]
152+
// const auto query_shape = query.get_shape();
153+
// const auto key_shape = key.get_shape();
154+
// const auto key_cache_shape = key_cache.get_shape();
155+
// const auto value_cache_shape = value_cache.get_shape();
156+
// const size_t batch_size = query_shape[0];
157+
// const size_t seq_len = query_shape[1];
158+
// const size_t hidden_size = query_shape[2];
159+
// const size_t num_kv_heads = value_cache_shape[1];
160+
// const size_t head_size = value_cache_shape[2];
161+
// const size_t num_heads = hidden_size / head_size;
162+
// const size_t block_size = value_cache_shape[3];
163+
// const size_t x = key_cache_shape[4];
164+
// const size_t num_tokens = key_shape[0];
165+
166+
// Reshape from [batch_size, seq_len, num_heads * head_size] to [batch_size, seq_len, num_heads, head_size]
167+
// query.set_partial_shape({batch_size, seq_len, num_heads, head_size});
168+
// key.set_partial_shape({batch_size, seq_len, num_kv_heads, head_size});
169+
// value.set_partial_shape(key.get_shape());
170+
171+
params.is_shape_agnostic = is_shape_agnostic;
172+
params.stage_id = 0;
173+
params.inputs.resize(3);
174+
params.outputs.resize(2);
175+
params.inputs[0] = convert_data_tensor(key);
176+
params.inputs[1] = convert_data_tensor(value);
177+
params.inputs[2] = convert_data_tensor(slot_mapping);
178+
params.outputs[0] = convert_data_tensor(key_cache);
179+
params.outputs[1] = convert_data_tensor(value_cache);
180+
params.layerID = impl_param.desc->id;
181+
182+
// const auto inputs_count = 2;
183+
// params.inputs.resize(inputs_count);
184+
// for (size_t i = 0; i < inputs_count; ++i) {
185+
// params.inputs[i] = convert_data_tensor(impl_param.input_layouts[i]);
186+
// }
187+
188+
// params.axis = convert_axis(axis, impl_param.get_output_layout().get_rank());
189+
// params.kernelPerInput = true;
190+
191+
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
192+
std::map<size_t, size_t> in_tensor_to_offset_map = {
193+
{0, in_offsets_map.at(1)},
194+
{1, in_offsets_map.at(2)},
195+
{2, in_offsets_map.at(6)},
196+
};
197+
std::map<size_t, size_t> out_tensor_to_offset_map = {
198+
{0, in_offsets_map.at(3)},
199+
{1, in_offsets_map.at(4)},
200+
};
201+
202+
params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);
203+
204+
return params;
205+
}
206+
207+
static sdpa_kernel_params_t get_bt_update_kernel_params(const kernel_impl_params& impl_param, bool is_state_set = false) {
208+
// auto params = get_default_params<kernel_selector::beam_table_update_params>(impl_param, true);
209+
210+
// auto inputs_count = 2;
211+
// auto bt_present_layout = impl_param.output_layouts[1];
212+
// auto bt_shape = extend_shape_to_rank_from_end(bt_present_layout.get_partial_shape(), 1);
213+
// bt_present_layout.set_partial_shape(bt_shape);
214+
// layout bt_past_layout = get_beam_table_layout(impl_param);
215+
216+
// auto beam_idx_l = impl_param.input_layouts[2];
217+
// beam_idx_l.set_partial_shape(extend_shape_to_rank_from_end(beam_idx_l.get_partial_shape(), 4));
218+
219+
// params.inputs.resize(inputs_count);
220+
// params.inputs[0] = convert_data_tensor(bt_past_layout);
221+
// params.inputs[1] = convert_data_tensor(beam_idx_l);
222+
// params.outputs[0] = convert_data_tensor(bt_present_layout);
223+
// params.inputs.resize(inputs_count);
224+
// params.is_state_set = is_state_set;
225+
226+
// const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, beam_table_past]]
227+
// const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; // [kv_present, beam_table_present]
228+
// std::map<size_t, size_t> in_tensor_to_offset_map = {
229+
// {0, in_offsets_map.at(3)}, // beam_table_past
230+
// {1, in_offsets_map.at(2)}, // beam_idx
231+
// };
232+
// std::map<size_t, size_t> out_tensor_to_offset_map = {
233+
// {0, out_offsets_map.at(1)}, // beam_table_present
234+
// };
235+
236+
// params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);
237+
238+
return {};
239+
}
240+
241+
static std::unique_ptr<primitive_impl> create(const typed_program_node<paged_attention>& arg, const kernel_impl_params& impl_param) {
242+
std::vector<kernel_selector::kernel_data> kernels_data;
243+
auto concat_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic());
244+
auto& concat_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
245+
kernels_data.push_back(concat_kernel_selector.get_best_kernel(concat_kernel_params));
246+
247+
// SDPA
248+
// auto& concat_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
249+
// kernels_data.push_back(bt_update_kernel_selector.get_best_kernel(bt_update_kernel_params));
250+
//
251+
return cldnn::make_unique<paged_attention_impl>(kernels_data);
252+
}
253+
254+
void update_dispatch_data(const kernel_impl_params& impl_param) override {
255+
auto paged_attention_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic());
256+
(_kernels_data[Stage::concat].update_dispatch_data_func)(paged_attention_kernel_params, _kernels_data[Stage::concat]);
257+
// _kernels_data[concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(0).count() == 0;
258+
}
259+
};
260+
261+
namespace detail {
262+
263+
attach_paged_attention_impl::attach_paged_attention_impl() {
264+
auto types = { data_types::f16, data_types::f32 };
265+
auto formats = { format::bfyx };
266+
implementation_map<paged_attention>::add(impl_types::ocl,
267+
shape_types::dynamic_shape,
268+
paged_attention_impl::create,
269+
types,
270+
formats);
271+
272+
implementation_map<paged_attention>::add(impl_types::ocl,
273+
shape_types::static_shape,
274+
paged_attention_impl::create,
275+
types,
276+
formats);
277+
}
278+
279+
} // namespace detail
280+
} // namespace ocl
281+
} // namespace cldnn
282+
283+
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::paged_attention_impl)
284+
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::paged_attention)

src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ void register_implementations() {
4141
REGISTER_OCL(grid_sample);
4242
REGISTER_OCL(group_normalization);
4343
REGISTER_OCL(kv_cache);
44+
REGISTER_OCL(paged_attention);
4445
REGISTER_OCL(lrn);
4546
REGISTER_OCL(lstm_elt);
4647
REGISTER_OCL(multiclass_nms);

0 commit comments

Comments
 (0)