Skip to content

Commit e64244f

Browse files
committed
WIP: [GPU] Use micro-sdpa for 1st token calculation of PagedAttention
1 parent 2ea26d7 commit e64244f

File tree

8 files changed

+256
-78
lines changed

8 files changed

+256
-78
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/compile_graph.cpp

+21-33
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@ void compile_graph::run(program& p) {
2424
}
2525
}
2626

27-
auto task_executor = p.get_task_executor();
2827
auto& proc_order = p.get_processing_order();
29-
std::vector<ov::threading::Task> tasks;
30-
std::exception_ptr exception;
3128

3229
for (size_t idx = 0; idx < proc_order.size(); idx++) {
3330
auto& node = *(std::next(proc_order.begin(), idx));
@@ -36,37 +33,28 @@ void compile_graph::run(program& p) {
3633
!(node->is_type<mutable_data>() && node->get_dependencies().empty());
3734

3835
if (can_select_impl) {
39-
tasks.push_back([node, &exception] {
40-
try {
41-
const auto& params = node->get_kernel_impl_params();
42-
auto shape_type = ImplementationManager::get_shape_type(*params);
43-
auto selected_impl_manager = node->type()->choose_impl(*node, shape_type);
44-
std::string fail_reason = "";
45-
try {
46-
if (selected_impl_manager) {
47-
node->selected_impl = selected_impl_manager->create(*node, *params);
48-
}
49-
} catch (std::exception& e) {
50-
fail_reason = e.what();
51-
}
52-
53-
OPENVINO_ASSERT(shape_type == shape_types::dynamic_shape || node->selected_impl != nullptr,
54-
"[GPU] Failed to select implementation for"
55-
"\nname:", node->id(),
56-
"\ntype: ", node->get_primitive()->type_string(),
57-
"\noriginal_type: ", node->get_primitive()->origin_op_type_name,
58-
(!fail_reason.empty() ? fail_reason : ""));
59-
} catch(...) {
60-
exception = std::current_exception();
36+
// std::cout << "Compiling " << node->id() << "\n";
37+
// if (idx + 1 < proc_order.size())
38+
// std::cout << "Compiling next id " << (*(std::next(proc_order.begin(), idx + 1)))->id() << "\n";
39+
40+
const auto& params = node->get_kernel_impl_params();
41+
auto shape_type = ImplementationManager::get_shape_type(*params);
42+
auto selected_impl_manager = node->type()->choose_impl(*node, shape_type);
43+
std::string fail_reason = "";
44+
try {
45+
if (selected_impl_manager) {
46+
node->selected_impl = selected_impl_manager->create(*node, *params);
6147
}
62-
});
48+
} catch (std::exception& e) {
49+
fail_reason = e.what();
50+
}
51+
52+
OPENVINO_ASSERT(shape_type == shape_types::dynamic_shape || node->selected_impl != nullptr,
53+
"[GPU] Failed to select implementation for"
54+
"\nname:", node->id(),
55+
"\ntype: ", node->get_primitive()->type_string(),
56+
"\noriginal_type: ", node->get_primitive()->origin_op_type_name,
57+
(!fail_reason.empty() ? fail_reason : ""));
6358
}
6459
}
65-
66-
task_executor->run_and_wait(tasks);
67-
tasks.clear();
68-
69-
if (exception) {
70-
std::rethrow_exception(exception);
71-
}
7260
}

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+70-11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "sdpa/pa_kv_cache_rotate_kernel_ref.h"
1616
#include "sdpa/pa_kv_cache_update_kernel_ref.h"
1717
#include "sdpa/pa_sdpa_kernel_opt.h"
18+
#include "sdpa/sdpa_kernel_micro.h"
1819

1920
namespace cldnn {
2021
namespace ocl {
@@ -66,6 +67,33 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
6667
return stage == PagedAttentionStage::MIXED;
6768
}
6869

70+
void update_inst_params(primitive_inst& inst) const override {
71+
OPENVINO_ASSERT(inst.type() == paged_attention::type_id());
72+
OPENVINO_ASSERT(inst.get_impl() == this);
73+
74+
auto& pa_inst = reinterpret_cast<paged_attention_inst&>(inst);
75+
if (is_micro_kernel_used) {
76+
auto tile_q_size = get_target_seq_len_block_size(PagedAttentionStage::PREFILL);
77+
pa_inst.tile_q_size = tile_q_size;
78+
std::cout << "update_inst_params: from micro-sdpa tile_q_size = " << tile_q_size << "\n";
79+
} else {
80+
pa_inst.tile_q_size = get_target_seq_len_block_size(PagedAttentionStage::PREFILL);
81+
std::cout << "update_inst_params: sdpa_opt tile_q_size = " << get_target_seq_len_block_size(PagedAttentionStage::PREFILL) << "\n";
82+
}
83+
}
84+
85+
size_t get_target_seq_len_block_size(const PagedAttentionStage& stage) const {
86+
if (stage == PagedAttentionStage::PREFILL) {
87+
if (is_micro_kernel_used) {
88+
return kernel_selector::SDPAKernelMicro::GetTileQSize(_kernels_data[Stage::SDPA]);
89+
} else {
90+
return 16;
91+
}
92+
} else {
93+
return 16;
94+
}
95+
}
96+
6997
void load(BinaryInputBuffer& ib) override {
7098
parent::load(ib);
7199
ib >> make_data(&has_scores_output, sizeof(bool));
@@ -527,7 +555,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
527555
static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param,
528556
const PagedAttentionStage& stage,
529557
const kernel_selector::MultiDataTensor& input_tensors,
530-
bool is_dynamic = false) {
558+
int64_t target_seq_len_block_size,
559+
bool is_dynamic) {
531560
auto params = get_default_params<kv_cache_update_kernel_params_t>(impl_param, is_dynamic);
532561

533562
const auto& key_tensor = input_tensors[1];
@@ -557,7 +586,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
557586
params.is_prefill = stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED;
558587

559588
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic)
560-
params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage);
589+
params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage, target_seq_len_block_size);
561590

