Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch FX] Map Namespace Names to Metatype #3237

Open
wants to merge 34 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
98cb4dc
comment changes
anzr299 Feb 4, 2025
5991aef
update operator metatypes
anzr299 Feb 4, 2025
44103b4
revert breaking changes
anzr299 Feb 4, 2025
22c5e51
Update operator_metatypes.py
anzr299 Feb 5, 2025
d53b445
fix aten op error
anzr299 Feb 6, 2025
de1361c
fix
anzr299 Feb 6, 2025
26e1c3a
layer norm metatypes update
anzr299 Feb 7, 2025
541319d
include aten layernorm in backends and metatype lookup list
anzr299 Feb 7, 2025
c3cc784
update reference files
anzr299 Feb 7, 2025
e186b4d
pre commit
anzr299 Feb 7, 2025
7c1e12c
fix for PT traced tensor issue
anzr299 Feb 10, 2025
4ce8d69
fix unnecesary changes
anzr299 Feb 11, 2025
215a529
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Feb 27, 2025
a36c4e3
update synthetic transformer reference graphs and reference values
anzr299 Feb 28, 2025
4ec56be
update get_all_aliases method
anzr299 Feb 28, 2025
a0de15d
add suport for edge case for int input of embedding node
anzr299 Mar 4, 2025
fb49a78
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 4, 2025
4e716c5
fix error
anzr299 Mar 4, 2025
56dfdda
add PT2 metatype mapping from namespace
anzr299 Mar 4, 2025
38e825a
pre commit fix
anzr299 Mar 4, 2025
f863b74
add pt2 operator metatype registry to known registries in PT operator…
anzr299 Mar 4, 2025
2bdc08f
pre commit fix
anzr299 Mar 4, 2025
441dde9
update reference graphs; fix minor debugging leftover
anzr299 Mar 5, 2025
cbda8b6
Merge branch 'develop' into fx/metatypes
anzr299 Mar 6, 2025
8a4e382
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 7, 2025
ab71b81
update graph builder for dynamic shapes
anzr299 Mar 7, 2025
97dc74a
update reference graphs
anzr299 Mar 7, 2025
8fb74e8
Remove Extra Registry;
anzr299 Mar 10, 2025
a90375c
update metatypes
anzr299 Mar 11, 2025
11565c2
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 11, 2025
ffddec7
update reference graphs and metatype for convtranspose2d
anzr299 Mar 11, 2025
f7fd41c
revert reference graph changes
anzr299 Mar 11, 2025
3f2c342
update a wrong reference file
anzr299 Mar 11, 2025
9acc6ae
update operator metatype mapping
anzr299 Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, name: str):
"""
super().__init__(name)
self._op_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {}
self._func_name_to_op_meta_dict: Dict[str, Type[OperatorMetatype]] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def register(self, name: Optional[str] = None, is_subtype: bool = False) -> Callable[..., Type[OperatorMetatype]]:
"""
Expand Down
46 changes: 22 additions & 24 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.torch.dynamic_graph.layer_attributes_handlers import apply_args_defaults
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import FX_OPERATOR_METATYPES
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES


Expand Down Expand Up @@ -65,22 +66,6 @@ def _get_layer_attributes(
)
return None

def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype:
"""
Attempts to retrieve correct subtype for the given node.

:param node: Given node.
:param metatype: Given node metatype.
:param model: Target GraphModule instance.
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
"""
if metatype in [om.PTEmbeddingMetatype]:
weight_node = node.args[0]
if weight_node.op == "get_attr":
return om.PTAtenEmbeddingMetatype

return metatype

