Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch FX] Map Namespace Names to Metatype #3237

Open
wants to merge 34 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
98cb4dc
comment changes
anzr299 Feb 4, 2025
5991aef
update operator metatypes
anzr299 Feb 4, 2025
44103b4
revert breaking changes
anzr299 Feb 4, 2025
22c5e51
Update operator_metatypes.py
anzr299 Feb 5, 2025
d53b445
fix aten op error
anzr299 Feb 6, 2025
de1361c
fix
anzr299 Feb 6, 2025
26e1c3a
layer norm metatypes update
anzr299 Feb 7, 2025
541319d
include aten layernorm in backends and metatype lookup list
anzr299 Feb 7, 2025
c3cc784
update reference files
anzr299 Feb 7, 2025
e186b4d
pre commit
anzr299 Feb 7, 2025
7c1e12c
fix for PT traced tensor issue
anzr299 Feb 10, 2025
4ce8d69
fix unnecesary changes
anzr299 Feb 11, 2025
215a529
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Feb 27, 2025
a36c4e3
update synthetic transformer reference graphs and reference values
anzr299 Feb 28, 2025
4ec56be
update get_all_aliases method
anzr299 Feb 28, 2025
a0de15d
add suport for edge case for int input of embedding node
anzr299 Mar 4, 2025
fb49a78
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 4, 2025
4e716c5
fix error
anzr299 Mar 4, 2025
56dfdda
add PT2 metatype mapping from namespace
anzr299 Mar 4, 2025
38e825a
pre commit fix
anzr299 Mar 4, 2025
f863b74
add pt2 operator metatype registry to known registries in PT operator…
anzr299 Mar 4, 2025
2bdc08f
pre commit fix
anzr299 Mar 4, 2025
441dde9
update reference graphs; fix minor debugging leftover
anzr299 Mar 5, 2025
cbda8b6
Merge branch 'develop' into fx/metatypes
anzr299 Mar 6, 2025
8a4e382
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 7, 2025
ab71b81
update graph builder for dynamic shapes
anzr299 Mar 7, 2025
97dc74a
update reference graphs
anzr299 Mar 7, 2025
8fb74e8
Remove Extra Registry;
anzr299 Mar 10, 2025
a90375c
update metatypes
anzr299 Mar 11, 2025
11565c2
Merge branch 'openvinotoolkit:develop' into fx/metatypes
anzr299 Mar 11, 2025
ffddec7
update reference graphs and metatype for convtranspose2d
anzr299 Mar 11, 2025
f7fd41c
revert reference graph changes
anzr299 Mar 11, 2025
3f2c342
update a wrong reference file
anzr299 Mar 11, 2025
9acc6ae
update operator metatype mapping
anzr299 Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
@@ -65,22 +65,6 @@ def _get_layer_attributes(
)
return None

def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype:
"""
Attempts to retrieve correct subtype for the given node.

:param node: Given node.
:param metatype: Given node metatype.
:param model: Target GraphModule instance.
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
"""
if metatype in [om.PTEmbeddingMetatype]:
weight_node = node.args[0]
if weight_node.op == "get_attr":
return om.PTAtenEmbeddingMetatype

return metatype

@staticmethod
def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) -> Tuple[str, om.OperatorMetatype]:
"""
@@ -90,6 +74,7 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
:param model: Given GraphModule.
:return: Node's type and metatype.
"""
node_type_name = None
if node.op == "placeholder":
node_type = "input"
node_metatype = om.PTInputNoopMetatype
@@ -101,9 +86,11 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
node_metatype = om.PTConstNoopMetatype
elif node.op in ("call_function",):
if hasattr(node.target, "overloadpacket"):
node_type = str(node.target.overloadpacket).split(".")[1]
node_type = str(node.target.overloadpacket)
node_type_name = node_type.split(".")[1]
elif node.target.__name__ == "getitem":
node_type = "__getitem__"
node_type = "aten.__getitem__"
node_type_name = "__getitem__"
else:
# TODO(dlyakhov): get correct nodes types from this nodes as well
node_type = str(node.target)
@@ -118,7 +105,8 @@ def get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule)
layer_attrs = GraphConverter._get_layer_attributes(node, node_metatype, model)
node_subtype = node_metatype.determine_subtype(layer_attrs)
node_metatype = node_subtype or node_metatype
return node_type, node_metatype
node_type_name = node_type_name or node_type
return node_type_name, node_metatype

