@@ -69,7 +69,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
69
69
void load (BinaryInputBuffer& ib) override {
70
70
parent::load (ib);
71
71
ib >> make_data (&has_scores_output, sizeof (bool ));
72
- ib >> make_data (&has_rotation_coefficients , sizeof (bool ));
72
+ ib >> make_data (&has_rotated_blocks , sizeof (bool ));
73
73
if (is_dynamic ()) {
74
74
auto & kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance ();
75
75
auto kv_cache_update_kernel_impl = kv_cache_update_kernel_selector.GetImplementation (_kernels_data[Stage::KV_CACHE_UPDATE].kernelName );
@@ -83,7 +83,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
83
83
auto pa_sdpa_kernel_impl = pa_sdpa_kernel_selector.GetImplementation (_kernels_data[Stage::PA_SDPA].kernelName );
84
84
pa_sdpa_kernel_impl->GetUpdateDispatchDataFunc (_kernels_data[Stage::PA_SDPA]);
85
85
86
- if (has_rotation_coefficients ) {
86
+ if (has_rotated_blocks ) {
87
87
auto & kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance ();
88
88
auto kv_cache_rotate_kernel_impl = kv_cache_rotate_kernel_selector.GetImplementation (_kernels_data[Stage::KV_CACHE_ROTATE].kernelName );
89
89
kv_cache_rotate_kernel_impl->GetUpdateDispatchDataFunc (_kernels_data[Stage::KV_CACHE_ROTATE]);
@@ -94,7 +94,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
94
94
void save (BinaryOutputBuffer& ob) const override {
95
95
parent::save (ob);
96
96
ob << make_data (&has_scores_output, sizeof (bool ));
97
- ob << make_data (&has_rotation_coefficients , sizeof (bool ));
97
+ ob << make_data (&has_rotated_blocks , sizeof (bool ));
98
98
}
99
99
100
100
std::vector<layout> get_internal_buffer_layouts_impl () const override {
@@ -347,7 +347,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
347
347
348
348
std::vector<event::ptr> res_events;
349
349
std::vector<event::ptr> dep_events = events;
350
- if (has_rotation_coefficients ) {
350
+ if (has_rotated_blocks ) {
351
351
execute_stage (dep_events, instance, res_events, Stage::KV_CACHE_ROTATE, is_mixed_mode);
352
352
dep_events = res_events;
353
353
}
@@ -472,7 +472,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
472
472
config.has_const_scale_val = false ;
473
473
}
474
474
475
- config.has_rotation_coefficients_input = desc->has_rotation_coefficients ;
475
+ config.has_rotated_blocks = desc->has_rotated_blocks ;
476
476
477
477
if (desc->heads_num != desc->kv_heads_num ) {
478
478
config.broadcast_axis = 1 ;
@@ -752,7 +752,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
752
752
for (const auto & input_layout : impl_param.input_layouts )
753
753
input_tensors.emplace_back (convert_data_tensor (input_layout));
754
754
755
- if (has_rotation_coefficients ) {
755
+ if (has_rotated_blocks ) {
756
756
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params (impl_param, input_tensors, impl_param.is_dynamic ());
757
757
(_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func )(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]);
758
758
}
@@ -792,22 +792,22 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
792
792
auto & pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance ();
793
793
kernels_data.push_back (pa_sdpa_kernel_selector.get_best_kernel (pa_sdpa_kernel_params));
794
794
795
- if (desc->has_rotation_coefficients ) {
795
+ if (desc->has_rotated_blocks ) {
796
796
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params (impl_param, input_tensors, impl_param.is_dynamic ());
797
797
auto & kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance ();
798
798
kernels_data.push_back (kv_cache_rotate_kernel_selector.get_best_kernel (kv_cache_rotate_kernel_params));
799
799
}
800
800
801
801
auto impl = cldnn::make_unique<paged_attention_impl>(kernels_data);
802
802
impl->has_scores_output = desc->has_scores_output ();
803
- impl->has_rotation_coefficients = desc->has_rotation_coefficients ;
803
+ impl->has_rotated_blocks = desc->has_rotated_blocks ;
804
804
805
805
return impl;
806
806
}
807
807
808
808
private:
809
809
bool has_scores_output = false ;
810
- bool has_rotation_coefficients = false ;
810
+ bool has_rotated_blocks = false ;
811
811
};
812
812
813
813
namespace detail {
0 commit comments