Skip to content

Commit 93192d6

Browse files
[PT2] no use gradient for building graph (#3278)
### Changes Use `torch.no_grad()` for build graph Affect: - unknown type for `tenosr.T` and `tensor.H` operation, but graph is still connected - removed `requires_grad` attribute for tensor, that always false for no_grad, but it's not used anywhere ### Reason for changes High memory usage
1 parent 9df265a commit 93192d6

File tree

8 files changed

+52
-77
lines changed

8 files changed

+52
-77
lines changed

nncf/experimental/torch2/function_hook/graph/build_graph_mode.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,15 @@ def execute_post_hooks(self, outputs: Any, op_meta: OpMeta) -> Any:
212212
Overload execute_post_hooks to correct registered node for operation.
213213
Process __get__ function, to detect permute and transpose operation
214214
and remove node if operation return not tensor.
215+
216+
:param output: The output of the function.
217+
:param op_meta: Metadata for the operation.
218+
:return: The modified output after post-hooks.
215219
"""
216-
if op_meta.func.__name__ == "__get__":
217-
if isinstance(outputs, torch.Tensor):
218-
self.process_tensor_attributes(outputs, op_meta)
219-
else:
220-
# Remove the node corresponding to this operation from the graph, as non-tensor
221-
# outputs (like `tensor.shape` or similar) are not relevant for further algorithmic use.
222-
self.graph.remove_node(op_meta.extra_info["node_id"])
220+
if op_meta.func.__name__ == "__get__" and not isinstance(outputs, torch.Tensor):
221+
# Remove the node corresponding to this operation from the graph, as non-tensor
222+
# outputs (like `tensor.shape` or similar) are not relevant for further algorithmic use.
223+
self.graph.remove_node(op_meta.extra_info["node_id"])
223224
outputs = super().execute_post_hooks(outputs, op_meta)
224225
return outputs
225226

@@ -368,8 +369,7 @@ def build_graph(model: nn.Module, *args: Any, **kwargs: Any) -> nx.MultiDiGraph:
368369
:return: A nx.MultiDiGraph where nodes represent operations of model.
369370
"""
370371
with training_mode_switcher(model, is_training=False):
371-
with torch.enable_grad(): # type: ignore
372-
# Gradient use to get information about __get__ functions to detect tensor.(T, mT) attributes
372+
with torch.no_grad():
373373
with GraphBuilderMode(model=model, hook_storage=get_hook_storage(model)) as ctx:
374374
args, kwargs = ctx.process_model_inputs(args, kwargs)
375375
wrapped_forward = cast(ForwardWithHooks, model.forward)

nncf/experimental/torch2/function_hook/graph/graph_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,10 @@ def __str__(self) -> str:
4343
class TensorMeta:
4444
dtype: torch.dtype
4545
shape: Tuple[int, ...]
46-
requires_grad: bool
4746

4847
@staticmethod
4948
def from_tensor(tensor: torch.Tensor) -> TensorMeta:
50-
return TensorMeta(tensor.dtype, tuple(tensor.shape), tensor.requires_grad)
49+
return TensorMeta(tensor.dtype, tuple(tensor.shape))
5150

5251

5352
@dataclass

tests/torch2/data/function_hook/graph_visualization/to_pydot_style_full.dot

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ rankdir=TB;
33
0 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: x|dtype: torch.float32|shape: (1, 1, 3, 3)}", shape=record, style="filled,rounded"];
44
1 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: conv.weight|dtype: torch.float32|shape: (1, 1, 1, 1)}", shape=record, style="filled,rounded"];
55
2 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: conv.bias|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
6-
3 [fillcolor="#ffd6a5", fontcolor="#000000", label="{type: function_call|op_name: conv/conv2d/0|fn_name: conv2d|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n(1, 1),\n(0, 0),\n(1, 1),\n1,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
6+
3 [fillcolor="#ffd6a5", fontcolor="#000000", label="{type: function_call|op_name: conv/conv2d/0|fn_name: conv2d|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3)),\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1, 1)),\nTensorMeta(dtype=torch.float32, shape=(1,)),\n(1, 1),\n(0, 0),\n(1, 1),\n1,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
77
4 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: __nncf_hooks.post_hooks.conv/conv2d/0__0.0.w|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
8-
5 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: conv/post_hook__conv-conv2d-0__0[0]/add/0|fn_name: add|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
9-
6 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /relu/0|fn_name: relu|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3), requires_grad=True),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
8+
5 [fillcolor="#caffbf", fontcolor="#000000", label="{type: function_call|op_name: conv/post_hook__conv-conv2d-0__0[0]/add/0|fn_name: add|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3)),\nTensorMeta(dtype=torch.float32, shape=(1,)),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
9+
6 [fillcolor="#a0c4ff", fontcolor="#000000", label="{type: function_call|op_name: /relu/0|fn_name: relu|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 3, 3)),\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
1010
7 [fillcolor="#adadad", fontcolor="#000000", label="{type: output|name: output|dtype: torch.float32|shape: (1, 1, 3, 3)}", shape=record, style="filled,rounded"];
1111
0 -> 3 [label="(1, 1, 3, 3)\n0 → 0"];
1212
1 -> 3 [label="(1, 1, 1, 1)\n0 → 1"];

tests/torch2/data/function_hook/handle_inner_functions/inner_functions_BatchNormModel.dot

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ rankdir=TB;
55
2 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: bn.bias|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
66
3 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: bn.running_mean|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
77
4 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: bn.running_var|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
8-
5 [fillcolor="#ffadad", fontcolor="#000000", label="{type: function_call|op_name: bn/batch_norm/0|fn_name: batch_norm|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=False),\nFalse,\n0.1,\n1e-05,\nTrue,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
8+
5 [fillcolor="#ffadad", fontcolor="#000000", label="{type: function_call|op_name: bn/batch_norm/0|fn_name: batch_norm|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1)),\nTensorMeta(dtype=torch.float32, shape=(1,)),\nTensorMeta(dtype=torch.float32, shape=(1,)),\nTensorMeta(dtype=torch.float32, shape=(1,)),\nTensorMeta(dtype=torch.float32, shape=(1,)),\nFalse,\n0.1,\n1e-05,\nTrue,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
99
6 [fillcolor="#adadad", fontcolor="#000000", label="{type: output|name: output|dtype: torch.float32|shape: (1, 1, 1)}", shape=record, style="filled,rounded"];
1010
0 -> 5 [label="(1, 1, 1)\n0 → 0"];
1111
1 -> 5 [label="(1,)\n0 → 1"];

0 commit comments

Comments
 (0)