562591
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
563592
std::map<size_t, size_t> in_tensor_to_offset_map = {
@@ -581,13 +610,31 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
581610
static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param,
582611
const PagedAttentionStage& stage,
583612
const kernel_selector::MultiDataTensor& input_tensors,
584-
bool is_dynamic = false) {
613+
int64_t target_seq_len_block_size,
614+
bool is_dynamic) {
585615
const auto desc = impl_param.typed_desc<paged_attention>();
586616
auto params = get_default_params<sdpa_kernel_params_t>(impl_param, is_dynamic);
587617

588-
const auto& query_tensor = input_tensors[0];
589-
const auto& key_tensor = input_tensors[1];
590-
const auto& value_tensor = input_tensors[2];
618+
auto get_sdpa_tensor = [&](const layout& input_layout, size_t head_size) {
619+
auto new_layout = input_layout;
620+
auto orig_shape = new_layout.get_partial_shape();
621+
auto new_shape = ov::PartialShape::dynamic(4);
622+
623+
new_shape[0] = 1;
624+
new_shape[1] = orig_shape[0];
625+
new_shape[2] = orig_shape[1] / head_size;
626+
new_shape[3] = head_size;
627+
628+
new_layout.set_partial_shape(new_shape);
629+
630+
std::cout << "Convert layout: " << input_layout.to_short_string() << " -> " << new_layout.to_short_string() << "\n";
631+
632+
return convert_data_tensor(new_layout);
633+
};
634+
635+
const auto query_tensor = get_sdpa_tensor(impl_param.get_input_layout(0), desc->head_size);
636+
const auto key_tensor = get_sdpa_tensor(impl_param.get_input_layout(1), desc->head_size);;
637+
const auto value_tensor = get_sdpa_tensor(impl_param.get_input_layout(2), desc->head_size);;
591638
const auto& subsequence_begins_tensor = input_tensors[6];
592639
const auto& scale_tensor = input_tensors[9];
593640
const auto& alibi_tensor = input_tensors[11];
@@ -616,12 +663,17 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
616663
if (has_alibi)
617664
params.inputs[input_idx++] = alibi_tensor;
618665

666+
params.outputs[0] = get_sdpa_tensor(impl_param.get_output_layout(0), desc->head_size);;
619667
if (has_scores_output) {
620668
params.outputs.resize(2);
621669
params.outputs[1] = convert_data_tensor(impl_param.get_output_layout(1));
622670
}
623671

624672
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
673+
params.input0_order = {0, 2, 1, 3};
674+
params.input1_order = {0, 2, 1, 3};
675+
params.input2_order = {0, 2, 1, 3};
676+
params.output_order = {0, 1, 2, 3};
625677

626678
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
627679
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;
@@ -643,7 +695,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
643695
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)});
644696

645697
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic)
646-
params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage);
698+
params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage, target_seq_len_block_size);
647699

648700
if (has_scores_output)
649701
out_tensor_to_offset_map.insert({1, out_offsets_map.at(1)});
@@ -756,11 +808,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
756808
(_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func)(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]);
757809
}
758810

759-
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
811+
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, get_target_seq_len_block_size(stage), impl_param.is_dynamic());
760812
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
761813

762814
if (stage == PagedAttentionStage::PREFILL) {
763-
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
815+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, get_target_seq_len_block_size(stage), impl_param.is_dynamic());
764816
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
765817
}
766818

@@ -779,11 +831,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
779831
input_tensors.emplace_back(convert_data_tensor(input_layout));
780832

781833
const auto& desc = impl_param.typed_desc<paged_attention>();
782-
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
834+
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, 0, impl_param.is_dynamic());
783835
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
784836
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
785837

786-
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
838+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, 0, impl_param.is_dynamic());
787839
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
788840
kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params));
789841

@@ -801,12 +853,19 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
801853
impl->has_scores_output = desc->has_scores_output();
802854
impl->has_rotated_blocks = desc->has_rotated_blocks;
803855

856+
if (!kernels_data[Stage::SDPA].kernels[0].micro_kernels.empty()) {
857+
std::cout << "Micro SDPA is choosen!\n";
858+
std::cout << "tile_q_size = " << kernel_selector::SDPAKernelMicro::GetTileQSize(kernels_data[Stage::SDPA]) << "\n";
859+
impl->is_micro_kernel_used = true;
860+
}
861+
804862
return impl;
805863
}
806864

