Skip to content

Commit cdf7208

Browse files
[PT2] Introduce PT2OpLayerAttribute (#3178)
### Changes - Introduced `PT2OpLayerAttribute`, to collect called function, attributes and constant ports - `FunctionMeta` stored function instead of function name ### Reason for changes Needs to implement subgraph extractor for FBC ### Related tickets 152996 ### Tests tests/torch2/function_hook/nncf_graph/test_layer_attributes.py
1 parent 1b6af84 commit cdf7208

File tree

10 files changed

+256
-24
lines changed

10 files changed

+256
-24
lines changed

nncf/experimental/torch2/function_hook/graph/build_graph_mode.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Any, Dict, Tuple, Union, cast
14+
from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
1515

1616
import networkx as nx # type: ignore[import-untyped]
1717
import torch
@@ -185,12 +185,12 @@ def process_tensor_attributes(self, output: torch.Tensor, op_meta: OpMeta) -> No
185185
:param output: The output tensor.
186186
:param op_meta: Metadata about the operation.
187187
"""
188-
fn_name = None
188+
func: Optional[Callable[..., Any]] = None
189189
fn_kwargs = None
190190

191191
if output.grad_fn is not None:
192192
if output.grad_fn.name() == "TransposeBackward0":
193-
fn_name = "transpose"
193+
func = torch.transpose
194194
# grad_fn collect arguments as _saved_dim0=18446744073709551614
195195
# Use static arguments for .mT
196196
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.mT
@@ -199,11 +199,11 @@ def process_tensor_attributes(self, output: torch.Tensor, op_meta: OpMeta) -> No
199199
"dim1": -1,
200200
}
201201
if output.grad_fn.name() == "PermuteBackward0":
202-
fn_name = "permute"
202+
func = torch.permute
203203
fn_kwargs = {"dims": output.grad_fn._saved_dims} # type: ignore[attr-defined]
204204

205-
if fn_name is not None and fn_kwargs is not None:
206-
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].fn_name = fn_name
205+
if func is not None and fn_kwargs is not None:
206+
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].func = func
207207
self.graph.nodes[op_meta.extra_info["node_id"]]["meta"].kwargs = fn_kwargs
208208

209209
def execute_post_hooks(self, outputs: Any, op_meta: OpMeta) -> Any:
@@ -322,7 +322,7 @@ def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: Op
322322
self.graph.add_node(
323323
node_id,
324324
type=NodeType.fn_call,
325-
meta=FunctionMeta(op_name=op_name, fn_name=op_meta.func.__name__, args=tuple(op_attrs), kwargs=op_kwargs),
325+
meta=FunctionMeta(op_name=op_name, func=op_meta.func, args=tuple(op_attrs), kwargs=op_kwargs),
326326
)
327327

328328
logger.debug(f"GraphBuilderMode.process_op_inputs: {node_id=} {op_name=} {op_attrs=} {op_kwargs=}")

nncf/experimental/torch2/function_hook/graph/graph_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from dataclasses import dataclass
1515
from enum import Enum
16-
from typing import Any, Dict, Optional, Tuple
16+
from typing import Any, Callable, Dict, Optional, Tuple
1717

1818
import torch
1919

@@ -75,10 +75,14 @@ def from_tensor(tensor: torch.Tensor, name: str) -> InOutMeta:
7575
@dataclass
7676
class FunctionMeta:
7777
op_name: str
78-
fn_name: str
78+
func: Callable[..., Any]
7979
args: Tuple[Any, ...]
8080
kwargs: Dict[str, Any]
8181

82+
@property
83+
def func_name(self) -> str:
84+
return self.func.__name__
85+
8286

8387
@dataclass
8488
class EdgeMeta:

nncf/experimental/torch2/function_hook/graph/graph_visualization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def get_label_from_node_data(node_data: Dict[str, Any], style: PydotStyleTemplat
108108
rows = [
109109
f"type: {node_type}",
110110
f"op_name: {meta.op_name}",
111-
f"fn_name: {meta.fn_name}",
111+
f"fn_name: {meta.func_name}",
112112
f"args: {args_to_label(meta.args)}",
113113
f"kwargs: {kwargs_to_label(meta.kwargs)}",
114114
]
@@ -195,7 +195,7 @@ def get_style(node: Dict[str, Any], style: PydotStyleTemplate) -> Dict[str, str]
195195
}
196196
if isinstance(meta, FunctionMeta):
197197
return {
198-
"fillcolor": color_picker(meta.fn_name),
198+
"fillcolor": color_picker(meta.func_name),
199199
"fontcolor": "#000000",
200200
"shape": "record",
201201
"style": '"filled,rounded"',
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from dataclasses import dataclass
13+
from typing import Any, Callable, Dict, Set, Tuple
14+
15+
from nncf.common.graph.layer_attributes import BaseLayerAttributes
16+
17+
18+
@dataclass(frozen=True)
19+
class PT2OpLayerAttributes(BaseLayerAttributes):
20+
"""
21+
This class stores information about operation.
22+
23+
:param func: Function that the operation represents.
24+
:param op_args: Tuple of positional arguments for the operation.
25+
:param op_kwargs: Dictionary of keyword arguments for the operation.
26+
:param constant_port_ids: Set of input port indices with constants.
27+
"""
28+
29+
func: Callable[..., Any]
30+
op_args: Tuple[Any, ...]
31+
op_kwargs: Dict[str, Any]
32+
constant_port_ids: Set[int]

nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py

+79-8
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
from collections import defaultdict
14-
from typing import Any, Dict, List, Optional, Tuple, Union, cast
14+
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
1515

1616
import networkx as nx # type: ignore
1717
import torch
@@ -21,13 +21,15 @@
2121
import nncf.torch.graph.operator_metatypes as om
2222
from nncf.common.graph.graph import NNCFGraph
2323
from nncf.common.graph.graph import NNCFNode
24+
from nncf.common.graph.layer_attributes import BaseLayerAttributes
2425
from nncf.common.graph.layer_attributes import Dtype
2526
from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph
2627
from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta
2728
from nncf.experimental.torch2.function_hook.graph.graph_utils import EdgeMeta
2829
from nncf.experimental.torch2.function_hook.graph.graph_utils import FunctionMeta
2930
from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta
3031
from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType
32+
from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes
3133

3234

3335
def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> str:
@@ -45,7 +47,7 @@ def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta
4547
if isinstance(meta, ConstMeta):
4648
return "nncf_model_const"
4749
if isinstance(meta, FunctionMeta):
48-
return meta.fn_name
50+
return meta.func_name
4951
raise nncf.InternalError("Unexpected metadata type")
5052

5153

@@ -77,20 +79,86 @@ def get_dtype(dtype: torch.dtype) -> Dtype:
7779
return Dtype.INTEGER
7880

7981

80-
def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> om.PTOperatorMetatype:
82+
def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> type[om.PTOperatorMetatype]:
8183
"""
8284
Converts the node type and metadata into a PTOperatorMetatype object.
8385
:param node_type: The type of the node.
8486
:param meta: The metadata associated with the node.
8587
:return: The PTOperatorMetatype object.
8688
"""
87-
node_metatype = cast(om.PTOperatorMetatype, om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type))
88-
node_sub_meta_type: Optional[om.PTOperatorMetatype] = None
89+
node_metatype = cast(
90+
type[om.PTOperatorMetatype], om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
91+
)
92+
node_sub_meta_type: Optional[type[om.PTOperatorMetatype]] = None
8993
if node_metatype.get_subtypes() and isinstance(meta, FunctionMeta):
9094
node_sub_meta_type = node_metatype.determine_subtype(function_args=meta.args, functions_kwargs=meta.kwargs)
9195
return node_sub_meta_type or node_metatype
9296

9397

98+
def is_constant_input_node(nx_graph: nx.MultiDiGraph, node: int) -> bool:
99+
"""
100+
Check if a node is a constant input node or constant subgraph:
101+
102+
1) constant
103+
2) quantize_function -> constant
104+
105+
:param nx_graph: The graph to check the node from.
106+
:param node: The node to check.
107+
:return: True if the node is a constant input node, False otherwise.
108+
"""
109+
meta = nx_graph.nodes[node]["meta"]
110+
111+
# 1) Input node is a constant node (parameter or buffer)
112+
if isinstance(meta, ConstMeta):
113+
return True
114+
115+
# 2) Quantize node with constant input
116+
if (
117+
isinstance(meta, FunctionMeta)
118+
and meta.func_name in om.QUANTIZE_NODE_TYPES
119+
and isinstance(nx_graph.nodes[node]["meta"], FunctionMeta)
120+
):
121+
return all(isinstance(nx_graph.nodes[s_node]["meta"], ConstMeta) for s_node, _ in nx_graph.in_edges(node))
122+
123+
return False
124+
125+
126+
def get_constant_port_ids(nx_graph: nx.MultiDiGraph, node: int) -> Set[int]:
127+
"""
128+
Get the indices of input ports corresponding to the constant node or subgraph.
129+
130+
:param nx_graph: The graph to get the constant port IDs from.
131+
:param node: The node to get the constant port IDs from.
132+
:return: The list of input port indices with constants.
133+
"""
134+
constant_port_ids: Set[int] = set()
135+
136+
for s_node, _, data in nx_graph.in_edges(node, data=True):
137+
if is_constant_input_node(nx_graph, s_node):
138+
meta = cast(EdgeMeta, data["meta"])
139+
constant_port_ids.add(meta.input_port)
140+
141+
return constant_port_ids
142+
143+
144+
def get_layer_attributes(
145+
nx_graph: nx.MultiDiGraph, node: int, meta: Union[ConstMeta, FunctionMeta, InOutMeta]
146+
) -> Optional[BaseLayerAttributes]:
147+
"""
148+
Get the layer attributes of a node in the graph.
149+
150+
:param nx_graph: The graph to get the layer attributes from.
151+
:param node: The node to get the layer attributes from.
152+
:param meta: The metadata associated with the node.
153+
:return: The layer attributes of the node.
154+
"""
155+
if isinstance(meta, FunctionMeta):
156+
constant_port_ids = get_constant_port_ids(nx_graph, node)
157+
return PT2OpLayerAttributes(meta.func, meta.args, meta.kwargs, constant_port_ids)
158+
159+
return None
160+
161+
94162
def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
95163
"""
96164
Converts a graph to an NNCFGraph.
@@ -102,15 +170,18 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
102170

