Skip to content

Commit e285715

Browse files
committed
modify test to cover the change
1 parent 2c7b964 commit e285715

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_multiple_query_sdp.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,9 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
152152
auto unsqueezeK = std::make_shared<ov::op::v0::Unsqueeze>(concatK, unsquezeAxis);
153153
auto unsqueezeV = std::make_shared<ov::op::v0::Unsqueeze>(concatV, unsquezeAxis);
154154

155-
auto targetShape = ov::op::v0::Constant::create(qkvType, {1, 1, 1, 4, 1}, {1});
156-
auto broadcastK = std::make_shared<ov::op::v1::Multiply>(unsqueezeK, targetShape);
157-
auto broadcastV = std::make_shared<ov::op::v1::Multiply>(unsqueezeV, targetShape);
155+
auto targetShape = ov::op::v0::Constant::create(element::i32, {5}, {1, 1, 1, 4, 1});
156+
auto broadcastK = std::make_shared<ov::op::v3::Broadcast>(unsqueezeK, targetShape, op::BroadcastType::BIDIRECTIONAL);
157+
auto broadcastV = std::make_shared<ov::op::v3::Broadcast>(unsqueezeV, targetShape, op::BroadcastType::BIDIRECTIONAL);
158158

159159
auto target4D = ov::op::v0::Constant::create(ov::element::i32, {4}, {0, 0, 8, 64});
160160

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_transpose_sdp_transpose.cpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPT
7171
result << ")_";
7272
}
7373
result << "Prc=" << inType << "_";
74-
result << "HasShapeOf=" << hasShapeof;
74+
result << "HasShapeOf=" << hasShapeof << "_";
7575
result << "TransposeOrder=";
7676
result << "(";
7777
for (const auto& itr : transposeOrder) {
@@ -85,7 +85,6 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPT
8585
void SetUp() override {
8686
ElementType inType;
8787
InputShapeAndTransposeOrder inputShapeAndOrders;
88-
bool hasShapeOf;
8988
std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam();
9089
std::vector<InputShape>& inputShapes = inputShapeAndOrders.first;
9190
transposeOrder = inputShapeAndOrders.second;
@@ -124,6 +123,10 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPT
124123
// pre SDPA transpose
125124
auto preOrder = ov::op::v0::Constant::create(ov::element::i32, {4}, transposeOrder);
126125
auto transposeQ = std::make_shared<ov::op::v1::Transpose>(inputParams[0], preOrder);
126+
std::shared_ptr<ov::Node> transposeQ_shapeof;
127+
if (hasShapeOf) {
128+
transposeQ_shapeof = std::make_shared<ov::op::v0::ShapeOf>(transposeQ);
129+
}
127130

128131
auto concat_axis = transposeOrder[2];
129132
auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ElementType::i32, ov::PartialShape{-1});
@@ -166,6 +169,7 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPT
166169
if (hasShapeOf) {
167170
results.push_back(pastk_shapeof);
168171
results.push_back(pastv_shapeof);
172+
results.push_back(transposeQ_shapeof);
169173
}
170174
ov::SinkVector sinks{pastk_assign, pastv_assign};
171175
function = std::make_shared<ov::Model>(results, sinks, inputParams, "ConcatTranposeSDP");
@@ -237,6 +241,7 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPT
237241
}
238242
}
239243
std::vector<size_t> transposeOrder;
244+
bool hasShapeOf;
240245
};
241246

242247
class ConcatSDPTransposeTest : public ConcatSDPTransposeTestBase {
@@ -287,7 +292,10 @@ TEST_P(ConcatSDPTransposeTest, CompareWithRefs) {
287292
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
288293
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
289294
CheckNumberOfNodesWithType(compiledModel, "Transpose", 1);
290-
CheckNumberOfNodesWithType(compiledModel, "Gather", 0);
295+
// Transformation TSShapeOfForward will change:
296+
// ?->transpose->shapeof ==> ?-->shapeof->gather
297+
// |->transpose
298+
CheckNumberOfNodesWithType(compiledModel, "Gather", hasShapeOf ? 1 : 0);
291299
auto expectedOutputs = run_test(functionRefs);
292300
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
293301
for (size_t i = 0; i < actualOutputs.size(); i++) {

0 commit comments

Comments
 (0)