Skip to content

Commit cb9e9ac

Browse files
[PT2] Context manager to disable tracing (#3361)
### Changes Add context manager to disable tracing for some part of code in a model
1 parent d614c1e commit cb9e9ac

File tree

6 files changed

+111
-10
lines changed

6 files changed

+111
-10
lines changed

nncf/experimental/torch2/function_hook/hook_executor_mode.py

+14
Original file line numberDiff line numberDiff line change
@@ -517,3 +517,17 @@ def disable(self) -> Iterator[None]:
517517
self.enabled = False
518518
yield
519519
self.enabled = ret
520+
521+
522+
@contextmanager
523+
def disable_function_hook_mode() -> Iterator[None]:
524+
"""
525+
Temporarily disables the function tracing and execution hooks within a context.
526+
"""
527+
enabled_modes = torch.overrides._get_current_function_mode_stack() # type: ignore[no-untyped-call]
528+
state = {(mode, mode.enabled) for mode in enabled_modes if isinstance(mode, FunctionHookMode)}
529+
for mode, _ in state:
530+
mode.enabled = False
531+
yield
532+
for mode, enabled in state:
533+
mode.enabled = enabled

nncf/torch/dynamic_graph/context.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from nncf.common.utils.api_marker import api
2626
from nncf.common.utils.debug import is_debug
2727
from nncf.common.utils.patcher import PATCHER
28+
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
29+
from nncf.experimental.torch2.function_hook.hook_executor_mode import disable_function_hook_mode
2830
from nncf.torch.dynamic_graph.graph import DynamicGraph
2931
from nncf.torch.dynamic_graph.graph import DynamicGraphNode
3032
from nncf.torch.dynamic_graph.graph import DynamicGraphNodeParameters
@@ -505,9 +507,15 @@ def disable_tracing(method):
505507
Patch a method so that it will be executed within no_nncf_trace context
506508
:param method: A method to patch.
507509
"""
510+
if is_experimental_torch_tracing_enabled():
508511

509-
def no_nncf_trace_wrapper(self, fn, *args, **kwargs):
510-
with no_nncf_trace():
511-
return fn(*args, **kwargs)
512+
def no_nncf_trace_wrapper(self, fn, *args, **kwargs):
513+
with disable_function_hook_mode():
514+
return fn(*args, **kwargs)
515+
else:
516+
517+
def no_nncf_trace_wrapper(self, fn, *args, **kwargs):
518+
with no_nncf_trace():
519+
return fn(*args, **kwargs)
512520

513521
PATCHER.patch(method, no_nncf_trace_wrapper)

nncf/torch/dynamic_graph/patch_pytorch.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,18 @@ class MagicFunctionsToPatch:
181181

182182
@api(canonical_alias="nncf.torch.register_operator")
183183
def register_operator(name=None):
184-
def wrap(operator):
185-
op_name = name
186-
if op_name is None:
187-
op_name = operator.__name__
188-
return wrap_operator(operator, PatchedOperatorInfo(op_name, NamespaceTarget.EXTERNAL))
184+
if is_experimental_torch_tracing_enabled():
185+
186+
def wrap(operator):
187+
# Skip wrapping operator for tracing by TorchFunctionMode
188+
return operator
189+
else:
190+
191+
def wrap(operator):
192+
op_name = name
193+
if op_name is None:
194+
op_name = operator.__name__
195+
return wrap_operator(operator, PatchedOperatorInfo(op_name, NamespaceTarget.EXTERNAL))
189196

190197
return wrap
191198

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
digraph {
2+
rankdir=TB;
3+
0 [label="{type: input|name: x|dtype: torch.float32|shape: (1, 1)}", fillcolor="#adadad", fontcolor="#000000", shape=record, style="filled,rounded"];
4+
1 [label="{type: function_call|op_name: /add/0|fn_name: add|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1)),\n1,\n]|kwargs: \{\}}", fillcolor="#caffbf", fontcolor="#000000", shape=record, style="filled,rounded"];
5+
2 [label="{type: function_call|op_name: /sub/0|fn_name: sub|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1)),\n1,\n]|kwargs: \{\}}", fillcolor="#ffadad", fontcolor="#000000", shape=record, style="filled,rounded"];
6+
3 [label="{type: output|name: output|dtype: torch.float32|shape: (1, 1)}", fillcolor="#adadad", fontcolor="#000000", shape=record, style="filled,rounded"];
7+
0 -> 1 [label="(1, 1)\n0 → 0"];
8+
2 -> 3 [label="(1, 1)\n0 → 0"];
9+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import torch
13+
from torch import nn
14+
from torch.overrides import _get_current_function_mode_stack
15+
16+
from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph
17+
from nncf.experimental.torch2.function_hook.graph.graph_visualization import to_pydot
18+
from nncf.experimental.torch2.function_hook.hook_executor_mode import FunctionHookMode
19+
from nncf.experimental.torch2.function_hook.hook_executor_mode import disable_function_hook_mode
20+
from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage
21+
from nncf.experimental.torch2.function_hook.wrapper import wrap_model
22+
from nncf.torch import disable_tracing
23+
from tests.cross_fw.shared.paths import TEST_ROOT
24+
from tests.torch2.utils import compare_with_reference_file
25+
26+
REF_DIR = TEST_ROOT / "torch2" / "data" / "function_hook" / "disable_tracing"
27+
28+
29+
def test_disable_function_hook_mode():
30+
model = wrap_model(nn.Conv2d(1, 1, 1))
31+
with FunctionHookMode(model=model, hook_storage=get_hook_storage(model)) as ctx:
32+
assert ctx.enabled
33+
with disable_function_hook_mode():
34+
assert not ctx.enabled
35+
assert ctx.enabled
36+
37+
38+
class ModelNoTrace(nn.Module):
39+
def __init__(self):
40+
super().__init__()
41+
42+
def forward(self, x):
43+
x = x + 1
44+
x = self.foo(x)
45+
x = x - 1
46+
return x
47+
48+
def foo(self, x):
49+
mode = _get_current_function_mode_stack()
50+
assert len(mode) == 1
51+
assert isinstance(mode[0], FunctionHookMode)
52+
assert not mode[0].enabled
53+
return x - 1
54+
55+
56+
disable_tracing(ModelNoTrace.foo)
57+
58+
59+
def test_build_graph_with_disable_tracing(regen_ref_data):
60+
model = wrap_model(ModelNoTrace())
61+
graph = build_graph(model, torch.randn(1, 1))
62+
dot_graph = to_pydot(graph)
63+
compare_with_reference_file(str(dot_graph), REF_DIR / "graph.dot", regen_ref_data)

tests/torch2/function_hook/test_train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def patched_forward(*args, **kwargs):
7070
# In overridden bound method __self__ links to original model for all replicas
7171
optimizer = torch.optim.Adam(wrapped_model.parameters(), lr=0.1)
7272
parallel_model = torch.nn.DataParallel(wrapped_model)
73-
with pytest.raises(nncf.InternalError, match="Not supported overwriting forward method, expected ForwardWithHooks"):
73+
with pytest.raises(nncf.InternalError, match="Not supported overridden forward"):
7474
run_one_epoch(parallel_model, optimizer, use_cuda=True)
7575

7676

@@ -87,7 +87,7 @@ def patched_forward(*args, **kwargs):
8787
wrapped_model.forward = patched_forward
8888
optimizer = torch.optim.Adam(wrapped_model.parameters(), lr=0.1)
8989
parallel_model = torch.nn.DataParallel(wrapped_model)
90-
with pytest.raises(nncf.InternalError, match="Not supported overwriting forward method of original module"):
90+
with pytest.raises(nncf.InternalError, match="Not supported overridden forward method"):
9191
run_one_epoch(parallel_model, optimizer, use_cuda=True)
9292

9393

0 commit comments

Comments
 (0)