Skip to content

Commit f0dfed7

Browse files
committed
[GPU] Fix incorrect broadcast axis for static SDPA case
1 parent ed11461 commit f0dfed7

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,12 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
249249
const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), desc->input_v_transpose_order);
250250

251251
OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal");
252-
for (size_t i = 0; i < query_shape.size(); ++i) {
253-
if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) {
254-
if (query_shape[i].get_length() > key_shape[i].get_length()) {
255-
config.broadcast_axis = desc->input_k_transpose_order[i];
256-
config.group_size = query_shape[i].get_length() / key_shape[i].get_length();
257-
}
252+
253+
const auto num_heads_dim = 1;
254+
if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) {
255+
if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) {
256+
config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim];
257+
config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
258258
}
259259
}
260260

src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp

+42-12
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,26 @@ const std::vector<std::vector<InputShape>> shapes{
310310
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, -1},
311311
{ov::Shape{1, 1, 7, 7}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 10, 10}}}
312312
},
313-
},
313+
}
314+
};
315+
316+
const std::vector<std::vector<int64_t>> disable_transpose{};
317+
const std::vector<std::vector<int64_t>> transpose_value{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}};
318+
const std::vector<std::vector<int64_t>> transpose_all{{0, 2, 1, 3}, {0, 2, 1, 3}, {0, 2, 1, 3}};
319+
320+
const auto dynamic_shape_params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */),
321+
testing::ValuesIn(shapes),
322+
testing::Values(true, false),
323+
testing::Values(true, false),
324+
testing::Values(true, false),
325+
testing::ValuesIn({disable_transpose, transpose_value}));
326+
327+
INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU,
328+
ScaledAttnLayerGPUTest,
329+
dynamic_shape_params,
330+
ScaledAttnLayerGPUTest::getTestCaseName);
331+
332+
const std::vector<std::vector<InputShape>> static_shapes{
314333
// static shapes
315334
{
316335
// q shape
@@ -326,21 +345,32 @@ const std::vector<std::vector<InputShape>> shapes{
326345
{ov::Shape{1, 1, 100, 100}}}
327346
},
328347
},
348+
{
349+
// q shape
350+
{ov::test::InputShape{ov::PartialShape{1, 8, 64, 128},
351+
{ov::Shape{1, 8, 64, 128}}}
352+
},
353+
// kv shape
354+
{ov::test::InputShape{ov::PartialShape{1, 8, 13, 128},
355+
{ov::Shape{1, 8, 13, 128}}}
356+
},
357+
// attn shape: [B, 1, -1, L0+L1]
358+
{ov::test::InputShape{ov::PartialShape{1, 1, 64, 13},
359+
{ov::Shape{1, 1, 64, 13}}}
360+
},
361+
},
329362
};
330363

331-
const std::vector<std::vector<int64_t>> disable_transpose{};
332-
const std::vector<std::vector<int64_t>> enable_transpose{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}};
364+
const auto static_shape_params = testing::Combine(testing::Values(ov::element::f16),
365+
testing::ValuesIn(static_shapes),
366+
testing::Values(true, false),
367+
testing::Values(true, false),
368+
testing::Values(true, false),
369+
testing::ValuesIn({disable_transpose, transpose_all}));
333370

334-
const auto params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */),
335-
testing::ValuesIn(shapes),
336-
testing::Values(true, false),
337-
testing::Values(true, false),
338-
testing::Values(true, false),
339-
testing::ValuesIn({disable_transpose, enable_transpose}));
340-
341-
INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU,
371+
INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttnStatic_GPU,
342372
ScaledAttnLayerGPUTest,
343-
params,
373+
static_shape_params,
344374
ScaledAttnLayerGPUTest::getTestCaseName);
345375

346376
} // namespace

0 commit comments

Comments
 (0)