103171
map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {}
104172
for node, data in nx_graph.nodes(data=True):
105-
meta: Union[ConstMeta, FunctionMeta, InOutMeta] = data["meta"]
173+
meta = data["meta"]
174+
if not isinstance(meta, (ConstMeta, FunctionMeta, InOutMeta)):
175+
raise nncf.InternalError(f"Unknown metadata type: {type(meta)}")
106176
node_name = get_name_of_node(meta)
107177
node_type = get_node_type(data["type"], meta)
108178
meta_type = get_meta_type(node_type, meta)
109-
179+
layer_attributes = get_layer_attributes(nx_graph, node, meta)
110180
nncf_node = nncf_graph.add_nncf_node(
111181
node_name=node_name,
112182
node_type=node_type,
113-
node_metatype=meta_type, # type: ignore[arg-type]
183+
node_metatype=meta_type,
184+
layer_attributes=layer_attributes,
114185
)
115186
map_nx_node_to_nncf_node[node] = nncf_node
116187

nncf/torch/graph/operator_metatypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def get_all_aliases(cls) -> List[str]:
8484
@classmethod
8585
def determine_subtype(
8686
cls, layer_attributes: Optional[BaseLayerAttributes] = None, function_args=None, functions_kwargs=None
87-
) -> Optional["PTOperatorSubtype"]:
87+
) -> Optional["type[PTOperatorSubtype]"]:
8888
matches = []
8989
for subtype in cls.get_subtypes():
9090
if subtype.matches(layer_attributes, function_args, functions_kwargs):

