Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 88a3159

Browse files
committedFeb 13, 2024·
Fixed compressed Gather optimization
1 parent 70d762c commit 88a3159

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed
 

‎src/plugins/intel_cpu/src/graph_optimizer.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,18 @@ void GraphOptimizer::FuseGatherAndWeightsDecompression(Graph &graph) {
543543
NodePtr subtractNode = mulParent;
544544
if (!expectedNode(subtractNode, Type::Eltwise))
545545
continue;
546-
auto subtractConstNode = subtractNode->getParentEdgeAt(1)->getParent();
546+
NodePtr subtractConvertNode, subtractConstNode;
547+
NodePtr subtractParent = subtractNode->getParentEdgeAt(1)->getParent();
548+
if (expectedNode(subtractParent, Type::Convert)) {
549+
subtractConvertNode = subtractParent;
550+
subtractParent = subtractConvertNode->getParentEdgeAt(0)->getParent();
551+
}
552+
subtractConstNode = subtractParent;
547553
if (!expectedNode(subtractConstNode, Type::Input))
548554
continue;
549555

556+
const bool withSubtractConvert = subtractConvertNode != nullptr;
557+
550558
auto convertNode = subtractNode->getParentEdgeAt(0)->getParent();
551559
if (!expectedNode(convertNode, Type::Convert))
552560
continue;
@@ -597,13 +605,20 @@ void GraphOptimizer::FuseGatherAndWeightsDecompression(Graph &graph) {
597605
gatherNode->addOriginalLayer(multiplyNode->getOriginalLayers());
598606
gatherNode->addOriginalLayer(convertNode->getOriginalLayers());
599607

608+
if (withSubtractConvert) {
609+
gatherNode->addOriginalLayer(subtractConvertNode->getOriginalLayers());
610+
auto subtractConvertEdge = subtractConvertNode->getChildEdges()[0].lock();
611+
graph.RemoveEdge(subtractConvertEdge);
612+
}
600613
gatherNode->addOriginalLayer(subtractNode->getOriginalLayers());
601614
auto subtractConstEdge = subtractConstNode->getChildEdges()[0].lock();
602615
graph.RemoveEdge(subtractConstEdge);
603616

604617
auto multiplyConstEdge = multiplyConstNode->getChildEdges()[0].lock();
605618
graph.RemoveEdge(multiplyConstEdge);
606619

620+
if (withSubtractConvert)
621+
graph.DropNode(subtractConvertNode);
607622
graph.DropNode(convertNode);
608623
graph.DropNode(subtractNode);
609624
graph.DropNode(multiplyNode);

‎src/plugins/intel_cpu/src/nodes/gather.cpp

+1-6
Original file line numberDiff line numberDiff line change
@@ -715,12 +715,7 @@ void Gather::fuseDecompressionConstant(const MemoryCPtr& memory, MemoryCPtr& dec
715715
} else {
716716
DnnlBlockedMemoryDesc memoryDesc(decompression_prc, memory->getShape());
717717
decompressionValuesPtr = std::make_shared<Memory>(getEngine(), memoryDesc, nullptr, false);
718-
const auto elementsCount = memory->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
719-
cpu_convert(memory->getData(),
720-
decompressionValuesPtr->getData(),
721-
DnnlExtensionUtils::DataTypeToElementType(memory->getDataType()),
722-
ov::element::f32,
723-
elementsCount);
718+
decompressionValuesPtr->load(*memory);
724719
}
725720
}
726721

0 commit comments

Comments
 (0)
Please sign in to comment.