@staticmethod
def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
@@ -135,7 +123,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"])
for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter.get_node_type_and_metatype(source_node, model)
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)
is_shared_node = source_node.op in ("get_attr",) and (
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
)
@@ -148,7 +135,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
source_nncf_node = nncf_graph.get_node_by_name(source_node.name)
for idx, dist_node in enumerate(source_node.users):
dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params(
input_port_id, output_port_id, tensor_shape, tensor_dtype = GraphConverter.get_edge_params(
model, source_node, source_nncf_node, dist_node, idx
)
nncf_graph.add_edge_between_nncf_nodes(
@@ -157,7 +144,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
tensor_shape=tensor_shape,
input_port_id=input_port_id,
output_port_id=output_port_id,
dtype=Dtype.FLOAT,
dtype=tensor_dtype,
)
return nncf_graph

@@ -168,7 +155,7 @@ def get_edge_params(
source_nncf_node: NNCFNode,
dist_node: torch.fx.Node,
output_idx: int,
) -> Tuple[int, int, Tuple[int, ...]]:
) -> Tuple[int, int, Tuple[int, ...], Dtype]:
"""
Retrieves edge params from the given source_node and dist_node pair.

@@ -182,8 +169,11 @@ def get_edge_params(
"""
output_port_id = 0
tensor_shape = None
tensor_dtype = Dtype.FLOAT
if source_node.op in ("get_attr",):
tensor_shape = tuple(get_tensor_constant_from_node(source_node, model).shape)
tensor = get_tensor_constant_from_node(source_node, model)
tensor_shape = tuple(tensor.shape)
tensor_dtype = Dtype.INTEGER if tensor.dtype == torch.int else tensor_dtype
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype and isinstance(
source_node.meta["val"], (tuple, list)
@@ -197,6 +187,7 @@ def get_edge_params(
tensor = source_node.meta["val"]
if isinstance(tensor, torch.Tensor):
tensor_shape = tuple(-1 if isinstance(i, torch.SymInt) else i for i in tensor.shape)
tensor_dtype = Dtype.INTEGER if tensor.dtype == torch.int else tensor_dtype
elif isinstance(tensor, torch.SymInt):
tensor_shape = (-1,)

@@ -205,4 +196,4 @@ def get_edge_params(
nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.")

input_port_id = dist_node.all_input_nodes.index(source_node)
return input_port_id, output_port_id, tensor_shape
return input_port_id, output_port_id, tensor_shape, tensor_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update return annotation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

2 changes: 1 addition & 1 deletion nncf/experimental/torch2/function_hook/extractor.py
Original file line number Diff line number Diff line change
@@ -176,7 +176,7 @@ def extract_conv(
if input_node == output_node:
return conv_module

if output_node.metatype is not om.PTBatchNormMetatype:
if output_node.metatype != om.PT2BatchNormMetatype:
msg = f"Support only PTBatchNormMetatype as output node, actual: {output_node.metatype}"
raise nncf.InternalError(msg)

Original file line number Diff line number Diff line change
@@ -80,7 +80,11 @@ class FunctionMeta:

@property
def func_name(self) -> str:
return self.func.__name__
if self.func.__qualname__.split(".")[0] == "TensorBase":
return f"torch.tensor.{self.func.__name__}"
elif self.func.__qualname__ == self.func.__name__:
return f"torch.nn.functional.{self.func.__name__}"
return f"{self.func.__module__}.{self.func.__name__}"


@dataclass
Original file line number Diff line number Diff line change
@@ -89,9 +89,9 @@ def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta
:param meta: The metadata associated with the node.
:return: The PTOperatorMetatype object.
"""
node_metatype = cast(
type[om.PTOperatorMetatype], om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
)
metatype = om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)

node_metatype = cast(type[om.PTOperatorMetatype], metatype)
node_sub_meta_type: Optional[type[om.PTOperatorMetatype]] = None
if node_metatype.get_subtypes() and isinstance(meta, FunctionMeta):
node_sub_meta_type = node_metatype.determine_subtype(function_args=meta.args, functions_kwargs=meta.kwargs)
@@ -187,7 +187,7 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> PTNNCFGraph:
layer_name=node_name,
node_metatype=meta_type,
node_name=node_name,
node_type=node_type,
node_type=node_type.split(".")[-1],
)
map_nx_node_to_nncf_node[node] = nncf_node

2 changes: 2 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
@@ -311,11 +311,13 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
om.PTMaxMetatype,
om.PTSqueezeMetatype,
om.PTLayerNormMetatype,
om.PTAtenLayerNormMetatype,
om.PTModuleLayerNormMetatype,
om.PTGroupNormMetatype,
om.PTModuleGroupNormMetatype,
# Batchnorm
om.PTBatchNormMetatype,
om.PT2BatchNormMetatype,
om.PTModuleBatchNormMetatype,
# Comparison operations
om.PTGreaterEqualMetatype,
1 change: 1 addition & 0 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
@@ -286,6 +286,7 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
om.PTMaxMetatype,
om.PTSqueezeMetatype,
om.PTLayerNormMetatype,
om.PTAtenLayerNormMetatype,
om.PTModuleLayerNormMetatype,
om.PTGroupNormMetatype,
om.PTModuleGroupNormMetatype,
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/layer_attributes_handlers.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@
om.PTConvTranspose1dMetatype, om.PTConvTranspose2dMetatype, om.PTConvTranspose3dMetatype
)
LINEAR_OP_NAMES = get_all_aliases(om.PTLinearMetatype)
BATCHNORM_OP_NAMES = get_all_aliases(om.PTBatchNormMetatype)
BATCHNORM_OP_NAMES = get_all_aliases(om.PTBatchNormMetatype, om.PT2BatchNormMetatype)
EMBEDDING_OP_NAMES = get_all_aliases(om.PTEmbeddingMetatype, om.PTEmbeddingBagMetatype)
GROUP_NORM_OP_NAMES = get_all_aliases(om.PTGroupNormMetatype)
LAYER_NORM_OP_NAMES = get_all_aliases(om.PTLayerNormMetatype)
Loading