Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT2] Use unique inport_port_id for list of tensors #3271

Merged
merged 2 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions nncf/experimental/torch2/function_hook/graph/build_graph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,22 +274,13 @@ def register_op_input(self, arg: Any, node_id: int, port_id: int, op_meta: OpMet
:param node_id: Id if operation node.
:param port_id: Port id of input argument.
:param op_meta: Metadata about the operation.
:return: Descriptor of the input. For a Tensor, this is a `TensorMeta` object.
For a collection of Tensors, a collection of `TensorMeta` objects is returned.
For other types, the original input `arg` is returned as-is.
:return: Descriptor of the input.
For a Tensor, this is a `TensorMeta` object.
For other types, the original input `arg` is returned as-is.
"""
if isinstance(arg, torch.Tensor):
self.register_op_input_tensor(arg, node_id, port_id, op_meta)
return TensorMeta.from_tensor(arg)
elif isinstance(arg, (list, tuple, set)):
op_attr = []
for x in arg:
if isinstance(x, torch.Tensor):
self.register_op_input_tensor(x, node_id, port_id, op_meta)
op_attr.append(TensorMeta.from_tensor(x))
else:
op_attr.append(x)
return op_attr
return arg

def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: OpMeta) -> None:
Expand All @@ -311,13 +302,30 @@ def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: Op

op_attrs = []
op_kwargs = {}
for port_id, arg in enumerate(args):
op_attr = self.register_op_input(arg, node_id, port_id, op_meta)
op_attrs.append(op_attr)

for port_id, (name, arg) in enumerate(kwargs.items(), start=len(args)):
op_attr = self.register_op_input(arg, node_id, port_id, op_meta)
op_kwargs[name] = op_attr
port_id = 0

for value in args:
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
list_attr = [None] * len(value)
for idx, tensor in enumerate(value):
op_attr = self.register_op_input(tensor, node_id, port_id, op_meta)
list_attr[idx] = op_attr
port_id += 1
op_attrs.append(tuple(list_attr) if isinstance(value, tuple) else list_attr)
else:
op_attr = self.register_op_input(value, node_id, port_id, op_meta)
op_attrs.append(op_attr)
port_id += 1

for kw_name, value in kwargs.items():
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
op_kwargs[kw_name] = [None] * len(value)
for tensor_idx, tensor in enumerate(value):
op_kwargs[kw_name][tensor_idx] = self.register_op_input(value, node_id, port_id, op_meta)
port_id += 1
else:
op_kwargs[kw_name] = self.register_op_input(value, node_id, port_id, op_meta)
port_id += 1

self.graph.add_node(
node_id,
Expand Down
35 changes: 28 additions & 7 deletions nncf/experimental/torch2/function_hook/hook_executor_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from itertools import chain
from types import MethodType
from types import TracebackType
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, cast
from weakref import ReferenceType
from weakref import ref

Expand Down Expand Up @@ -348,13 +348,34 @@ def execute_pre_hooks(
with self:
_args, kwargs = self.process_parameters(_args, kwargs)

port_id = 0
for idx, value in enumerate(_args):
_args[idx] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, idx, value)

for port_id, kw_name in enumerate(kwargs, start=len(_args)):
kwargs[kw_name] = self.hook_storage.execute_pre_function_hooks(
op_meta.op_name, port_id, kwargs[kw_name]
)
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
is_tuple = isinstance(value, tuple)
list_args = cast(List[Tensor], list(value) if is_tuple else value)
for tensor_idx, tensor in enumerate(list_args):
list_args[tensor_idx] = self.hook_storage.execute_pre_function_hooks(
op_meta.op_name, port_id, tensor
)
port_id += 1
_args[idx] = tuple(list_args) if is_tuple else list_args
else:
_args[idx] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, port_id, value)
port_id += 1

for kw_name, value in kwargs.items():
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
is_tuple = isinstance(value, tuple)
list_args = cast(List[Tensor], list(value) if is_tuple else value)
for tensor_idx, tensor in enumerate(list_args):
list_args[tensor_idx] = self.hook_storage.execute_pre_function_hooks(
op_meta.op_name, port_id, tensor
)
port_id += 1
kwargs[kw_name] = tuple(list_args) if is_tuple else list_args
else:
kwargs[kw_name] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, port_id, value)
port_id += 1
return tuple(_args), kwargs

def process_post_function_hooks_for_value(self, value: Any, op_meta: OpMeta, port_id: int) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ rankdir=TB;
0 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: x|dtype: torch.float32|shape: (1, 1, 3, 3)}", shape=record, style="filled,rounded"];
1 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: conv.weight|dtype: torch.float32|shape: (1, 1, 1, 1)}", shape=record, style="filled,rounded"];
2 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: conv.bias|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
3 [fillcolor="#ffd6a5", fontcolor="#000000", label="{type: function_call|op_name: conv/conv2d/0|fn_name: conv2d|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n[1, 1],\n[0, 0],\n[1, 1],\n1,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
3 [fillcolor="#ffd6a5", fontcolor="#000000", label="{type: function_call|op_name: conv/conv2d/0|fn_name: conv2d|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n(1, 1),\n(0, 0),\n(1, 1),\n1,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
4 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: __nncf_hooks.post_hooks.conv/conv2d/0__0.0.w|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
5 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: conv/post_hook__conv-conv2d-0__0[0]/add/0|fn_name: add|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
6 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /relu/0|fn_name: relu|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=True),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ rankdir=TB;
0 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: query|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
1 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: key|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
2 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: value|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
3 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/0|fn_name: rand|args: [\n[48, 16],\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
3 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/0|fn_name: rand|args: [\n(48, 16),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
4 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/1|fn_name: rand|args: [\n48,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
5 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/2|fn_name: rand|args: [\n[16, 16],\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
5 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/2|fn_name: rand|args: [\n(16, 16),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
6 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/3|fn_name: rand|args: [\n16,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
13 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: /chunk/0|fn_name: chunk|args: [\nTensorMeta(dtype=torch.float32, shape=(48, 16), requires_grad=False),\n3,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
14 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: /chunk/1|fn_name: chunk|args: [\nTensorMeta(dtype=torch.float32, shape=(48,), requires_grad=False),\n3,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
Expand Down
Loading