@staticmethod
def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) -> Tuple[str, om.OperatorMetatype]:
"""
Expand All @@ -90,6 +75,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
:param model: Given GraphModule.
:return: Node's type and metatype.
"""
node_type_name = None
if node.op == "placeholder":
node_type = "input"
node_metatype = om.PTInputNoopMetatype
Expand All @@ -101,13 +87,21 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
node_metatype = om.PTConstNoopMetatype
elif node.op in ("call_function",):
if hasattr(node.target, "overloadpacket"):
node_type = str(node.target.overloadpacket).split(".")[1]
node_type = str(node.target.overloadpacket)
node_type_name = node_type.split(".")[1]
elif node.target.__name__ == "getitem":
node_type = "__getitem__"
node_type = "aten.__getitem__"
node_type_name = "__getitem__"
else:
# TODO(dlyakhov): get correct nodes types from this nodes as well
node_type = str(node.target)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
# For FX specific metatypes not registered in PT operator metatype
node_metatype = (
FX_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
if node_metatype == UnknownMetatype
else node_metatype
)
else:
node_type = node.op
node_metatype = UnknownMetatype
Expand All @@ -118,7 +112,8 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
layer_attrs = GraphConverter._get_layer_attributes(node, node_metatype, model)
node_subtype = node_metatype.determine_subtype(layer_attrs)
node_metatype = node_subtype or node_metatype
return node_type, node_metatype
node_type_name = node_type_name or node_type
return node_type_name, node_metatype

