Skip to content

Commit a517444

Browse files
committed
WIP: [GPU] PagedAttention OCL kernel kv-cache compression
1 parent 70ec531 commit a517444

File tree

7 files changed

+401
-29
lines changed

7 files changed

+401
-29
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+43-13
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
553553
params.outputs[1] = value_cache_tensor;
554554

555555
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
556+
if (ov::element::Type(impl_param.get_input_layout(3).data_type).size() == 1) {
557+
params.conf.is_kv_compressed = true;
558+
params.conf.use_asymmetric_quantization = true;
559+
}
556560

557561
params.is_prefill = stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED;
558562

@@ -692,6 +696,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
692696
params.inputs[input_idx++] = subsequence_begins_tensor;
693697

694698
params.conf = get_sdpa_configuration(impl_param, is_dynamic);
699+
if (ov::element::Type(impl_param.get_input_layout(3).data_type).size() == 1) {
700+
params.conf.is_kv_compressed = true;
701+
params.conf.use_asymmetric_quantization = true;
702+
}
695703

696704
if (has_scale_input)
697705
params.inputs[input_idx++] = scale_tensor;
@@ -779,28 +787,50 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
779787
input_tensors.emplace_back(convert_data_tensor(input_layout));
780788

781789
const auto& desc = impl_param.typed_desc<paged_attention>();
782-
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
783-
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
784-
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
790+
try {
791+
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
792+
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
793+
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
794+
} catch (std::exception& e) {
795+
std::cout << "PagedAttention1 error: " << e.what() << "\n";
796+
std::rethrow_exception(std::current_exception());
797+
}
785798

786-
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
787-
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
788-
kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params));
799+
try {
800+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
801+
auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance();
802+
kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params));
803+
} catch (std::exception& e) {
804+
std::cout << "PagedAttention2 error: " << e.what() << "\n";
805+
std::rethrow_exception(std::current_exception());
806+
}
789807

790-
auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
791-
auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance();
792-
kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params));
808+
try {
809+
auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
810+
auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance();
811+
kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params));
812+
} catch (std::exception& e) {
813+
std::cout << "PagedAttention3 error: " << e.what() << "\n";
814+
std::rethrow_exception(std::current_exception());
815+
}
793816

794-
if (desc->has_rotated_blocks) {
795-
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic());
796-
auto& kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance();
797-
kernels_data.push_back(kv_cache_rotate_kernel_selector.get_best_kernel(kv_cache_rotate_kernel_params));
817+
try {
818+
if (desc->has_rotated_blocks) {
819+
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic());
820+
auto& kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance();
821+
kernels_data.push_back(kv_cache_rotate_kernel_selector.get_best_kernel(kv_cache_rotate_kernel_params));
822+
}
823+
} catch (std::exception& e) {
824+
std::cout << "PagedAttention4 error: " << e.what() << "\n";
825+
std::rethrow_exception(std::current_exception());
798826
}
799827

828+
800829
auto impl = std::make_unique<paged_attention_impl>(kernels_data);
801830
impl->has_scores_output = desc->has_scores_output();
802831
impl->has_rotated_blocks = desc->has_rotated_blocks;
803832

833+
804834
return impl;
805835
}
806836

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1947,6 +1947,9 @@ void primitive_inst::prepare_primitive() {
19471947
void primitive_inst::execute() {
19481948
GPU_DEBUG_PROFILED_STAGE(instrumentation::pipeline_stage::inference);
19491949
if (get_flag(ExecutionFlags::SKIP)) {
1950+
if (_node->is_type<read_value>())
1951+
get_network().get_stream().finish();
1952+
19501953
set_out_event(get_network().get_stream().aggregate_events(_impl_params->dep_events));
19511954
return;
19521955
}

0 commit comments

Comments
 (0)