-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch FX] Map Namespace Names to Metatype #3237
base: develop
Are you sure you want to change the base?
Conversation
remove unnecesary print remove unnecesary FX metatype Registry pre commit fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@anzr299, please fix the pre-commit: |
A test case have caught the issue with an embedding model here ynimmaga/executorch#26 |
… metatypes. Update reference graphs
@@ -78,6 +78,7 @@ def __init__(self, name: str): | |||
""" | |||
super().__init__(name) | |||
self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {} | |||
self._func_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -205,4 +203,4 @@ def get_edge_params( | |||
nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.") | |||
|
|||
input_port_id = dist_node.all_input_nodes.index(source_node) | |||
return input_port_id, output_port_id, tensor_shape | |||
return input_port_id, output_port_id, tensor_shape, tensor_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update return annotation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
from nncf.torch.dynamic_graph.graph import DynamicGraph | ||
from nncf.torch.dynamic_graph.structs import NamespaceTarget | ||
|
||
ModuleAttributes = TypeVar("ModuleAttributes", bound=BaseLayerAttributes) | ||
|
||
PT_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes") | ||
FX_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes") | ||
PT2_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to add extra registry use only one, but use full operation name as op_name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@property | ||
def func_namespace(self) -> str: | ||
if self.func.__qualname__.split(".")[0] == "TensorBase": | ||
return f"torch.tensor.{str(self.func.__name__)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to use str within f-string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, I missed that when changing from an older implementation. Thank you very much!
@@ -82,6 +82,14 @@ class FunctionMeta: | |||
def func_name(self) -> str: | |||
return self.func.__name__ | |||
|
|||
@property | |||
def func_namespace(self) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change func_name property, there is need only one way to get name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if self.func.__qualname__.split(".")[0] == "TensorBase": | ||
return f"torch.tensor.{str(self.func.__name__)}" | ||
elif self.func.__qualname__ == self.func.__name__: | ||
return f"torch.nn.functional.{str(self.func.__name__)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not correctly works for function from torchvision like deform_conv2d
or symmetric_quantize
in nncf
So better use f"{func.__module__}.{func.__name__}"
as op_name
for functions and func.__qualname__
for methods.
And using NamespaceTarget in definition of metatypes is actual only for patched functions, for FX and TorchFunctionMode tracing is it should not be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I did not understand the suggestion "f"{func.module}.{func.name}" as op_name for functions and func.qualname for methods" do you suggest using these directly? because for example torch.tensor.to
does not have a module but qualname returns TensorBase.to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes,
torch.tensor.to is method and can not be call as torch.tensor.to
and methods is not contain __module__
It can used only for instance of Tensor torch.tensor([1]).to.__qualname__
So i have suggest to use qualname for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so I will modify the metatype definition to be tensorbase.to instead of torch.tensor.to so that it can match with the correct metatype? or do I create a fake tensor and call Tensor([1]).to.__qualname__
on it?
review fixes
Changes
Created a new dictionary in operator metatypes registry to maintain mapping of namespace and function name to metatype.
Reason for changes
For a more accurate retrieval of operation metatype.
Tests
original graph tests and metatypes were used to check the correctness of the change.