@@ -543,10 +543,18 @@ void GraphOptimizer::FuseGatherAndWeightsDecompression(Graph &graph) {
543
543
NodePtr subtractNode = mulParent;
544
544
if (!expectedNode (subtractNode, Type::Eltwise))
545
545
continue ;
546
- auto subtractConstNode = subtractNode->getParentEdgeAt (1 )->getParent ();
546
+ NodePtr subtractConvertNode, subtractConstNode;
547
+ NodePtr subtractParent = subtractNode->getParentEdgeAt (1 )->getParent ();
548
+ if (expectedNode (subtractParent, Type::Convert)) {
549
+ subtractConvertNode = subtractParent;
550
+ subtractParent = subtractConvertNode->getParentEdgeAt (0 )->getParent ();
551
+ }
552
+ subtractConstNode = subtractParent;
547
553
if (!expectedNode (subtractConstNode, Type::Input))
548
554
continue ;
549
555
556
+ const bool withSubtractConvert = subtractConvertNode != nullptr ;
557
+
550
558
auto convertNode = subtractNode->getParentEdgeAt (0 )->getParent ();
551
559
if (!expectedNode (convertNode, Type::Convert))
552
560
continue ;
@@ -597,13 +605,20 @@ void GraphOptimizer::FuseGatherAndWeightsDecompression(Graph &graph) {
597
605
gatherNode->addOriginalLayer (multiplyNode->getOriginalLayers ());
598
606
gatherNode->addOriginalLayer (convertNode->getOriginalLayers ());
599
607
608
+ if (withSubtractConvert) {
609
+ gatherNode->addOriginalLayer (subtractConvertNode->getOriginalLayers ());
610
+ auto subtractConvertEdge = subtractConvertNode->getChildEdges ()[0 ].lock ();
611
+ graph.RemoveEdge (subtractConvertEdge);
612
+ }
600
613
gatherNode->addOriginalLayer (subtractNode->getOriginalLayers ());
601
614
auto subtractConstEdge = subtractConstNode->getChildEdges ()[0 ].lock ();
602
615
graph.RemoveEdge (subtractConstEdge);
603
616
604
617
auto multiplyConstEdge = multiplyConstNode->getChildEdges ()[0 ].lock ();
605
618
graph.RemoveEdge (multiplyConstEdge);
606
619
620
+ if (withSubtractConvert)
621
+ graph.DropNode (subtractConvertNode);
607
622
graph.DropNode (convertNode);
608
623
graph.DropNode (subtractNode);
609
624
graph.DropNode (multiplyNode);
0 commit comments