tests/torch2/function_hook/graph/test_build_graph_mode.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_execute_pre_hooks():
101101
"type": NodeType.fn_call,
102102
"meta": FunctionMeta(
103103
op_name="/relu/0",
104-
fn_name="relu",
104+
func=torch.relu,
105105
args=(
106106
TensorMeta(dtype=torch.float32, shape=(1,), requires_grad=False),
107107
TensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),
@@ -190,14 +190,14 @@ def test_tensor_attributes(attr):
190190
if attr == ".T":
191191
ref_meta = FunctionMeta(
192192
op_name="/__get__/0",
193-
fn_name="permute",
193+
func=torch.permute,
194194
args=(TensorMeta(dtype=torch.float32, shape=(2, 3), requires_grad=True),),
195195
kwargs={"dims": (1, 0)},
196196
)
197197
else:
198198
ref_meta = FunctionMeta(
199199
op_name="/__get__/0",
200-
fn_name="transpose",
200+
func=torch.transpose,
201201
args=(TensorMeta(dtype=torch.float32, shape=(2, 3), requires_grad=True),),
202202
kwargs={"dim0": -2, "dim1": -1},
203203
)

tests/torch2/function_hook/helpers.py

+33
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,36 @@ def __init__(self) -> None:
103103
def forward(self, x: torch.Tensor) -> torch.Tensor:
104104
x = x + x
105105
return x
106+
107+
108+
class MatMulLeft(torch.nn.Module):
109+
def __init__(self):
110+
super().__init__()
111+
self.w = torch.nn.Parameter(torch.tensor([1], dtype=torch.float32))
112+
113+
@staticmethod
114+
def get_example_inputs():
115+
return torch.ones([1, 1])
116+
117+
def forward(self, x):
118+
return torch.matmul(x, self.w)
119+
120+
121+
class MatMulRight(MatMulLeft):
122+
def forward(self, x):
123+
return torch.matmul(self.w, x)
124+
125+
126+
class QuantizedConvModel(nn.Module):
127+
@staticmethod
128+
def get_example_inputs():
129+
return torch.ones([1, 1, 3, 3])
130+
131+
def __init__(self) -> None:
132+
super().__init__()
133+
self.conv = nn.Conv2d(1, 1, 1)
134+
135+
def forward(self, x: torch.Tensor):
136+
x = self.conv(x)
137+
x = torch.relu(x)
138+
return x

0 commit comments

Comments
 (0)