@@ -553,6 +553,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
553
553
params.outputs [1 ] = value_cache_tensor;
554
554
555
555
params.conf = get_sdpa_configuration (impl_param, is_dynamic);
556
+ if (ov::element::Type (impl_param.get_input_layout (3 ).data_type ).size () == 1 ) {
557
+ params.conf .is_kv_compressed = true ;
558
+ params.conf .use_asymmetric_quantization = true ;
559
+ }
556
560
557
561
params.is_prefill = stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED;
558
562
@@ -692,6 +696,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
692
696
params.inputs [input_idx++] = subsequence_begins_tensor;
693
697
694
698
params.conf = get_sdpa_configuration (impl_param, is_dynamic);
699
+ if (ov::element::Type (impl_param.get_input_layout (3 ).data_type ).size () == 1 ) {
700
+ params.conf .is_kv_compressed = true ;
701
+ params.conf .use_asymmetric_quantization = true ;
702
+ }
695
703
696
704
if (has_scale_input)
697
705
params.inputs [input_idx++] = scale_tensor;
@@ -779,28 +787,50 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
779
787
input_tensors.emplace_back (convert_data_tensor (input_layout));
780
788
781
789
const auto & desc = impl_param.typed_desc <paged_attention>();
782
- auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
783
- auto & kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance ();
784
- kernels_data.push_back (kv_cache_update_kernel_selector.get_best_kernel (kv_cache_update_kernel_params));
790
+ try {
791
+ auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
792
+ auto & kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance ();
793
+ kernels_data.push_back (kv_cache_update_kernel_selector.get_best_kernel (kv_cache_update_kernel_params));
794
+ } catch (std::exception & e) {
795
+ std::cout << " PagedAttention1 error: " << e.what () << " \n " ;
796
+ std::rethrow_exception (std::current_exception ());
797
+ }
785
798
786
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
787
- auto & sdpa_kernel_selector = sdpa_kernel_selector_t::Instance ();
788
- kernels_data.push_back (sdpa_kernel_selector.get_best_kernel (sdpa_kernel_params));
799
+ try {
800
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
801
+ auto & sdpa_kernel_selector = sdpa_kernel_selector_t::Instance ();
802
+ kernels_data.push_back (sdpa_kernel_selector.get_best_kernel (sdpa_kernel_params));
803
+ } catch (std::exception & e) {
804
+ std::cout << " PagedAttention2 error: " << e.what () << " \n " ;
805
+ std::rethrow_exception (std::current_exception ());
806
+ }
789
807
790
- auto pa_sdpa_kernel_params = get_pa_sdpa_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
791
- auto & pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance ();
792
- kernels_data.push_back (pa_sdpa_kernel_selector.get_best_kernel (pa_sdpa_kernel_params));
808
+ try {
809
+ auto pa_sdpa_kernel_params = get_pa_sdpa_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
810
+ auto & pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance ();
811
+ kernels_data.push_back (pa_sdpa_kernel_selector.get_best_kernel (pa_sdpa_kernel_params));
812
+ } catch (std::exception & e) {
813
+ std::cout << " PagedAttention3 error: " << e.what () << " \n " ;
814
+ std::rethrow_exception (std::current_exception ());
815
+ }
793
816
794
- if (desc->has_rotated_blocks ) {
795
- auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params (impl_param, input_tensors, impl_param.is_dynamic ());
796
- auto & kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance ();
797
- kernels_data.push_back (kv_cache_rotate_kernel_selector.get_best_kernel (kv_cache_rotate_kernel_params));
817
+ try {
818
+ if (desc->has_rotated_blocks ) {
819
+ auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params (impl_param, input_tensors, impl_param.is_dynamic ());
820
+ auto & kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance ();
821
+ kernels_data.push_back (kv_cache_rotate_kernel_selector.get_best_kernel (kv_cache_rotate_kernel_params));
822
+ }
823
+ } catch (std::exception & e) {
824
+ std::cout << " PagedAttention4 error: " << e.what () << " \n " ;
825
+ std::rethrow_exception (std::current_exception ());
798
826
}
799
827
828
+
800
829
auto impl = std::make_unique<paged_attention_impl>(kernels_data);
801
830
impl->has_scores_output = desc->has_scores_output ();
802
831
impl->has_rotated_blocks = desc->has_rotated_blocks ;
803
832
833
+
804
834
return impl;
805
835
}
806
836
0 commit comments