11
11
12
12
13
13
from collections import defaultdict
14
- from typing import Any , Dict , List , Optional , Tuple , Union , cast
14
+ from typing import Any , Dict , List , Optional , Set , Tuple , Union , cast
15
15
16
16
import networkx as nx # type: ignore
17
17
import torch
21
21
import nncf .torch .graph .operator_metatypes as om
22
22
from nncf .common .graph .graph import NNCFGraph
23
23
from nncf .common .graph .graph import NNCFNode
24
+ from nncf .common .graph .layer_attributes import BaseLayerAttributes
24
25
from nncf .common .graph .layer_attributes import Dtype
25
26
from nncf .experimental .torch2 .function_hook .graph .build_graph_mode import build_graph
26
27
from nncf .experimental .torch2 .function_hook .graph .graph_utils import ConstMeta
27
28
from nncf .experimental .torch2 .function_hook .graph .graph_utils import EdgeMeta
28
29
from nncf .experimental .torch2 .function_hook .graph .graph_utils import FunctionMeta
29
30
from nncf .experimental .torch2 .function_hook .graph .graph_utils import InOutMeta
30
31
from nncf .experimental .torch2 .function_hook .graph .graph_utils import NodeType
32
+ from nncf .experimental .torch2 .function_hook .nncf_graph .layer_attributes import PT2OpLayerAttributes
31
33
32
34
33
35
def get_node_type (type : NodeType , meta : Union [ConstMeta , FunctionMeta , InOutMeta ]) -> str :
@@ -45,7 +47,7 @@ def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta
45
47
if isinstance (meta , ConstMeta ):
46
48
return "nncf_model_const"
47
49
if isinstance (meta , FunctionMeta ):
48
- return meta .fn_name
50
+ return meta .func_name
49
51
raise nncf .InternalError ("Unexpected metadata type" )
50
52
51
53
@@ -77,20 +79,86 @@ def get_dtype(dtype: torch.dtype) -> Dtype:
77
79
return Dtype .INTEGER
78
80
79
81
80
- def get_meta_type (node_type : str , meta : Union [ConstMeta , FunctionMeta , InOutMeta ]) -> om .PTOperatorMetatype :
82
+ def get_meta_type (node_type : str , meta : Union [ConstMeta , FunctionMeta , InOutMeta ]) -> type [ om .PTOperatorMetatype ] :
81
83
"""
82
84
Converts the node type and metadata into a PTOperatorMetatype object.
83
85
:param node_type: The type of the node.
84
86
:param meta: The metadata associated with the node.
85
87
:return: The PTOperatorMetatype object.
86
88
"""
87
- node_metatype = cast (om .PTOperatorMetatype , om .PT_OPERATOR_METATYPES .get_operator_metatype_by_op_name (node_type ))
88
- node_sub_meta_type : Optional [om .PTOperatorMetatype ] = None
89
+ node_metatype = cast (
90
+ type [om .PTOperatorMetatype ], om .PT_OPERATOR_METATYPES .get_operator_metatype_by_op_name (node_type )
91
+ )
92
+ node_sub_meta_type : Optional [type [om .PTOperatorMetatype ]] = None
89
93
if node_metatype .get_subtypes () and isinstance (meta , FunctionMeta ):
90
94
node_sub_meta_type = node_metatype .determine_subtype (function_args = meta .args , functions_kwargs = meta .kwargs )
91
95
return node_sub_meta_type or node_metatype
92
96
93
97
98
+ def is_constant_input_node (nx_graph : nx .MultiDiGraph , node : int ) -> bool :
99
+ """
100
+ Check if a node is a constant input node or constant subgraph:
101
+
102
+ 1) constant
103
+ 2) quantize_function -> constant
104
+
105
+ :param nx_graph: The graph to check the node from.
106
+ :param node: The node to check.
107
+ :return: True if the node is a constant input node, False otherwise.
108
+ """
109
+ meta = nx_graph .nodes [node ]["meta" ]
110
+
111
+ # 1) Input node is a constant node (parameter or buffer)
112
+ if isinstance (meta , ConstMeta ):
113
+ return True
114
+
115
+ # 2) Quantize node with constant input
116
+ if (
117
+ isinstance (meta , FunctionMeta )
118
+ and meta .func_name in om .QUANTIZE_NODE_TYPES
119
+ and isinstance (nx_graph .nodes [node ]["meta" ], FunctionMeta )
120
+ ):
121
+ return all (isinstance (nx_graph .nodes [s_node ]["meta" ], ConstMeta ) for s_node , _ in nx_graph .in_edges (node ))
122
+
123
+ return False
124
+
125
+
126
+ def get_constant_port_ids (nx_graph : nx .MultiDiGraph , node : int ) -> Set [int ]:
127
+ """
128
+ Get the indices of input ports corresponding to the constant node or subgraph.
129
+
130
+ :param nx_graph: The graph to get the constant port IDs from.
131
+ :param node: The node to get the constant port IDs from.
132
+ :return: The list of input port indices with constants.
133
+ """
134
+ constant_port_ids : Set [int ] = set ()
135
+
136
+ for s_node , _ , data in nx_graph .in_edges (node , data = True ):
137
+ if is_constant_input_node (nx_graph , s_node ):
138
+ meta = cast (EdgeMeta , data ["meta" ])
139
+ constant_port_ids .add (meta .input_port )
140
+
141
+ return constant_port_ids
142
+
143
+
144
+ def get_layer_attributes (
145
+ nx_graph : nx .MultiDiGraph , node : int , meta : Union [ConstMeta , FunctionMeta , InOutMeta ]
146
+ ) -> Optional [BaseLayerAttributes ]:
147
+ """
148
+ Get the layer attributes of a node in the graph.
149
+
150
+ :param nx_graph: The graph to get the layer attributes from.
151
+ :param node: The node to get the layer attributes from.
152
+ :param meta: The metadata associated with the node.
153
+ :return: The layer attributes of the node.
154
+ """
155
+ if isinstance (meta , FunctionMeta ):
156
+ constant_port_ids = get_constant_port_ids (nx_graph , node )
157
+ return PT2OpLayerAttributes (meta .func , meta .args , meta .kwargs , constant_port_ids )
158
+
159
+ return None
160
+
161
+
94
162
def convert_to_nncf_graph (nx_graph : nx .MultiDiGraph ) -> NNCFGraph :
95
163
"""
96
164
Converts a graph to an NNCFGraph.
@@ -102,15 +170,18 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
102
170
103
171
map_nx_node_to_nncf_node : Dict [int , NNCFNode ] = {}
104
172
for node , data in nx_graph .nodes (data = True ):
105
- meta : Union [ConstMeta , FunctionMeta , InOutMeta ] = data ["meta" ]
173
+ meta = data ["meta" ]
174
+ if not isinstance (meta , (ConstMeta , FunctionMeta , InOutMeta )):
175
+ raise nncf .InternalError (f"Unknown metadata type: { type (meta )} " )
106
176
node_name = get_name_of_node (meta )
107
177
node_type = get_node_type (data ["type" ], meta )
108
178
meta_type = get_meta_type (node_type , meta )
109
-
179
+ layer_attributes = get_layer_attributes ( nx_graph , node , meta )
110
180
nncf_node = nncf_graph .add_nncf_node (
111
181
node_name = node_name ,
112
182
node_type = node_type ,
113
- node_metatype = meta_type , # type: ignore[arg-type]
183
+ node_metatype = meta_type ,
184
+ layer_attributes = layer_attributes ,
114
185
)
115
186
map_nx_node_to_nncf_node [node ] = nncf_node
116
187
0 commit comments