@@ -310,7 +310,26 @@ const std::vector<std::vector<InputShape>> shapes{
310
310
{ov::test::InputShape{ov::PartialShape{-1 , 1 , -1 , -1 },
311
311
{ov::Shape{1 , 1 , 7 , 7 }, ov::Shape{1 , 1 , 1 , 1 }, ov::Shape{2 , 1 , 10 , 10 }}}
312
312
},
313
- },
313
+ }
314
+ };
315
+
316
+ const std::vector<std::vector<int64_t >> disable_transpose{};
317
+ const std::vector<std::vector<int64_t >> transpose_value{{0 , 1 , 2 , 3 }, {0 , 1 , 2 , 3 }, {0 , 2 , 1 , 3 }};
318
+ const std::vector<std::vector<int64_t >> transpose_all{{0 , 2 , 1 , 3 }, {0 , 2 , 1 , 3 }, {0 , 2 , 1 , 3 }};
319
+
320
+ const auto dynamic_shape_params = testing::Combine(testing::Values(ov::element::f16 /* , ov::element::f32 */ ),
321
+ testing::ValuesIn (shapes),
322
+ testing::Values(true , false ),
323
+ testing::Values(true , false ),
324
+ testing::Values(true , false ),
325
+ testing::ValuesIn({disable_transpose, transpose_value}));
326
+
327
+ INSTANTIATE_TEST_SUITE_P (smoke_ScaledAttn_GPU,
328
+ ScaledAttnLayerGPUTest,
329
+ dynamic_shape_params,
330
+ ScaledAttnLayerGPUTest::getTestCaseName);
331
+
332
+ const std::vector<std::vector<InputShape>> static_shapes{
314
333
// static shapes
315
334
{
316
335
// q shape
@@ -326,21 +345,32 @@ const std::vector<std::vector<InputShape>> shapes{
326
345
{ov::Shape{1 , 1 , 100 , 100 }}}
327
346
},
328
347
},
348
+ {
349
+ // q shape
350
+ {ov::test::InputShape{ov::PartialShape{1 , 8 , 64 , 128 },
351
+ {ov::Shape{1 , 8 , 64 , 128 }}}
352
+ },
353
+ // kv shape
354
+ {ov::test::InputShape{ov::PartialShape{1 , 8 , 13 , 128 },
355
+ {ov::Shape{1 , 8 , 13 , 128 }}}
356
+ },
357
+ // attn shape: [B, 1, -1, L0+L1]
358
+ {ov::test::InputShape{ov::PartialShape{1 , 1 , 64 , 13 },
359
+ {ov::Shape{1 , 1 , 64 , 13 }}}
360
+ },
361
+ },
329
362
};
330
363
331
- const std::vector<std::vector<int64_t >> disable_transpose{};
332
- const std::vector<std::vector<int64_t >> enable_transpose{{0 , 1 , 2 , 3 }, {0 , 1 , 2 , 3 }, {0 , 2 , 1 , 3 }};
364
+ const auto static_shape_params = testing::Combine(testing::Values(ov::element::f16),
365
+ testing::ValuesIn (static_shapes),
366
+ testing::Values(true , false ),
367
+ testing::Values(true , false ),
368
+ testing::Values(true , false ),
369
+ testing::ValuesIn({disable_transpose, transpose_all}));
333
370
334
- const auto params = testing::Combine(testing::Values(ov::element::f16 /* , ov::element::f32 */ ),
335
- testing::ValuesIn (shapes),
336
- testing::Values(true , false ),
337
- testing::Values(true , false ),
338
- testing::Values(true , false ),
339
- testing::ValuesIn({disable_transpose, enable_transpose}));
340
-
341
- INSTANTIATE_TEST_SUITE_P (smoke_ScaledAttn_GPU,
371
+ INSTANTIATE_TEST_SUITE_P (smoke_ScaledAttnStatic_GPU,
342
372
ScaledAttnLayerGPUTest,
343
- params ,
373
+ static_shape_params ,
344
374
ScaledAttnLayerGPUTest::getTestCaseName);
345
375
346
376
} // namespace
0 commit comments