@@ -730,6 +730,37 @@ struct PagedAttentionTest : public ::testing::TestWithParam<T> {
730
730
rotation_deltas_layout.set_partial_shape (ov::PartialShape{ -1 , -1 });
731
731
rotation_trig_lut_layout.set_partial_shape (ov::PartialShape{ -1 , p.head_size });
732
732
733
+ if (p.dynamic_paddings ) {
734
+ const auto padding_axis = 1 ;
735
+ const auto pad_before = p.head_size ;
736
+ const auto pad_after = p.head_size * 2 ;
737
+
738
+ query_layout.data_padding ._dynamic_dims_mask [padding_axis] = 1 ;
739
+
740
+ auto query_data_layout = query_mem->get_layout ();
741
+ auto padded_query_data_layout = query_data_layout;
742
+ padded_query_data_layout.data_padding ._lower_size [padding_axis] = pad_before;
743
+ padded_query_data_layout.data_padding ._upper_size [padding_axis] = pad_after;
744
+
745
+ auto new_query_memory = get_test_engine ().allocate_memory (padded_query_data_layout, false );
746
+
747
+ mem_lock<ov::float16> query_mem_lock (query_mem, get_test_stream ());
748
+ mem_lock<ov::float16> new_query_mem_lock (new_query_memory, get_test_stream ());
749
+
750
+ auto query_data_shape = query_data_layout.get_shape ();
751
+ for (size_t b = 0 ; b < query_data_shape[0 ]; b++) {
752
+ for (size_t f = 0 ; f < query_data_shape[1 ]; f++) {
753
+ auto input_offset =
754
+ query_data_layout.get_linear_offset (cldnn::tensor (static_cast <int32_t >(b), static_cast <int32_t >(f), 0 , 0 , 0 , 0 ));
755
+ auto output_offset =
756
+ padded_query_data_layout.get_linear_offset (cldnn::tensor (static_cast <int32_t >(b), static_cast <int32_t >(f), 0 , 0 , 0 , 0 ));
757
+
758
+ new_query_mem_lock[output_offset] = query_mem_lock[input_offset];
759
+ }
760
+ }
761
+ query_mem = new_query_memory;
762
+ }
763
+
733
764
std::vector<input_info> pa_inputs = {
734
765
input_info (" query" ),
735
766
input_info (" key" ),
@@ -857,6 +888,7 @@ struct paged_attention_test_params {
857
888
int num_heads;
858
889
int head_size;
859
890
int block_size;
891
+ bool dynamic_paddings;
860
892
bool scores_output;
861
893
CacheRotationDescriptor rotation_config;
862
894
};
@@ -873,31 +905,34 @@ const auto DISABLE_SCORES = false;
873
905
const auto PER_BLOCK_ROTATION = CacheRotationDescriptor{ true , true };
874
906
const auto PER_TOKEN_ROTATION = CacheRotationDescriptor{ true , false };
875
907
const auto DISABLE_ROTATION = CacheRotationDescriptor{ false , false };
908
+ const auto STATIC_INPUT_PAD = false ;
909
+ const auto DYNAMIC_INPUT_PAD = true ;
876
910
877
911
INSTANTIATE_TEST_SUITE_P (smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector<paged_attention_test_params>{
878
912
/* with scores output */
879
- paged_attention_test_params{ {{10 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
880
- paged_attention_test_params{ {{36 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
881
- paged_attention_test_params{ {{1024 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, DISABLE_ROTATION }, // 1st token long
882
- paged_attention_test_params{ {{10 , 0 }, {30 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
883
- paged_attention_test_params{ {{128 , 0 }, {256 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
884
- paged_attention_test_params{ {{1 , 10 }}, 2 , 64 , 16 , ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token
885
- paged_attention_test_params{ {{1 , 34 }, {1 , 515 }}, 2 , 64 , 16 , ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
886
- paged_attention_test_params{ {{1 , 34 }, {25 , 0 }, {10 , 34 }}, 2 , 64 , 16 , ENABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
887
- /* without scores output */
888
- paged_attention_test_params{ {{10 , 0 }}, 2 , 64 , 16 , DISABLE_SCORES, DISABLE_ROTATION }, // 1st token
889
- paged_attention_test_params{ {{1024 , 0 }}, 2 , 64 , 16 , DISABLE_SCORES, DISABLE_ROTATION }, // 1st token long
890
- paged_attention_test_params{ {{1 , 34 }, {1 , 515 }}, 2 , 64 , 16 , DISABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
913
+ paged_attention_test_params{ {{10 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
914
+ paged_attention_test_params{ {{36 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
915
+ paged_attention_test_params{ {{1024 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token long
916
+ paged_attention_test_params{ {{10 , 0 }, {30 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
917
+ paged_attention_test_params{ {{128 , 0 }, {256 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
918
+ paged_attention_test_params{ {{1 , 10 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token
919
+ paged_attention_test_params{ {{1 , 34 }, {1 , 515 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
920
+ paged_attention_test_params{ {{1 , 34 }, {25 , 0 }, {10 , 34 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
921
+ /* without scores output, dynamic input query paddings */
922
+ paged_attention_test_params{ {{10 , 0 }}, 2 , 64 , 16 , DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token
923
+ paged_attention_test_params{ {{1024 , 0 }}, 2 , 64 , 16 , DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token long
924
+ paged_attention_test_params{ {{1 , 34 }, {1 , 515 }}, 2 , 64 , 16 , DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
925
+ paged_attention_test_params{ {{1 , 34 }, {25 , 0 }, {10 , 34 }}, 2 , 64 , 16 , DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
891
926
/* with scores, per_block rotation */
892
- paged_attention_test_params{ {{10 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
893
- paged_attention_test_params{ {{36 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
894
- paged_attention_test_params{ {{1024 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token long
895
- paged_attention_test_params{ {{10 , 0 }, {30 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
896
- paged_attention_test_params{ {{128 , 0 }, {256 , 0 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
897
- paged_attention_test_params{ {{1 , 10 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token
898
- paged_attention_test_params{ {{1 , 34 }, {1 , 515 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + 2nd token
899
- paged_attention_test_params{ {{1 , 34 }, {25 , 0 }, {10 , 34 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_BLOCK_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
927
+ paged_attention_test_params{ {{10 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
928
+ paged_attention_test_params{ {{36 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
929
+ paged_attention_test_params{ {{1024 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token long
930
+ paged_attention_test_params{ {{10 , 0 }, {30 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
931
+ paged_attention_test_params{ {{128 , 0 }, {256 , 0 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
932
+ paged_attention_test_params{ {{1 , 10 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token
933
+ paged_attention_test_params{ {{1 , 34 }, {1 , 515 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + 2nd token
934
+ paged_attention_test_params{ {{1 , 34 }, {25 , 0 }, {10 , 34 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
900
935
/* with scores, per_token rotation */
901
- paged_attention_test_params{ {{1 , 34 }, {1 , 515 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_TOKEN_ROTATION }, // 2nd token + 2nd token
902
- paged_attention_test_params{ {{1 , 34 }, {25 , 0 }, {10 , 34 }}, 2 , 64 , 16 , ENABLE_SCORES, PER_TOKEN_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
936
+ paged_attention_test_params{ {{1 , 34 }, {1 , 515 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION }, // 2nd token + 2nd token
937
+ paged_attention_test_params{ {{1 , 34 }, {25 , 0 }, {10 , 34 }}, 2 , 64 , 16 , STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
903
938
}));
0 commit comments