@staticmethod
def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
Expand All @@ -135,7 +130,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"])
for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter.get_node_type_and_metatype(source_node, model)
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)
is_shared_node = source_node.op in ("get_attr",) and (
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
)
Expand All @@ -148,7 +142,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
source_nncf_node = nncf_graph.get_node_by_name(source_node.name)
for idx, dist_node in enumerate(source_node.users):
dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params(
input_port_id, output_port_id, tensor_shape, tensor_dtype = GraphConverter.get_edge_params(
model, source_node, source_nncf_node, dist_node, idx
)
nncf_graph.add_edge_between_nncf_nodes(
Expand All @@ -157,7 +151,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
tensor_shape=tensor_shape,
input_port_id=input_port_id,
output_port_id=output_port_id,
dtype=Dtype.FLOAT,
dtype=tensor_dtype,
)
return nncf_graph

Expand All @@ -182,8 +176,11 @@ def get_edge_params(
"""
output_port_id = 0
tensor_shape = None
tensor_dtype = Dtype.FLOAT
if source_node.op in ("get_attr",):
tensor_shape = tuple(get_tensor_constant_from_node(source_node, model).shape)
tensor = get_tensor_constant_from_node(source_node, model)
tensor_shape = tuple(tensor.shape)
tensor_dtype = Dtype.INTEGER if tensor.dtype == torch.int else tensor_dtype
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype and isinstance(
source_node.meta["val"], (tuple, list)
Expand All @@ -197,6 +194,7 @@ def get_edge_params(
tensor = source_node.meta["val"]
if isinstance(tensor, torch.Tensor):
tensor_shape = tuple(-1 if isinstance(i, torch.SymInt) else i for i in tensor.shape)
tensor_dtype = Dtype.INTEGER if tensor.dtype == torch.int else tensor_dtype
elif isinstance(tensor, torch.SymInt):
tensor_shape = (-1,)

Expand All @@ -205,4 +203,4 @@ def get_edge_params(
nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.")

input_port_id = dist_node.all_input_nodes.index(source_node)
return input_port_id, output_port_id, tensor_shape
return input_port_id, output_port_id, tensor_shape, tensor_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update return annotation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

2 changes: 1 addition & 1 deletion nncf/experimental/torch2/function_hook/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def extract_conv(
if input_node == output_node:
return conv_module

if output_node.metatype is not om.PTBatchNormMetatype:
if output_node.metatype != om.PT2BatchNormMetatype:
msg = f"Support only PTBatchNormMetatype as output node, actual: {output_node.metatype}"
raise nncf.InternalError(msg)

Expand Down
8 changes: 8 additions & 0 deletions nncf/experimental/torch2/function_hook/graph/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ class FunctionMeta:
def func_name(self) -> str:
return self.func.__name__

@property
def func_namespace(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change func_name property, there is need only one way to get name

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if self.func.__qualname__.split(".")[0] == "TensorBase":
return f"torch.tensor.{str(self.func.__name__)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to use str within f-string

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I missed that when changing from an older implementation. Thank you very much!

elif self.func.__qualname__ == self.func.__name__:
return f"torch.nn.functional.{str(self.func.__name__)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not correctly works for function from torchvision like deform_conv2d or symmetric_quantize in nncf
So better use f"{func.__module__}.{func.__name__}" as op_name for functions and func.__qualname__ for methods.

And using NamespaceTarget in definition of metatypes is actual only for patched functions, for FX and TorchFunctionMode tracing is it should not be used.

@alexsu52

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I did not understand the suggestion "f"{func.module}.{func.name}" as op_name for functions and func.qualname for methods" do you suggest using these directly? because for example torch.tensor.to does not have a module but qualname returns TensorBase.to

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes,
torch.tensor.to is method and can not be call as torch.tensor.to and methods is not contain __module__
It can used only for instance of Tensor torch.tensor([1]).to.__qualname__
So i have suggest to use qualname for it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I will modify the metatype definition to be tensorbase.to instead of torch.tensor.to so that it can match with the correct metatype? or do I create a fake tensor and call Tensor([1]).to.__qualname__ on it?

return f"{str(self.func.__module__)}.{str(self.func.__name__)}"


@dataclass
class EdgeMeta:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.layer_attributes import ConstantLayerAttributes
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph
from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta
Expand All @@ -48,7 +49,7 @@ def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta
if isinstance(meta, ConstMeta):
return "nncf_model_const"
if isinstance(meta, FunctionMeta):
return meta.func_name
return meta.func_namespace
msg = "Unexpected metadata type"
raise nncf.InternalError(msg)

Expand Down Expand Up @@ -89,9 +90,14 @@ def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta
:param meta: The metadata associated with the node.
:return: The PTOperatorMetatype object.
"""
node_metatype = cast(
type[om.PTOperatorMetatype], om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
metatype = om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
metatype = (
om.PT2_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
if metatype == UnknownMetatype
else metatype
)

node_metatype = cast(type[om.PTOperatorMetatype], metatype)
node_sub_meta_type: Optional[type[om.PTOperatorMetatype]] = None
if node_metatype.get_subtypes() and isinstance(meta, FunctionMeta):
node_sub_meta_type = node_metatype.determine_subtype(function_args=meta.args, functions_kwargs=meta.kwargs)
Expand Down Expand Up @@ -187,7 +193,7 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> PTNNCFGraph:
layer_name=node_name,
node_metatype=meta_type,
node_name=node_name,
node_type=node_type,
node_type=node_type.split(".")[-1],
)
map_nx_node_to_nncf_node[node] = nncf_node

Expand Down
2 changes: 2 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,13 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
om.PTMaxMetatype,
om.PTSqueezeMetatype,
om.PTLayerNormMetatype,
om.PTAtenLayerNormMetatype,
om.PTModuleLayerNormMetatype,
om.PTGroupNormMetatype,
om.PTModuleGroupNormMetatype,
# Batchnorm
om.PTBatchNormMetatype,
om.PT2BatchNormMetatype,
om.PTModuleBatchNormMetatype,
# Comparison operations
om.PTGreaterEqualMetatype,
Expand Down
1 change: 1 addition & 0 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
om.PTMaxMetatype,
om.PTSqueezeMetatype,
om.PTLayerNormMetatype,
om.PTAtenLayerNormMetatype,
om.PTModuleLayerNormMetatype,
om.PTGroupNormMetatype,
om.PTModuleGroupNormMetatype,
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/layer_attributes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
om.PTConvTranspose1dMetatype, om.PTConvTranspose2dMetatype, om.PTConvTranspose3dMetatype
)
LINEAR_OP_NAMES = get_all_aliases(om.PTLinearMetatype)
BATCHNORM_OP_NAMES = get_all_aliases(om.PTBatchNormMetatype)
BATCHNORM_OP_NAMES = get_all_aliases(om.PTBatchNormMetatype, om.PT2BatchNormMetatype)
EMBEDDING_OP_NAMES = get_all_aliases(om.PTEmbeddingMetatype, om.PTEmbeddingBagMetatype)
GROUP_NORM_OP_NAMES = get_all_aliases(om.PTGroupNormMetatype)
LAYER_NORM_OP_NAMES = get_all_aliases(om.PTLayerNormMetatype)
Expand Down
Loading