Skip to content

Commit a777f54

Browse files
uniq port_id for list input
1 parent 99f0c44 commit a777f54

18 files changed

+2552
-2729
lines changed

nncf/experimental/torch2/function_hook/graph/build_graph_mode.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -281,15 +281,6 @@ def register_op_input(self, arg: Any, node_id: int, port_id: int, op_meta: OpMet
281281
if isinstance(arg, torch.Tensor):
282282
self.register_op_input_tensor(arg, node_id, port_id, op_meta)
283283
return TensorMeta.from_tensor(arg)
284-
elif isinstance(arg, (list, tuple, set)):
285-
op_attr = []
286-
for x in arg:
287-
if isinstance(x, torch.Tensor):
288-
self.register_op_input_tensor(x, node_id, port_id, op_meta)
289-
op_attr.append(TensorMeta.from_tensor(x))
290-
else:
291-
op_attr.append(x)
292-
return op_attr
293284
return arg
294285

295286
def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: OpMeta) -> None:
@@ -311,13 +302,30 @@ def register_op_node(self, args: Tuple[Any], kwargs: Dict[str, Any], op_meta: Op
311302

312303
op_attrs = []
313304
op_kwargs = {}
314-
for port_id, arg in enumerate(args):
315-
op_attr = self.register_op_input(arg, node_id, port_id, op_meta)
316-
op_attrs.append(op_attr)
317-
318-
for port_id, (name, arg) in enumerate(kwargs.items(), start=len(args)):
319-
op_attr = self.register_op_input(arg, node_id, port_id, op_meta)
320-
op_kwargs[name] = op_attr
305+
port_id = 0
306+
307+
for value in args:
308+
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
309+
list_attr = [None] * len(value)
310+
for idx, tensor in enumerate(value):
311+
op_attr = self.register_op_input(tensor, node_id, port_id, op_meta)
312+
list_attr[idx] = op_attr
313+
port_id += 1
314+
op_attrs.append(tuple(list_attr) if isinstance(value, tuple) else list_attr)
315+
else:
316+
op_attr = self.register_op_input(value, node_id, port_id, op_meta)
317+
op_attrs.append(op_attr)
318+
port_id += 1
319+
320+
for kw_name, value in kwargs.items():
321+
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
322+
op_kwargs[kw_name] = [None] * len(value)
323+
for tensor_idx, tensor in enumerate(value):
324+
op_kwargs[kw_name][tensor_idx] = self.register_op_input(value, node_id, port_id, op_meta)
325+
port_id += 1
326+
else:
327+
op_kwargs[kw_name] = self.register_op_input(value, node_id, port_id, op_meta)
328+
port_id += 1
321329

322330
self.graph.add_node(
323331
node_id,

nncf/experimental/torch2/function_hook/hook_executor_mode.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from itertools import chain
2121
from types import MethodType
2222
from types import TracebackType
23-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type
23+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, cast
2424
from weakref import ReferenceType
2525
from weakref import ref
2626

@@ -348,13 +348,34 @@ def execute_pre_hooks(
348348
with self:
349349
_args, kwargs = self.process_parameters(_args, kwargs)
350350

351+
port_id = 0
351352
for idx, value in enumerate(_args):
352-
_args[idx] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, idx, value)
353-
354-
for port_id, kw_name in enumerate(kwargs, start=len(_args)):
355-
kwargs[kw_name] = self.hook_storage.execute_pre_function_hooks(
356-
op_meta.op_name, port_id, kwargs[kw_name]
357-
)
353+
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
354+
is_tuple = isinstance(value, tuple)
355+
list_args = cast(List[Tensor], list(value) if is_tuple else value)
356+
for tensor_idx, tensor in enumerate(list_args):
357+
list_args[tensor_idx] = self.hook_storage.execute_pre_function_hooks(
358+
op_meta.op_name, port_id, tensor
359+
)
360+
port_id += 1
361+
_args[idx] = tuple(list_args) if is_tuple else list_args
362+
else:
363+
_args[idx] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, port_id, value)
364+
port_id += 1
365+
366+
for kw_name, value in kwargs.items():
367+
if isinstance(value, (list, tuple)) and all(isinstance(v, torch.Tensor) for v in value):
368+
is_tuple = isinstance(value, tuple)
369+
list_args = cast(List[Tensor], list(value) if is_tuple else value)
370+
for tensor_idx, tensor in enumerate(list_args):
371+
list_args[tensor_idx] = self.hook_storage.execute_pre_function_hooks(
372+
op_meta.op_name, port_id, tensor
373+
)
374+
port_id += 1
375+
kwargs[kw_name] = tuple(list_args) if is_tuple else list_args
376+
else:
377+
kwargs[kw_name] = self.hook_storage.execute_pre_function_hooks(op_meta.op_name, port_id, value)
378+
port_id += 1
358379
return tuple(_args), kwargs
359380

360381
def process_post_function_hooks_for_value(self, value: Any, op_meta: OpMeta, port_id: int) -> Any:

tests/torch2/data/function_hook/graph_visualization/to_pydot_style_full.dot

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ rankdir=TB;
33
0 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: x|dtype: torch.float32|shape: (1, 1, 3, 3)}", shape=record, style="filled,rounded"];
44
1 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: conv.weight|dtype: torch.float32|shape: (1, 1, 1, 1)}", shape=record, style="filled,rounded"];
55
2 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: conv.bias|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
6-
3 [fillcolor="#ffd6a5", fontcolor="#000000", label="{type: function_call|op_name: conv/conv2d/0|fn_name: conv2d|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n[1, 1],\n[0, 0],\n[1, 1],\n1,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
6+
3 [fillcolor="#ffd6a5", fontcolor="#000000", label="{type: function_call|op_name: conv/conv2d/0|fn_name: conv2d|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n(1, 1),\n(0, 0),\n(1, 1),\n1,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
77
4 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: __nncf_hooks.post_hooks.conv/conv2d/0__0.0.w|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
88
5 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: conv/post_hook__conv-conv2d-0__0[0]/add/0|fn_name: add|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
99
6 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /relu/0|fn_name: relu|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=True),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];

tests/torch2/data/function_hook/handle_inner_functions/inner_functions_MultiHeadAttention.dot

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ rankdir=TB;
33
0 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: query|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
44
1 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: key|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
55
2 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: value|dtype: torch.float32|shape: (5, 2, 16)}", shape=record, style="filled,rounded"];
6-
3 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/0|fn_name: rand|args: [\n[48, 16],\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
6+
3 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/0|fn_name: rand|args: [\n(48, 16),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
77
4 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/1|fn_name: rand|args: [\n48,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
8-
5 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/2|fn_name: rand|args: [\n[16, 16],\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
8+
5 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/2|fn_name: rand|args: [\n(16, 16),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
99
6 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /rand/3|fn_name: rand|args: [\n16,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
1010
13 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: /chunk/0|fn_name: chunk|args: [\nTensorMeta(dtype=torch.float32, shape=(48, 16), requires_grad=False),\n3,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
1111
14 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: /chunk/1|fn_name: chunk|args: [\nTensorMeta(dtype=torch.float32, shape=(48,), requires_grad=False),\n3,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];

0 commit comments

Comments
 (0)