Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch FX] Map Namespace Names to Metatype #3237

Open
wants to merge 34 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
98cb4dc
comment changes
anzr299 Feb 4, 2025
5991aef
update operator metatypes
anzr299 Feb 4, 2025
44103b4
revert breaking changes
anzr299 Feb 4, 2025
22c5e51
Update operator_metatypes.py
anzr299 Feb 5, 2025
d53b445
fix aten op error
anzr299 Feb 6, 2025
de1361c
fix
anzr299 Feb 6, 2025
26e1c3a
layer norm metatypes update
anzr299 Feb 7, 2025
541319d
include aten layernorm in backends and metatype lookup list
anzr299 Feb 7, 2025
c3cc784
update reference files
anzr299 Feb 7, 2025
e186b4d
pre commit
anzr299 Feb 7, 2025
7c1e12c
fix for PT traced tensor issue
anzr299 Feb 10, 2025
4ce8d69
fix unnecesary changes
anzr299 Feb 11, 2025
215a529
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Feb 27, 2025
a36c4e3
update synthetic transformer reference graphs and reference values
anzr299 Feb 28, 2025
4ec56be
update get_all_aliases method
anzr299 Feb 28, 2025
a0de15d
add suport for edge case for int input of embedding node
anzr299 Mar 4, 2025
fb49a78
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 4, 2025
4e716c5
fix error
anzr299 Mar 4, 2025
56dfdda
add PT2 metatype mapping from namespace
anzr299 Mar 4, 2025
38e825a
pre commit fix
anzr299 Mar 4, 2025
f863b74
add pt2 operator metatype registry to known registries in PT operator…
anzr299 Mar 4, 2025
2bdc08f
pre commit fix
anzr299 Mar 4, 2025
441dde9
update reference graphs; fix minor debugging leftover
anzr299 Mar 5, 2025
cbda8b6
Merge branch 'develop' into fx/metatypes
anzr299 Mar 6, 2025
8a4e382
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 7, 2025
ab71b81
update graph builder for dynamic shapes
anzr299 Mar 7, 2025
97dc74a
update reference graphs
anzr299 Mar 7, 2025
8fb74e8
Remove Extra Registry;
anzr299 Mar 10, 2025
a90375c
update metatypes
anzr299 Mar 11, 2025
11565c2
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 11, 2025
ffddec7
update reference graphs and metatype for convtranspose2d
anzr299 Mar 11, 2025
f7fd41c
revert reference graph changes
anzr299 Mar 11, 2025
3f2c342
update a wrong reference file
anzr299 Mar 11, 2025
9acc6ae
update operator metatype mapping
anzr299 Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, name: str):
"""
super().__init__(name)
self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {}
self._func_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def register(self, name: Optional[str] = None, is_subtype: bool = False) -> Callable[..., Type[OperatorMetatype]]:
"""
Expand Down Expand Up @@ -111,6 +112,17 @@ def wrap(obj: Type[OperatorMetatype]) -> Type[OperatorMetatype]:
)
raise nncf.InternalError(msg)
self._op_name_to_op_meta_dict[name] = obj
if hasattr(obj, "module_to_function_names"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to remove module_to_function_names and directly fill op_names but use full name

module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"], NamespaceTarget.ATEN: ["conv2d"]}

op_names = [
    "torch.nn.functional.conv2".
    "torch.ops.aten.conv2d",
] 

There is need to update patcher of function and function FunctionHookMode to collect full name of function and use it to determinate metatype.
Need to change existed tracers and check op_names, but it's better to have second function to determinate metatype that will used only for fx.

@alexsu52 @daniil-lyakhov

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you suggest changing the tracer for PT backend to collect full name(Namespace + Op Name) instead of just the name?

Copy link
Contributor

@alexsu52 alexsu52 Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are namespace + op name and op name in the op_names dict.

for namespace, function_names in obj.module_to_function_names.items():
for function_name in function_names:
target_function_name = f"{namespace.value}.{function_name}"
if target_function_name in self._func_name_to_op_meta_dict:
msg = (
"Inconsistent operator metatype registry - single patched "
f"op name `{target_function_name}` maps to multiple metatypes!"
)
raise nncf.InternalError(msg)
self._func_name_to_op_meta_dict[target_function_name] = obj
return obj

return wrap
Expand All @@ -126,6 +138,18 @@ def get_operator_metatype_by_op_name(self, op_name: str) -> Type[OperatorMetatyp
return UnknownMetatype
return self._op_name_to_op_meta_dict[op_name]

def get_operator_metatype_by_func(self, func_name: str) -> Type[OperatorMetatype]:
"""
Returns the operator metatype by function name.

:param func_name: The function name.
:return: The operator metatype.
"""
if func_name not in self._func_name_to_op_meta_dict:
return UnknownMetatype
obj = self._func_name_to_op_meta_dict[func_name]
return obj


NOOP_METATYPES = Registry("noop_metatypes")
INPUT_NOOP_METATYPES = Registry("input_noop_metatypes")
Expand Down
29 changes: 8 additions & 21 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,6 @@ def _get_layer_attributes(
)
return None

def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype:
"""
Attempts to retrieve correct subtype for the given node.

:param node: Given node.
:param metatype: Given node metatype.
:param model: Target GraphModule instance.
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
"""
if metatype in [om.PTEmbeddingMetatype]:
weight_node = node.args[0]
if weight_node.op == "get_attr":
return om.PTAtenEmbeddingMetatype

return metatype

@staticmethod
def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) -> Tuple[str, om.OperatorMetatype]:
"""
Expand All @@ -90,6 +74,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
:param model: Given GraphModule.
:return: Node's type and metatype.
"""
node_type_name = None
if node.op == "placeholder":
node_type = "input"
node_metatype = om.PTInputNoopMetatype
Expand All @@ -101,13 +86,15 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
node_metatype = om.PTConstNoopMetatype
elif node.op in ("call_function",):
if hasattr(node.target, "overloadpacket"):
node_type = str(node.target.overloadpacket).split(".")[1]
node_type = str(node.target.overloadpacket)
node_type_name = node_type.split(".")[1]
elif node.target.__name__ == "getitem":
node_type = "__getitem__"
node_type = "aten.__getitem__"
node_type_name = "__getitem__"
else:
# TODO(dlyakhov): get correct nodes types from this nodes as well
node_type = str(node.target)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_func(node_type)
else:
node_type = node.op
node_metatype = UnknownMetatype
Expand All @@ -118,7 +105,8 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
layer_attrs = GraphConverter._get_layer_attributes(node, node_metatype, model)
node_subtype = node_metatype.determine_subtype(layer_attrs)
node_metatype = node_subtype or node_metatype
return node_type, node_metatype
node_type_name = node_type_name or node_type
return node_type_name, node_metatype

@staticmethod
def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
Expand All @@ -135,7 +123,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"])
for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter.get_node_type_and_metatype(source_node, model)
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)
is_shared_node = source_node.op in ("get_attr",) and (
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
)
Expand Down
Loading
Loading