Skip to content

Commit 2becda4

Browse files
committed
[GPU] Update PagedAttention output shape, add dynamic paddings support for mixed kernel mode execution
1 parent cfbc998 commit 2becda4

File tree

3 files changed

+63
-25
lines changed

3 files changed

+63
-25
lines changed

src/plugins/intel_gpu/src/graph/paged_attention.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no
4949
template<typename ShapeType>
5050
std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) {
5151
auto data_layout = impl_param.get_input_layout(0);
52+
data_layout.data_padding = padding();
5253

5354
const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape();
5455
bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size;
@@ -71,7 +72,7 @@ std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_no
7172
total_size += past_lens_mem_lock[i];
7273
}
7374

74-
total_size += static_cast<long int>(impl_param.get_input_layout(0).get_shape()[0]);
75+
total_size += static_cast<long int>(data_layout.get_shape()[0]);
7576

7677
output_layouts.push_back(layout{ov::PartialShape{total_size}, output_dt, format::bfyx});
7778
} else {

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+4-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ KERNEL(pa_sdpa_opt)(
118118
{
119119
#if STORE_QUERY_TO_SLM
120120
const uint query_idx_local = sgid * SUBGROUP_SIZE + sglid;
121-
const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM +
121+
const uint query_idx = INPUT0_OFFSET +
122+
seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
122123
head_num_idx * HEAD_SIZE +
123124
query_idx_local;
124125

@@ -137,7 +138,8 @@ KERNEL(pa_sdpa_opt)(
137138
#else
138139
INPUT0_TYPE q_val[HEAD_SIZE / SUBGROUP_SIZE];
139140
unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {
140-
const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM +
141+
const uint query_idx = INPUT0_OFFSET +
142+
seq_idx * (HEAD_SIZE * HEADS_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
141143
head_num_idx * HEAD_SIZE +
142144
i * SUBGROUP_SIZE;
143145
q_val[i] = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);

src/plugins/intel_gpu/tests/unit/test_cases/paged_attention_gpu_test.cpp

+57-22
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,37 @@ struct PagedAttentionTest : public ::testing::TestWithParam<T> {
730730
rotation_deltas_layout.set_partial_shape(ov::PartialShape{ -1, -1 });
731731
rotation_trig_lut_layout.set_partial_shape(ov::PartialShape{ -1, p.head_size });
732732

733+
if (p.dynamic_paddings) {
734+
const auto padding_axis = 1;
735+
const auto pad_before = p.head_size;
736+
const auto pad_after = p.head_size * 2;
737+
738+
query_layout.data_padding._dynamic_dims_mask[padding_axis] = 1;
739+
740+
auto query_data_layout = query_mem->get_layout();
741+
auto padded_query_data_layout = query_data_layout;
742+
padded_query_data_layout.data_padding._lower_size[padding_axis] = pad_before;
743+
padded_query_data_layout.data_padding._upper_size[padding_axis] = pad_after;
744+
745+
auto new_query_memory = get_test_engine().allocate_memory(padded_query_data_layout, false);
746+
747+
mem_lock<ov::float16> query_mem_lock(query_mem, get_test_stream());
748+
mem_lock<ov::float16> new_query_mem_lock(new_query_memory, get_test_stream());
749+
750+
auto query_data_shape = query_data_layout.get_shape();
751+
for (size_t b = 0; b < query_data_shape[0]; b++) {
752+
for (size_t f = 0; f < query_data_shape[1]; f++) {
753+
auto input_offset =
754+
query_data_layout.get_linear_offset(cldnn::tensor(static_cast<int32_t>(b), static_cast<int32_t>(f), 0, 0, 0, 0));
755+
auto output_offset =
756+
padded_query_data_layout.get_linear_offset(cldnn::tensor(static_cast<int32_t>(b), static_cast<int32_t>(f), 0, 0, 0, 0));
757+
758+
new_query_mem_lock[output_offset] = query_mem_lock[input_offset];
759+
}
760+
}
761+
query_mem = new_query_memory;
762+
}
763+
733764
std::vector<input_info> pa_inputs = {
734765
input_info("query"),
735766
input_info("key"),
@@ -857,6 +888,7 @@ struct paged_attention_test_params {
857888
int num_heads;
858889
int head_size;
859890
int block_size;
891+
bool dynamic_paddings;
860892
bool scores_output;
861893
CacheRotationDescriptor rotation_config;
862894
};
@@ -873,31 +905,34 @@ const auto DISABLE_SCORES = false;
873905
const auto PER_BLOCK_ROTATION = CacheRotationDescriptor{ true, true };
874906
const auto PER_TOKEN_ROTATION = CacheRotationDescriptor{ true, false };
875907
const auto DISABLE_ROTATION = CacheRotationDescriptor{ false, false };
908+
const auto STATIC_INPUT_PAD = false;
909+
const auto DYNAMIC_INPUT_PAD = true;
876910

877911
INSTANTIATE_TEST_SUITE_P(smoke_paged_attention, paged_attention_test, ::testing::ValuesIn(std::vector<paged_attention_test_params>{
878912
/* with scores output */
879-
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
880-
paged_attention_test_params{ {{36, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
881-
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token long
882-
paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
883-
paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
884-
paged_attention_test_params{ {{1, 10}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token
885-
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
886-
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
887-
/* without scores output */
888-
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token
889-
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token long
890-
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, DISABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
913+
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
914+
paged_attention_test_params{ {{36, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token
915+
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token long
916+
paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
917+
paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 1st token + 1st token
918+
paged_attention_test_params{ {{1, 10}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token
919+
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
920+
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
921+
/* without scores output, dynamic input query paddings */
922+
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token
923+
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 1st token long
924+
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // 2nd token + 2nd token
925+
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, DYNAMIC_INPUT_PAD, DISABLE_SCORES, DISABLE_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
891926
/* with scores, per_block rotation */
892-
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
893-
paged_attention_test_params{ {{36, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
894-
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token long
895-
paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
896-
paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
897-
paged_attention_test_params{ {{1, 10}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token
898-
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + 2nd token
899-
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, PER_BLOCK_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
927+
paged_attention_test_params{ {{10, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
928+
paged_attention_test_params{ {{36, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token
929+
paged_attention_test_params{ {{1024, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token long
930+
paged_attention_test_params{ {{10, 0}, {30, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
931+
paged_attention_test_params{ {{128, 0}, {256, 0}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 1st token + 1st token
932+
paged_attention_test_params{ {{1, 10}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token
933+
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // 2nd token + 2nd token
934+
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_BLOCK_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
900935
/* with scores, per_token rotation */
901-
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, ENABLE_SCORES, PER_TOKEN_ROTATION }, // 2nd token + 2nd token
902-
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, ENABLE_SCORES, PER_TOKEN_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
936+
paged_attention_test_params{ {{1, 34}, {1, 515}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION }, // 2nd token + 2nd token
937+
paged_attention_test_params{ {{1, 34}, {25, 0}, {10, 34}}, 2, 64, 16, STATIC_INPUT_PAD, ENABLE_SCORES, PER_TOKEN_ROTATION }, // mixed: 2nd token + 1st token + part of 1st token
903938
}));

0 commit comments

Comments
 (0)