Skip to content

Commit ec35f1b

Browse files
committed
[GPU] Fix incorrect broadcast axis for static SDPA case
1 parent ed11461 commit ec35f1b

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,12 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
249249
const auto value_shape = transpose_pshape(impl_param.get_input_layout(2).get_partial_shape(), desc->input_v_transpose_order);
250250

251251
OPENVINO_ASSERT(key_shape == value_shape, "[GPU] The shapes of key and value inputs are expected to be equal");
252-
for (size_t i = 0; i < query_shape.size(); ++i) {
253-
if (query_shape[i].is_static() && key_shape[i].is_static() && value_shape[i].is_static()) {
254-
if (query_shape[i].get_length() > key_shape[i].get_length()) {
255-
config.broadcast_axis = desc->input_k_transpose_order[i];
256-
config.group_size = query_shape[i].get_length() / key_shape[i].get_length();
257-
}
252+
253+
const auto num_heads_dim = 1;
254+
if (query_shape[num_heads_dim].is_static() && key_shape[num_heads_dim].is_static() && value_shape[num_heads_dim].is_static()) {
255+
if (query_shape[num_heads_dim].get_length() > key_shape[num_heads_dim].get_length()) {
256+
config.broadcast_axis = desc->input_k_transpose_order[num_heads_dim];
257+
config.group_size = query_shape[num_heads_dim].get_length() / key_shape[num_heads_dim].get_length();
258258
}
259259
}
260260

src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/scaled_dot_product_attention.cpp

+51-6
Original file line numberDiff line numberDiff line change
@@ -329,18 +329,63 @@ const std::vector<std::vector<InputShape>> shapes{
329329
};
330330

331331
const std::vector<std::vector<int64_t>> disable_transpose{};
332-
const std::vector<std::vector<int64_t>> enable_transpose{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}};
332+
const std::vector<std::vector<int64_t>> transpose_value{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}};
333+
const std::vector<std::vector<int64_t>> transpose_all{{0, 2, 1, 3}, {0, 2, 1, 3}, {0, 2, 1, 3}};
333334

334335
const auto params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */),
335-
testing::ValuesIn(shapes),
336-
testing::Values(true, false),
337-
testing::Values(true, false),
338-
testing::Values(true, false),
339-
testing::ValuesIn({disable_transpose, enable_transpose}));
336+
testing::ValuesIn(shapes),
337+
testing::Values(true, false),
338+
testing::Values(true, false),
339+
testing::Values(true, false),
340+
testing::ValuesIn({disable_transpose, transpose_value}));
340341

341342
INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU,
342343
ScaledAttnLayerGPUTest,
343344
params,
344345
ScaledAttnLayerGPUTest::getTestCaseName);
345346

347+
const std::vector<std::vector<InputShape>> static_shapes{
348+
// static shapes
349+
{
350+
// q shape
351+
{ov::test::InputShape{ov::PartialShape{1, 8, 100, 128},
352+
{ov::Shape{1, 8, 100, 128}}}
353+
},
354+
// kv shape
355+
{ov::test::InputShape{ov::PartialShape{1, 8, 100, 128},
356+
{ov::Shape{1, 8, 100, 128}}}
357+
},
358+
// attn shape: [B, 1, -1, L0+L1]
359+
{ov::test::InputShape{ov::PartialShape{1, 1, 100, 100},
360+
{ov::Shape{1, 1, 100, 100}}}
361+
},
362+
},
363+
{
364+
// q shape
365+
{ov::test::InputShape{ov::PartialShape{1, 8, 64, 128},
366+
{ov::Shape{1, 8, 64, 128}}}
367+
},
368+
// kv shape
369+
{ov::test::InputShape{ov::PartialShape{1, 8, 13, 128},
370+
{ov::Shape{1, 8, 13, 128}}}
371+
},
372+
// attn shape: [B, 1, -1, L0+L1]
373+
{ov::test::InputShape{ov::PartialShape{1, 1, 64, 13},
374+
{ov::Shape{1, 1, 64, 13}}}
375+
},
376+
},
377+
};
378+
379+
const auto static_params = testing::Combine(testing::Values(ov::element::f16),
380+
testing::ValuesIn(static_shapes),
381+
testing::Values(true, false),
382+
testing::Values(true, false),
383+
testing::Values(true, false),
384+
testing::ValuesIn({disable_transpose, transpose_all}));
385+
386+
INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttnStatic_GPU,
387+
ScaledAttnLayerGPUTest,
388+
static_params,
389+
ScaledAttnLayerGPUTest::getTestCaseName);
390+
346391
} // namespace

0 commit comments

Comments
 (0)