Skip to content

Commit 61f5886

Browse files
[PT2] Use unique inport_port_id for list of tensors (#3271)
### Changes - Use unique inport_port_id for input of list of tensors - Keep type for arguments - Add init_weights=False for models.inception_v3 to avoid depreceted warning ### Reason for changes Align with current tracing
1 parent 2f3fb1c commit 61f5886

18 files changed

+2555
-2732
lines changed

nncf/experimental/torch2/function_hook/graph/build_graph_mode.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -275,22 +275,13 @@ def register_op_input(self, arg: Any, node_id: int, port_id: int, op_meta: OpMet
275275
:param node_id: Id if operation node.
276276
:param port_id: Port id of input argument.
277277
:param op_meta: Metadata about the operation.
278-
:return: Descriptor of the input. For a Tensor, this is a `TensorMeta` object.
279-
For a collection of Tensors, a collection of `TensorMeta` objects is returned.
280-
For other types, the original input `arg` is returned as-is.
278+
:return: Descriptor of the input.
279+
For a Tensor, this is a `TensorMeta` object.
280+
For other types, the original input `arg` is returned as-is.
281281
"""
282282
if isinstance(arg, torch.Tensor):
283283
self.register_op_input_tensor(arg, node_id, port_id, op_meta)
284284
return TensorMeta.from_tensor(arg)
285-
elif isinstance(arg, (list, tuple, set)):
286-
op_attr = []
287-
for x in arg:
288-
if isinstance(x, torch.Tensor):
289-
self.register_op_input_tensor(x, node_id, port_id, op_meta)
290-
op_attr.append(TensorMeta.from_tensor(x))
291-
else:
292-
op_attr.append(x)
293-
return op_attr
294285
return arg
295286

296287
def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: OpMeta) -> None:
@@ -312,13 +303,30 @@ def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: Op
312303

313304
op_attrs = []
314305
op_kwargs = {}
315-
for port_id, arg in enumerate(args):
316-
op_attr = self.register_op_input(arg, node_id, port_id, op_meta)
317-
op_attrs.append(op_attr)
318-
319-
for port_id, (name, arg) in enumerate(kwargs.items(), start=len(args)):
320-
op_attr = self.register_op_input(arg, node_id, port_id, op_meta)
321-
op_kwargs[name] = op_attr
306+
port_id = 0
307+
308+
for value in args:
309+
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
310+
list_attr = [None] * len(value)
311+
for idx, tensor in enumerate(value):
312+
op_attr = self.register_op_input(tensor, node_id, port_id, op_meta)
313+
list_attr[idx] = op_attr
314+
port_id += 1
315+
op_attrs.append(tuple(list_attr) if isinstance(value, tuple) else list_attr)
316+
else:
317+
op_attr = self.register_op_input(value, node_id, port_id, op_meta)
318+
op_attrs.append(op_attr)
319+
port_id += 1
320+
321+
for kw_name, value in kwargs.items():
322+
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
323+
op_kwargs[kw_name] = [None] * len(value)
324+
for tensor_idx, tensor in enumerate(value):
325+
op_kwargs[kw_name][tensor_idx] = self.register_op_input(value, node_id, port_id, op_meta)
326+
port_id += 1
327+
else:
328+
op_kwargs[kw_name] = self.register_op_input(value, node_id, port_id, op_meta)
329+
port_id += 1
322330

323331
self.graph.add_node(
324332
node_id,

nncf/experimental/torch2/function_hook/hook_executor_mode.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from itertools import chain
2121
from types import MethodType
2222
from types import TracebackType
23-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type
23+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, cast
2424
from weakref import ReferenceType
2525
from weakref import ref
2626

@@ -348,13 +348,34 @@ def execute_pre_hooks(
348348
with self:
349349
_args, kwargs = self.process_parameters(_args, kwargs)
350350

351+
port_id = 0
351352
for idx, value in enumerate(_args):
352-
_args[idx] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, idx, value)
353-
354-
for port_id, kw_name in enumerate(kwargs, start=len(_args)):
355-
kwargs[kw_name] = self.hook_storage.execute_pre_function_hooks(
356-
op_meta.op_name, port_id, kwargs[kw_name]
357-
)
353+
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
354+
is_tuple = isinstance(value, tuple)
355+
list_args = cast(List[Tensor], list(value) if is_tuple else value)
356+
for tensor_idx, tensor in enumerate(list_args):
357+
list_args[tensor_idx] = self.hook_storage.execute_pre_function_hooks(
358+
op_meta.op_name, port_id, tensor
359+
)
360+
port_id += 1
361+
_args[idx] = tuple(list_args) if is_tuple else list_args
362+
else:
363+
_args[idx] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, port_id, value)
364+
port_id += 1
365+
366+
for kw_name, value in kwargs.items():
367+
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
368+
is_tuple = isinstance(value, tuple)
369+
list_args = cast(List[Tensor], list(value) if is_tuple else value)
370+
for tensor_idx, tensor in enumerate(list_args):
371+
list_args[tensor_idx] = self.hook_storage.execute_pre_function_hooks(
372+
op_meta.op_name, port_id, tensor
373+
)
374+
port_id += 1
375+
kwargs[kw_name] = tuple(list_args) if is_tuple else list_args
376+
else:
377+
kwargs[kw_name] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, port_id, value)
378+
port_id += 1
358379
return tuple(_args), kwargs
359380

360381
def process_post_function_hooks_for_value(self, value: Any, op_meta: OpMeta, port_id: int) -> Any:

tests/torch2/data/function_hook/graph_visualization/to_pydot_style_full.dot

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ rankdir=TB;
33
0 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: x|dtype: torch.float32|shape: (1, 1, 3, 3)}", shape=record, style="filled,rounded"];
44
1 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: conv.weight|dtype: torch.float32|shape: (1, 1, 1, 1)}", shape=record, style="filled,rounded"];
55
2 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: conv.bias|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
6-
3 [fillcolor="#ffd6a5", fontcolor="#000000", label="{type: function_call|op_name: conv/conv2d/0|fn_name: conv2d|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n[1, 1],\n[0, 0],\n[1, 1],\n1,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
6+
3 [fillcolor="#ffd6a5", fontcolor="#000000", label="{type: function_call|op_name: conv/conv2d/0|fn_name: conv2d|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n(1, 1),\n(0, 0),\n(1, 1),\n1,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
77
4 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: __nncf_hooks.post_hooks.conv/conv2d/0__0.0.w|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
88
5 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: conv/post_hook__conv-conv2d-0__0[0]/add/0|fn_name: add|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
99
6 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /relu/0|fn_name: relu|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=True),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];

tests/torch2/data/function_hook/handle_inner_functions/inner_functions_MultiHeadAttention.dot

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ rankdir=TB;
33
0 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: query|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
44
1 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: key|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
55
2 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: value|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
6-
3 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/0|fn_name: rand|args: [\n[48, 16],\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
6+
3 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/0|fn_name: rand|args: [\n(48, 16),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
77
4 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/1|fn_name: rand|args: [\n48,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
8-
5 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/2|fn_name: rand|args: [\n[16, 16],\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
8+
5 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/2|fn_name: rand|args: [\n(16, 16),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
99
6 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/3|fn_name: rand|args: [\n16,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
1010
13 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: /chunk/0|fn_name: chunk|args: [\nTensorMeta(dtype=torch.float32, shape=(48, 16), requires_grad=False),\n3,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
1111
14 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: /chunk/1|fn_name: chunk|args: [\nTensorMeta(dtype=torch.float32, shape=(48,), requires_grad=False),\n3,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];

0 commit comments

Comments
 (0)