@@ -329,18 +329,63 @@ const std::vector<std::vector<InputShape>> shapes{
329
329
};
330
330
331
331
const std::vector<std::vector<int64_t >> disable_transpose{};
332
- const std::vector<std::vector<int64_t >> enable_transpose{{0 , 1 , 2 , 3 }, {0 , 1 , 2 , 3 }, {0 , 2 , 1 , 3 }};
332
+ const std::vector<std::vector<int64_t >> transpose_value{{0 , 1 , 2 , 3 }, {0 , 1 , 2 , 3 }, {0 , 2 , 1 , 3 }};
333
+ const std::vector<std::vector<int64_t >> transpose_all{{0 , 2 , 1 , 3 }, {0 , 2 , 1 , 3 }, {0 , 2 , 1 , 3 }};
333
334
334
335
const auto params = testing::Combine(testing::Values(ov::element::f16 /* , ov::element::f32 */ ),
335
- testing::ValuesIn (shapes),
336
- testing::Values(true , false ),
337
- testing::Values(true , false ),
338
- testing::Values(true , false ),
339
- testing::ValuesIn({disable_transpose, enable_transpose }));
336
+ testing::ValuesIn (shapes),
337
+ testing::Values(true , false ),
338
+ testing::Values(true , false ),
339
+ testing::Values(true , false ),
340
+ testing::ValuesIn({disable_transpose, transpose_value }));
340
341
341
342
INSTANTIATE_TEST_SUITE_P (smoke_ScaledAttn_GPU,
342
343
ScaledAttnLayerGPUTest,
343
344
params,
344
345
ScaledAttnLayerGPUTest::getTestCaseName);
345
346
347
+ const std::vector<std::vector<InputShape>> static_shapes{
348
+ // static shapes
349
+ {
350
+ // q shape
351
+ {ov::test::InputShape{ov::PartialShape{1 , 8 , 100 , 128 },
352
+ {ov::Shape{1 , 8 , 100 , 128 }}}
353
+ },
354
+ // kv shape
355
+ {ov::test::InputShape{ov::PartialShape{1 , 8 , 100 , 128 },
356
+ {ov::Shape{1 , 8 , 100 , 128 }}}
357
+ },
358
+ // attn shape: [B, 1, -1, L0+L1]
359
+ {ov::test::InputShape{ov::PartialShape{1 , 1 , 100 , 100 },
360
+ {ov::Shape{1 , 1 , 100 , 100 }}}
361
+ },
362
+ },
363
+ {
364
+ // q shape
365
+ {ov::test::InputShape{ov::PartialShape{1 , 8 , 64 , 128 },
366
+ {ov::Shape{1 , 8 , 64 , 128 }}}
367
+ },
368
+ // kv shape
369
+ {ov::test::InputShape{ov::PartialShape{1 , 8 , 13 , 128 },
370
+ {ov::Shape{1 , 8 , 13 , 128 }}}
371
+ },
372
+ // attn shape: [B, 1, -1, L0+L1]
373
+ {ov::test::InputShape{ov::PartialShape{1 , 1 , 64 , 13 },
374
+ {ov::Shape{1 , 1 , 64 , 13 }}}
375
+ },
376
+ },
377
+ };
378
+
379
+ const auto static_params = testing::Combine(testing::Values(ov::element::f16),
380
+ testing::ValuesIn (static_shapes),
381
+ testing::Values(true , false ),
382
+ testing::Values(true , false ),
383
+ testing::Values(true , false ),
384
+ testing::ValuesIn({disable_transpose, transpose_all}));
385
+
386
+ INSTANTIATE_TEST_SUITE_P (smoke_ScaledAttnStatic_GPU,
387
+ ScaledAttnLayerGPUTest,
388
+ static_params,
389
+ ScaledAttnLayerGPUTest::getTestCaseName);
390
+
346
391
} // namespace
0 commit comments