-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch FX] Map Namespace Names to Metatype #3237
base: develop
Are you sure you want to change the base?
Changes from 27 commits
98cb4dc
5991aef
44103b4
22c5e51
d53b445
de1361c
26e1c3a
541319d
c3cc784
e186b4d
7c1e12c
4ce8d69
215a529
a36c4e3
4ec56be
a0de15d
fb49a78
4e716c5
56dfdda
38e825a
f863b74
2bdc08f
441dde9
cbda8b6
8a4e382
ab71b81
97dc74a
8fb74e8
a90375c
11565c2
ffddec7
f7fd41c
3f2c342
9acc6ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node | ||
from nncf.torch.dynamic_graph.layer_attributes_handlers import apply_args_defaults | ||
from nncf.torch.graph.graph import PTNNCFGraph | ||
from nncf.torch.graph.operator_metatypes import FX_OPERATOR_METATYPES | ||
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES | ||
|
||
|
||
|
@@ -65,22 +66,6 @@ def _get_layer_attributes( | |
) | ||
return None | ||
|
||
def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype: | ||
""" | ||
Attempts to retrieve correct subtype for the given node. | ||
|
||
:param node: Given node. | ||
:param metatype: Given node metatype. | ||
:param model: Target GraphModule instance. | ||
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise. | ||
""" | ||
if metatype in [om.PTEmbeddingMetatype]: | ||
weight_node = node.args[0] | ||
if weight_node.op == "get_attr": | ||
return om.PTAtenEmbeddingMetatype | ||
|
||
return metatype | ||
|
||
@staticmethod | ||
def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) -> Tuple[str, om.OperatorMetatype]: | ||
""" | ||
|
@@ -90,6 +75,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) | |
:param model: Given GraphModule. | ||
:return: Node's type and metatype. | ||
""" | ||
node_type_name = None | ||
if node.op == "placeholder": | ||
node_type = "input" | ||
node_metatype = om.PTInputNoopMetatype | ||
|
@@ -101,13 +87,21 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) | |
node_metatype = om.PTConstNoopMetatype | ||
elif node.op in ("call_function",): | ||
if hasattr(node.target, "overloadpacket"): | ||
node_type = str(node.target.overloadpacket).split(".")[1] | ||
node_type = str(node.target.overloadpacket) | ||
node_type_name = node_type.split(".")[1] | ||
elif node.target.__name__ == "getitem": | ||
node_type = "__getitem__" | ||
node_type = "aten.__getitem__" | ||
node_type_name = "__getitem__" | ||
else: | ||
# TODO(dlyakhov): get correct nodes types from this nodes as well | ||
node_type = str(node.target) | ||
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) | ||
# For FX specific metatypes not registered in PT operator metatype | ||
node_metatype = ( | ||
FX_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) | ||
if node_metatype == UnknownMetatype | ||
else node_metatype | ||
) | ||
else: | ||
node_type = node.op | ||
node_metatype = UnknownMetatype | ||
|
@@ -118,7 +112,8 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) | |
layer_attrs = GraphConverter._get_layer_attributes(node, node_metatype, model) | ||
node_subtype = node_metatype.determine_subtype(layer_attrs) | ||
node_metatype = node_subtype or node_metatype | ||
return node_type, node_metatype | ||
node_type_name = node_type_name or node_type | ||
return node_type_name, node_metatype | ||
|
||
@staticmethod | ||
def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: | ||
|
@@ -135,7 +130,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: | |
const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"]) | ||
for source_node in model.graph.nodes: | ||
node_type, node_metatype = GraphConverter.get_node_type_and_metatype(source_node, model) | ||
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype) | ||
is_shared_node = source_node.op in ("get_attr",) and ( | ||
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1 | ||
) | ||
|
@@ -148,7 +142,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: | |
source_nncf_node = nncf_graph.get_node_by_name(source_node.name) | ||
for idx, dist_node in enumerate(source_node.users): | ||
dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id | ||
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params( | ||
input_port_id, output_port_id, tensor_shape, tensor_dtype = GraphConverter.get_edge_params( | ||
model, source_node, source_nncf_node, dist_node, idx | ||
) | ||
nncf_graph.add_edge_between_nncf_nodes( | ||
|
@@ -157,7 +151,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: | |
tensor_shape=tensor_shape, | ||
input_port_id=input_port_id, | ||
output_port_id=output_port_id, | ||
dtype=Dtype.FLOAT, | ||
dtype=tensor_dtype, | ||
) | ||
return nncf_graph | ||
|
||
|
@@ -182,8 +176,11 @@ def get_edge_params( | |
""" | ||
output_port_id = 0 | ||
tensor_shape = None | ||
tensor_dtype = Dtype.FLOAT | ||
if source_node.op in ("get_attr",): | ||
tensor_shape = tuple(get_tensor_constant_from_node(source_node, model).shape) | ||
tensor = get_tensor_constant_from_node(source_node, model) | ||
tensor_shape = tuple(tensor.shape) | ||
tensor_dtype = Dtype.INTEGER if tensor.dtype == torch.int else tensor_dtype | ||
elif "val" in source_node.meta: | ||
if source_nncf_node.metatype is om.PTBatchNormMetatype and isinstance( | ||
source_node.meta["val"], (tuple, list) | ||
|
@@ -197,6 +194,7 @@ def get_edge_params( | |
tensor = source_node.meta["val"] | ||
if isinstance(tensor, torch.Tensor): | ||
tensor_shape = tuple(-1 if isinstance(i, torch.SymInt) else i for i in tensor.shape) | ||
tensor_dtype = Dtype.INTEGER if tensor.dtype == torch.int else tensor_dtype | ||
elif isinstance(tensor, torch.SymInt): | ||
tensor_shape = (-1,) | ||
|
||
|
@@ -205,4 +203,4 @@ def get_edge_params( | |
nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.") | ||
|
||
input_port_id = dist_node.all_input_nodes.index(source_node) | ||
return input_port_id, output_port_id, tensor_shape | ||
return input_port_id, output_port_id, tensor_shape, tensor_dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to update return annotation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,14 @@ class FunctionMeta: | |
def func_name(self) -> str: | ||
return self.func.__name__ | ||
|
||
@property | ||
def func_namespace(self) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change func_name property, there is need only one way to get name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
if self.func.__qualname__.split(".")[0] == "TensorBase": | ||
return f"torch.tensor.{str(self.func.__name__)}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to use str within f-string There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, I missed that when changing from an older implementation. Thank you very much! |
||
elif self.func.__qualname__ == self.func.__name__: | ||
return f"torch.nn.functional.{str(self.func.__name__)}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not correctly works for function from torchvision like And using NamespaceTarget in definition of metatypes is actual only for patched functions, for FX and TorchFunctionMode tracing is it should not be used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I did not understand the suggestion "f"{func.module}.{func.name}" as op_name for functions and func.qualname for methods" do you suggest using these directly? because for example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so I will modify the metatype definition to be tensorbase.to instead of torch.tensor.to so that it can match with the correct metatype? or do I create a fake tensor and call |
||
return f"{str(self.func.__module__)}.{str(self.func.__name__)}" | ||
|
||
|
||
@dataclass | ||
class EdgeMeta: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done