-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch FX] Map Namespace Names to Metatype #3237
base: develop
Are you sure you want to change the base?
Changes from 3 commits
98cb4dc
5991aef
44103b4
22c5e51
d53b445
de1361c
26e1c3a
541319d
c3cc784
e186b4d
7c1e12c
4ce8d69
215a529
a36c4e3
4ec56be
a0de15d
fb49a78
4e716c5
56dfdda
38e825a
f863b74
2bdc08f
441dde9
cbda8b6
8a4e382
ab71b81
97dc74a
8fb74e8
a90375c
11565c2
ffddec7
f7fd41c
3f2c342
9acc6ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,6 +78,7 @@ def __init__(self, name: str): | |
""" | ||
super().__init__(name) | ||
self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {} | ||
self._func_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {} | ||
|
||
def register(self, name: Optional[str] = None, is_subtype: bool = False) -> Callable[..., Type[OperatorMetatype]]: | ||
""" | ||
|
@@ -111,6 +112,17 @@ def wrap(obj: Type[OperatorMetatype]) -> Type[OperatorMetatype]: | |
) | ||
raise nncf.InternalError(msg) | ||
self._op_name_to_op_meta_dict[name] = obj | ||
if hasattr(obj, "module_to_function_names"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest to remove module_to_function_names and directly fill op_names but use full name
There is need to update patcher of function and function FunctionHookMode to collect full name of function and use it to determinate metatype. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you suggest changing the tracer for PT backend to collect full name(Namespace + Op Name) instead of just the name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are namespace + op name and op name in the op_names dict. |
||
for namespace, function_names in obj.module_to_function_names.items(): | ||
for function_name in function_names: | ||
target_function_name = f"{namespace.value}.{function_name}" | ||
if target_function_name in self._func_name_to_op_meta_dict: | ||
msg = ( | ||
"Inconsistent operator metatype registry - single patched " | ||
f"op name `{target_function_name}` maps to multiple metatypes!" | ||
) | ||
raise nncf.InternalError(msg) | ||
self._func_name_to_op_meta_dict[target_function_name] = obj | ||
return obj | ||
|
||
return wrap | ||
|
@@ -126,6 +138,18 @@ def get_operator_metatype_by_op_name(self, op_name: str) -> Type[OperatorMetatyp | |
return UnknownMetatype | ||
return self._op_name_to_op_meta_dict[op_name] | ||
|
||
def get_operator_metatype_by_func(self, func_name: str) -> Type[OperatorMetatype]: | ||
""" | ||
Returns the operator metatype by function name. | ||
|
||
:param func_name: The function name. | ||
:return: The operator metatype. | ||
""" | ||
if func_name not in self._func_name_to_op_meta_dict: | ||
return UnknownMetatype | ||
obj = self._func_name_to_op_meta_dict[func_name] | ||
return obj | ||
|
||
|
||
NOOP_METATYPES = Registry("noop_metatypes") | ||
INPUT_NOOP_METATYPES = Registry("input_noop_metatypes") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done