15
15
#include " sdpa/pa_kv_cache_rotate_kernel_ref.h"
16
16
#include " sdpa/pa_kv_cache_update_kernel_ref.h"
17
17
#include " sdpa/pa_sdpa_kernel_opt.h"
18
+ #include " sdpa/sdpa_kernel_micro.h"
18
19
19
20
namespace cldnn {
20
21
namespace ocl {
@@ -66,6 +67,33 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
66
67
return stage == PagedAttentionStage::MIXED;
67
68
}
68
69
70
+ void update_inst_params (primitive_inst& inst) const override {
71
+ OPENVINO_ASSERT (inst.type () == paged_attention::type_id ());
72
+ OPENVINO_ASSERT (inst.get_impl () == this );
73
+
74
+ auto & pa_inst = reinterpret_cast <paged_attention_inst&>(inst);
75
+ if (is_micro_kernel_used) {
76
+ auto tile_q_size = get_target_seq_len_block_size (PagedAttentionStage::PREFILL);
77
+ pa_inst.tile_q_size = tile_q_size;
78
+ std::cout << " update_inst_params: from micro-sdpa tile_q_size = " << tile_q_size << " \n " ;
79
+ } else {
80
+ pa_inst.tile_q_size = get_target_seq_len_block_size (PagedAttentionStage::PREFILL);
81
+ std::cout << " update_inst_params: sdpa_opt tile_q_size = " << get_target_seq_len_block_size (PagedAttentionStage::PREFILL) << " \n " ;
82
+ }
83
+ }
84
+
85
+ size_t get_target_seq_len_block_size (const PagedAttentionStage& stage) const {
86
+ if (stage == PagedAttentionStage::PREFILL) {
87
+ if (is_micro_kernel_used) {
88
+ return kernel_selector::SDPAKernelMicro::GetTileQSize (_kernels_data[Stage::SDPA]);
89
+ } else {
90
+ return 16 ;
91
+ }
92
+ } else {
93
+ return 16 ;
94
+ }
95
+ }
96
+
69
97
void load (BinaryInputBuffer& ib) override {
70
98
parent::load (ib);
71
99
ib >> make_data (&has_scores_output, sizeof (bool ));
@@ -527,7 +555,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
527
555
static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params (const kernel_impl_params& impl_param,
528
556
const PagedAttentionStage& stage,
529
557
const kernel_selector::MultiDataTensor& input_tensors,
530
- bool is_dynamic = false ) {
558
+ int64_t target_seq_len_block_size,
559
+ bool is_dynamic) {
531
560
auto params = get_default_params<kv_cache_update_kernel_params_t >(impl_param, is_dynamic);
532
561
533
562
const auto & key_tensor = input_tensors[1 ];
@@ -557,7 +586,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
557
586
params.is_prefill = stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED;
558
587
559
588
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic)
560
- params.conf .paged_attention_aligned_seq_len = get_aligned_seq_len (impl_param, stage);
589
+ params.conf .paged_attention_aligned_seq_len = get_aligned_seq_len (impl_param, stage, target_seq_len_block_size );
561
590
562
591
const auto & in_offsets_map = impl_param.in_port_to_shape_info_offset ;
563
592
std::map<size_t , size_t > in_tensor_to_offset_map = {
@@ -581,13 +610,31 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
581
610
static sdpa_kernel_params_t get_sdpa_kernel_params (const kernel_impl_params& impl_param,
582
611
const PagedAttentionStage& stage,
583
612
const kernel_selector::MultiDataTensor& input_tensors,
584
- bool is_dynamic = false ) {
613
+ int64_t target_seq_len_block_size,
614
+ bool is_dynamic) {
585
615
const auto desc = impl_param.typed_desc <paged_attention>();
586
616
auto params = get_default_params<sdpa_kernel_params_t >(impl_param, is_dynamic);
587
617
588
- const auto & query_tensor = input_tensors[0 ];
589
- const auto & key_tensor = input_tensors[1 ];
590
- const auto & value_tensor = input_tensors[2 ];
618
+ auto get_sdpa_tensor = [&](const layout& input_layout, size_t head_size) {
619
+ auto new_layout = input_layout;
620
+ auto orig_shape = new_layout.get_partial_shape ();
621
+ auto new_shape = ov::PartialShape::dynamic (4 );
622
+
623
+ new_shape[0 ] = 1 ;
624
+ new_shape[1 ] = orig_shape[0 ];
625
+ new_shape[2 ] = orig_shape[1 ] / head_size;
626
+ new_shape[3 ] = head_size;
627
+
628
+ new_layout.set_partial_shape (new_shape);
629
+
630
+ std::cout << " Convert layout: " << input_layout.to_short_string () << " -> " << new_layout.to_short_string () << " \n " ;
631
+
632
+ return convert_data_tensor (new_layout);
633
+ };
634
+
635
+ const auto query_tensor = get_sdpa_tensor (impl_param.get_input_layout (0 ), desc->head_size );
636
+ const auto key_tensor = get_sdpa_tensor (impl_param.get_input_layout (1 ), desc->head_size );;
637
+ const auto value_tensor = get_sdpa_tensor (impl_param.get_input_layout (2 ), desc->head_size );;
591
638
const auto & subsequence_begins_tensor = input_tensors[6 ];
592
639
const auto & scale_tensor = input_tensors[9 ];
593
640
const auto & alibi_tensor = input_tensors[11 ];
@@ -616,12 +663,17 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
616
663
if (has_alibi)
617
664
params.inputs [input_idx++] = alibi_tensor;
618
665
666
+ params.outputs [0 ] = get_sdpa_tensor (impl_param.get_output_layout (0 ), desc->head_size );;
619
667
if (has_scores_output) {
620
668
params.outputs .resize (2 );
621
669
params.outputs [1 ] = convert_data_tensor (impl_param.get_output_layout (1 ));
622
670
}
623
671
624
672
params.conf = get_sdpa_configuration (impl_param, is_dynamic);
673
+ params.input0_order = {0 , 2 , 1 , 3 };
674
+ params.input1_order = {0 , 2 , 1 , 3 };
675
+ params.input2_order = {0 , 2 , 1 , 3 };
676
+ params.output_order = {0 , 1 , 2 , 3 };
625
677
626
678
const auto & in_offsets_map = impl_param.in_port_to_shape_info_offset ;
627
679
const auto & out_offsets_map = impl_param.out_port_to_shape_info_offset ;
@@ -643,7 +695,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
643
695
in_tensor_to_offset_map.insert ({input_idx++, in_offsets_map.at (11 )});
644
696
645
697
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic)
646
- params.conf .paged_attention_aligned_seq_len = get_aligned_seq_len (impl_param, stage);
698
+ params.conf .paged_attention_aligned_seq_len = get_aligned_seq_len (impl_param, stage, target_seq_len_block_size );
647
699
648
700
if (has_scores_output)
649
701
out_tensor_to_offset_map.insert ({1 , out_offsets_map.at (1 )});
@@ -756,11 +808,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
756
808
(_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func )(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]);
757
809
}
758
810
759
- auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
811
+ auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, get_target_seq_len_block_size (stage), impl_param.is_dynamic ());
760
812
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func )(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
761
813
762
814
if (stage == PagedAttentionStage::PREFILL) {
763
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
815
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, get_target_seq_len_block_size (stage), impl_param.is_dynamic ());
764
816
(_kernels_data[Stage::SDPA].update_dispatch_data_func )(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
765
817
}
766
818
@@ -779,11 +831,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
779
831
input_tensors.emplace_back (convert_data_tensor (input_layout));
780
832
781
833
const auto & desc = impl_param.typed_desc <paged_attention>();
782
- auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
834
+ auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, stage, input_tensors, 0 , impl_param.is_dynamic ());
783
835
auto & kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance ();
784
836
kernels_data.push_back (kv_cache_update_kernel_selector.get_best_kernel (kv_cache_update_kernel_params));
785
837
786
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, impl_param.is_dynamic ());
838
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, stage, input_tensors, 0 , impl_param.is_dynamic ());
787
839
auto & sdpa_kernel_selector = sdpa_kernel_selector_t::Instance ();
788
840
kernels_data.push_back (sdpa_kernel_selector.get_best_kernel (sdpa_kernel_params));
789
841
@@ -801,12 +853,19 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
801
853
impl->has_scores_output = desc->has_scores_output ();
802
854
impl->has_rotated_blocks = desc->has_rotated_blocks ;
803
855
856
+ if (!kernels_data[Stage::SDPA].kernels [0 ].micro_kernels .empty ()) {
857
+ std::cout << " Micro SDPA is choosen!\n " ;
858
+ std::cout << " tile_q_size = " << kernel_selector::SDPAKernelMicro::GetTileQSize (kernels_data[Stage::SDPA]) << " \n " ;
859
+ impl->is_micro_kernel_used = true ;
860
+ }
861
+
804
862
return impl;
805
863
}
806
864
807
865
private:
808
866
bool has_scores_output = false ;
809
867
bool has_rotated_blocks = false ;
868
+ bool is_micro_kernel_used = false ;
810
869
};
811
870
812
871
namespace detail {
0 commit comments