807865
private:
808866
bool has_scores_output = false;
809867
bool has_rotated_blocks = false;
868+
bool is_micro_kernel_used = false;
810869
};
811870

812871
namespace detail {

src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
6262
memory::ptr rotation_deltas_memory_ptr() const { return input_memory_ptr(14); }
6363
memory::ptr rotation_trig_lut_memory_ptr() const { return input_memory_ptr(15); }
6464

65-
std::shared_ptr<network> prefill_network;
65+
size_t tile_q_size = 0;
6666

6767
protected:
6868
void on_execute() override;

src/plugins/intel_gpu/src/graph/include/primitive_inst.h

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct primitive_impl {
5959
virtual std::set<size_t> get_lockable_internal_buffers() const { return {}; }
6060
virtual void set_node_params(const program_node&) {}
6161
virtual const std::string& get_type_info() const = 0;
62+
virtual void update_inst_params(primitive_inst& instance) const {}
6263
virtual void set_arguments(primitive_inst& instance) = 0;
6364
virtual void set_arguments(primitive_inst& instance, kernel_arguments_data& args) = 0;
6465
virtual event::ptr execute(const std::vector<event::ptr>& events, primitive_inst& instance) = 0;

src/plugins/intel_gpu/src/graph/paged_attention.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ void paged_attention_inst::on_execute() {
114114
(stage == PagedAttentionStage::GENERATE && !has_scores_output))
115115
return;
116116

117+
OPENVINO_ASSERT(_impl != nullptr, "[GPU] impl shouldn't be nullptr");
118+
_impl->update_inst_params(*this);
119+
117120
auto& stream = get_network().get_stream();
118121
const auto past_lens_mem = past_lens_memory_ptr();
119122
const auto subsequence_begins_mem = subsequence_begins_memory_ptr();
@@ -179,7 +182,7 @@ void paged_attention_inst::on_execute() {
179182

180183
size_t index = 0;
181184
size_t subsequence_offsets_acc = 0;
182-
const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl
185+
const auto target_seq_len_block_size = static_cast<int>(tile_q_size);
183186
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
184187
const auto past_len = past_lens_mem_lock[i];
185188
const auto seq_start = subsequence_begins_mem_lock[i];

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl

+26-1
Original file line numberDiff line numberDiff line change
@@ -143,26 +143,51 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
143143
const global QRY_DATA_T *Q,
144144
const global VAL_DATA_T *V,
145145
global half *A,
146+
#if IS_PAGED_ATTENTION
147+
const __global INPUT3_TYPE* subsequence_begins,
148+
#endif
146149
#if WITH_ATTN_MASK
147150
const global half *msk,
148151
#endif
149152
#if WITH_SCALE
150153
global SCALE_DATA_T *scale_ptr,
151154
#endif
152-
int d, int k, int q
155+
int d,
156+
#if IS_PAGED_ATTENTION
157+
const __global int* blocked_indexes_start,
158+
const __global int* blocked_indexes_end,
159+
const __global int* gws_seq_indexes_correspondence
160+
#else
161+
int k,
162+
int q
163+
#endif
153164
#ifdef KV_COMPRESSED
154165
, const global KEY_ATTR_SCALES_DATA_T *K_scales
155166
, const global KEY_ATTR_ZP_DATA_T *K_zp
156167
, const global VAL_ATTR_SCALES_DATA_T *V_scales
157168
, const global VAL_ATTR_ZP_DATA_T *V_zp
158169
#endif
159170
) {
171+
#if IS_PAGED_ATTENTION
172+
const uint q_tile_idx = get_group_id(0);
173+
const uint block_start_pos = blocked_indexes_start[q_tile_idx];
174+
const uint block_end_pos = blocked_indexes_end[q_tile_idx];
175+
const uint subsequence_q_tile_idx = block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[q_tile_idx]];
176+
// const uint sequence_idx_end = block_end_pos - block_start_pos;
177+
const uint subsequence_begin = subsequence_begins[gws_seq_indexes_correspondence[q_tile_idx]];
178+
const int k = subsequence_begins[gws_seq_indexes_correspondence[q_tile_idx] + 1] - subsequence_begins[gws_seq_indexes_correspondence[q_tile_idx]];
179+
const int q = k;
180+
#endif
160181
uint sg_ij = sub_group_broadcast(get_local_id(1), 0);
161182
uint b0 = get_group_id(1);
162183
uint b1 = get_group_id(2);
163184
uint b0_kv = b0 / KV_GROUP_SIZE;
164185

186+
#if IS_PAGED_ATTENTION
187+
uint wg_j0 = subsequence_q_tile_idx;
188+
#else
165189
uint wg_j0 = get_group_id(0) * ugemm_kq_wg_tile_n;
190+
#endif
166191

167192
/* Leading dimension for matrices */
168193
uint ldk = TRANSPOSE_K ? KEY_S3 : KEY_S2;

0 commit comments

Comments
 (0)