Skip to content

Commit ecae4d8

Browse files
committed
comment changes
remove unnecesary print remove unnecesary FX metatype Registry pre commit fix
1 parent c861de0 commit ecae4d8

File tree

3 files changed

+15
-33
lines changed

3 files changed

+15
-33
lines changed

nncf/common/graph/operator_metatypes.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,17 @@ def wrap(obj: Type[OperatorMetatype]) -> Type[OperatorMetatype]:
112112
)
113113
raise nncf.InternalError(msg)
114114
self._op_name_to_op_meta_dict[name] = obj
115-
if hasattr(obj, "module_to_function_names"):
116-
func_name = [
117-
f"{namespace.value}.{func}"
118-
for namespace, functions in obj.module_to_function_names.items()
119-
if functions
120-
for func in functions
121-
]
122-
if len(func_name):
123-
for func in func_name:
124-
self._func_name_to_op_meta_dict[func] = obj
115+
if hasattr(obj, "module_to_function_names"):
116+
for namespace, function_names in obj.module_to_function_names.items():
117+
for function_name in function_names:
118+
target_function_name = f"{namespace.value}.{function_name}"
119+
if target_function_name in self._func_name_to_op_meta_dict:
120+
msg = (
121+
"Inconsistent operator metatype registry - single patched "
122+
f"op name `{target_function_name}` maps to multiple metatypes!"
123+
)
124+
raise nncf.InternalError(msg)
125+
self._func_name_to_op_meta_dict[target_function_name] = obj
125126
return obj
126127

127128
return wrap

nncf/experimental/torch/fx/nncf_graph_builder.py

+1-19
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,6 @@ def _get_layer_attributes(
6565
)
6666
return None
6767

68-
def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype:
69-
"""
70-
Attempts to retrieve correct subtype for the given node.
71-
72-
:param node: Given node.
73-
:param metatype: Given node metatype.
74-
:param model: Target GraphModule instance.
75-
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
76-
"""
77-
if metatype in [om.PTEmbeddingMetatype]:
78-
weight_node = node.args[0]
79-
if weight_node.op == "get_attr":
80-
return om.PTAtenEmbeddingMetatype
81-
82-
return metatype
83-
8468
@staticmethod
8569
def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) -> Tuple[str, om.OperatorMetatype]:
8670
"""
@@ -121,8 +105,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
121105
layer_attrs = GraphConverter._get_layer_attributes(node, node_metatype, model)
122106
node_subtype = node_metatype.determine_subtype(layer_attrs)
123107
node_metatype = node_subtype or node_metatype
124-
if not node_type_name:
125-
node_type_name = node_type
108+
node_type_name = node_type_name or node_type
126109
return node_type_name, node_metatype
127110

128111
@staticmethod
@@ -140,7 +123,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
140123
const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"])
141124
for source_node in model.graph.nodes:
142125
node_type, node_metatype = GraphConverter.get_node_type_and_metatype(source_node, model)
143-
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)
144126
is_shared_node = source_node.op in ("get_attr",) and (
145127
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
146128
)

nncf/torch/graph/operator_metatypes.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
ModuleAttributes = TypeVar("ModuleAttributes", bound=BaseLayerAttributes)
2929

3030
PT_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes")
31-
FX_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes")
3231

3332

3433
class PTOperatorMetatype(OperatorMetatype):
@@ -962,15 +961,15 @@ class PTModuleEmbeddingMetatype(PTModuleOperatorSubtype):
962961
@PT_OPERATOR_METATYPES.register()
963962
class PTEmbeddingMetatype(PTOperatorMetatype):
964963
name = "EmbeddingOp"
965-
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["embedding"], NamespaceTarget.ATEN: ["embedding"]}
964+
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["embedding"]}
966965
hw_config_names = [HWConfigOpName.EMBEDDING]
967966
subtypes = [PTModuleEmbeddingMetatype]
968967
weight_port_ids = [1]
969968

970969

971-
@FX_OPERATOR_METATYPES.register()
970+
@PT_OPERATOR_METATYPES.register()
972971
class PTAtenEmbeddingMetatype(OperatorMetatype):
973-
name = "EmbeddingOp"
972+
name = "AtenEmbeddingOp"
974973
module_to_function_names = {NamespaceTarget.ATEN: ["embedding"]}
975974
hw_config_names = [HWConfigOpName.EMBEDDING]
976975
weight_port_ids = [0]

0 commit comments

Comments
 (0)