11
11
import random
12
12
from collections import OrderedDict
13
13
from contextlib import contextmanager
14
- from typing import Any , Dict , Generator , List , Type
14
+ from typing import Any , Dict , Generator , List
15
15
16
16
import numpy as np
17
17
import torch
31
31
from nncf .torch .dynamic_graph .scope import Scope
32
32
from nncf .torch .dynamic_graph .scope import ScopeElement
33
33
from nncf .torch .dynamic_graph .trace_tensor import TracedTensorMixin
34
+ from nncf .torch .graph .operator_metatypes import MATMUL_METATYPES
34
35
from nncf .torch .layer_utils import _NNCFModuleMixin
36
+ from nncf .torch .model_graph_manager import get_weight_tensor_port_ids
35
37
from nncf .torch .structures import ExecutionParameters
36
38
37
39
@@ -472,14 +474,18 @@ def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
472
474
return dtype
473
475
474
476
475
- def get_weight_nodes_in_inference_graph (
476
- inference_nncf_graph : NNCFGraph , mat_mul_metatypes : List [Type [om .PTOperatorMetatype ]]
477
+ def get_weight_nodes (
478
+ nncf_graph : NNCFGraph ,
479
+ inference_nncf_graph : NNCFGraph ,
477
480
) -> List [NNCFNode ]:
478
481
"""
479
482
Returns nodes that have weights.
480
483
481
484
:param nncf_graph: Instance of inference NNCFGraph,
485
+ which contains shape of and constant subgraphs.
486
+ :param inference_nncf_graph: Instance of inference NNCFGraph,
482
487
which does not contain shape of and constant subgraphs.
488
+
483
489
:return: All nodes with weights.
484
490
"""
485
491
weight_nodes_candidates = [
@@ -489,28 +495,18 @@ def get_weight_nodes_in_inference_graph(
489
495
]
490
496
weight_nodes = []
491
497
for node in weight_nodes_candidates :
492
- if node .metatype in mat_mul_metatypes and not is_matmul_with_constant_in_inference_graph (
493
- node , inference_nncf_graph
494
- ):
495
- continue
496
- weight_nodes .append (node )
498
+ if is_matmul_with_constant (node , nncf_graph ):
499
+ weight_nodes .append (node )
497
500
return weight_nodes
498
501
499
502
500
- def is_matmul_with_constant_in_inference_graph (node : NNCFNode , inference_nncf_graph : NNCFGraph ) -> bool :
503
+ def is_matmul_with_constant (node : NNCFNode , nncf_graph : NNCFGraph ) -> bool :
501
504
"""
502
505
Determines whether the given node in the NNCF graph represents a matmul with a constant input.
503
506
504
507
:param node: A NNCFNode instance.
505
- :param inference_nncf_graph: An inference NNCFGraph instance.
508
+ :param nncf_graph: Instance of inference NNCFGraph,
509
+ which contains shape of and constant subgraphs.
506
510
:return: True if given node is a matmul with a constant input, False otherwise.
507
511
"""
508
- if node .metatype == om .PTLinearMetatype :
509
- return True
510
-
511
- # Inference graph does not contain constants, so
512
- # any missed input edge means it is a constant branch.
513
- is_matmul_metatype = node .metatype in [om .PTMatMulMetatype , om .PTAddmmMetatype ]
514
- inputs_missed = 1 <= len (inference_nncf_graph .get_input_edges (node )) < len (node .metatype .weight_port_ids )
515
-
516
- return is_matmul_metatype and inputs_missed
512
+ return node .metatype in MATMUL_METATYPES and len (get_weight_tensor_port_ids (node , nncf_graph )) > 0
0 commit comments