Skip to content

Commit dd0ef13

Browse files
[PT2] rework ForwardWithHooks (#3362)
### Changes Rework ForwardWithHooks to be more flexible to patching forward function. Now model collects inside ForwardWithHooks and not used original forward to find model. Add `_has_compatible_shallow_copy_type` and `_parse_to` to ignored function names. Fix tests on cuda. ### Reason for changes ```python model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-GPTNeoXForCausalLM", torch_dtype=torch.bfloat16, device_map="auto") model = wrap_model(model) # model = cast(nn.Module, self._func.__self__) # type: ignore[attr-defined] # AttributeError: 'functools.partial' object has no attribute '__self__'. Did you mean: '__call__'? ```
1 parent cd3c678 commit dd0ef13

File tree

5 files changed

+73
-37
lines changed

5 files changed

+73
-37
lines changed

nncf/experimental/torch2/function_hook/graph/build_graph_mode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,6 @@ def build_graph(model: nn.Module, *args: Any, **kwargs: Any) -> nx.MultiDiGraph:
373373
with GraphBuilderMode(model=model, hook_storage=get_hook_storage(model)) as ctx:
374374
args, kwargs = ctx.process_model_inputs(args, kwargs)
375375
wrapped_forward = cast(ForwardWithHooks, model.forward)
376-
outputs = wrapped_forward._func(*args, **kwargs)
376+
outputs = wrapped_forward.orig_forward(*args, **kwargs)
377377
outputs = ctx.process_model_outputs(outputs)
378378
return ctx.graph

nncf/experimental/torch2/function_hook/hook_executor_mode.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
"size",
4242
"is_floating_point",
4343
"_set_grad_enabled",
44+
"_parse_to",
45+
"_has_compatible_shallow_copy_type",
4446
]
4547

4648

nncf/experimental/torch2/function_hook/wrapper.py

+36-31
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import inspect
1515
import types
1616
from types import MethodType
17-
from typing import Any, Callable, Dict, Tuple, TypeVar, cast
17+
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast
1818

1919
from torch import nn
2020

@@ -31,70 +31,76 @@
3131
class ForwardWithHooks:
3232
"""Class to wrap forward function of nn.Module, to forward function of the model with enabled FunctionHookMode"""
3333

34-
__slots__ = "_func", "__dict__", "__weakref__"
35-
_func: Callable[..., Any]
36-
37-
def __new__(cls, orig_forward: Callable[..., Any]) -> ForwardWithHooks:
38-
if not callable(orig_forward):
39-
msg = "the first argument must be callable"
40-
raise TypeError(msg)
34+
__slots__ = "_orig_forward", "_model", "__dict__", "__weakref__"
35+
_orig_forward: Callable[..., Any]
36+
_model: nn.Module
4137

42-
if isinstance(orig_forward, ForwardWithHooks):
38+
def __new__(cls, model: nn.Module, orig_forward: Optional[Callable[..., Any]] = None) -> ForwardWithHooks:
39+
if isinstance(model.forward, ForwardWithHooks):
4340
msg = "Func already wrapped"
4441
raise TypeError(msg)
4542

4643
self = super().__new__(cls)
4744

48-
self._func = orig_forward
45+
self._orig_forward = model.forward if orig_forward is None else orig_forward
46+
self._model = model
4947
return self
5048

5149
def __call__(self, *args: Any, **kwargs: Any) -> Any:
52-
model = cast(nn.Module, self._func.__self__) # type: ignore[attr-defined]
53-
with FunctionHookMode(model=model, hook_storage=get_hook_storage(model)) as ctx:
50+
with FunctionHookMode(model=self.model, hook_storage=get_hook_storage(self.model)) as ctx:
5451
args, kwargs = ctx.process_model_inputs(args, kwargs)
55-
outputs = self._func(*args, **kwargs)
52+
outputs = self.orig_forward(*args, **kwargs)
5653
outputs = ctx.process_model_outputs(outputs)
5754
return outputs
5855

5956
def __repr__(self) -> str:
60-
return f"ForwardWithHooks.{repr(self._func)}"
57+
return f"ForwardWithHooks.{repr(self.orig_forward)}"
6158

62-
def __reduce__(self) -> Tuple[Callable[..., Any], Tuple[Any, ...], Tuple[Any, ...]]:
63-
return type(self), (self._func,), (self._func, self.__dict__ or None)
59+
def __reduce__(self) -> Any:
60+
return type(self), (self.model, self.orig_forward), (self.model, self.orig_forward, self.__dict__ or None)
6461

65-
def __setstate__(self, state: Tuple[Any, Any]) -> None:
62+
def __setstate__(self, state: Any) -> None:
6663
if not isinstance(state, tuple):
6764
msg = "argument to __setstate__ must be a tuple"
6865
raise TypeError(msg)
69-
if len(state) != 2:
70-
msg = f"expected 2 items in state, got {len(state)}"
66+
if len(state) != 3:
67+
msg = f"expected 3 items in state, got {len(state)}"
7168
raise TypeError(msg)
72-
func, namespace = state
73-
if not callable(func) or (namespace is not None and not isinstance(namespace, dict)):
69+
model, orig_forward, namespace = state
70+
if not callable(orig_forward) or (namespace is not None and not isinstance(namespace, dict)):
7471
msg = "invalid partial state"
7572
raise TypeError(msg)
7673

7774
if namespace is None:
7875
namespace = {}
7976

77+
self._model = model
78+
self._orig_forward = orig_forward
8079
self.__dict__ = namespace
81-
self._func = func
8280

8381
@property
8482
def __code__(self) -> types.CodeType:
8583
return self.__call__.__code__
8684

8785
@property
8886
def __globals__(self) -> Dict[str, Any]:
89-
return self._func.__globals__
87+
return self.orig_forward.__globals__
9088

9189
@property
9290
def __name__(self) -> str:
93-
return self._func.__name__
91+
return self.orig_forward.__name__
9492

9593
@property
9694
def __signature__(self) -> inspect.Signature:
97-
return inspect.signature(self._func)
95+
return inspect.signature(self.orig_forward)
96+
97+
@property
98+
def orig_forward(self) -> Callable[..., Any]:
99+
return self._orig_forward
100+
101+
@property
102+
def model(self) -> nn.Module:
103+
return self._model
98104

99105

100106
class ReplicateForDataParallel:
@@ -135,16 +141,16 @@ def __call__(self, *args: Any, **kwargs: Any) -> nn.Module:
135141
raise nncf.InternalError(msg)
136142

137143
if not (
138-
isinstance(saved_forward_with_hooks._func, types.MethodType)
139-
and saved_forward_with_hooks._func.__func__ is module.__class__.forward
144+
isinstance(saved_forward_with_hooks.orig_forward, types.MethodType)
145+
and saved_forward_with_hooks.orig_forward.__func__ is module.__class__.forward
140146
):
141147
msg = "Not supported overridden forward method of original module"
142148
raise nncf.InternalError(msg)
143149

144150
module.__dict__.pop("forward")
145151

146-
replica: nn.Module = self._func(*args, **kwargs)
147-
replica.forward = ForwardWithHooks(replica.forward)
152+
replica: nn.Module = self.func(*args, **kwargs)
153+
replica.forward = ForwardWithHooks(replica)
148154
module.forward = saved_forward_with_hooks
149155

150156
return replica
@@ -193,7 +199,7 @@ def wrap_model(model: TModel) -> TModel:
193199
:param model: The nn.Module to be wrapped.
194200
:return: The modified model with the custom behavior injected.
195201
"""
196-
model.forward = ForwardWithHooks(model.forward)
202+
model.forward = ForwardWithHooks(model)
197203
model._replicate_for_data_parallel = ReplicateForDataParallel(model._replicate_for_data_parallel) # type: ignore
198204
model.add_module(ATR_HOOK_STORAGE, HookStorage())
199205
return model
@@ -236,7 +242,6 @@ def register_pre_function_hook(model: nn.Module, op_name: str, port_id: int, hoo
236242
:param op_name: The name of the operation associated with the hook.
237243
:param port_id: The port ID associated with the hook.
238244
:param hook: The pre-function hook module to be executed.
239-
240245
:return: A handle that can be used to remove the hook later.
241246
"""
242247
hook_storage = get_hook_storage(model)

tests/torch2/function_hook/test_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def patched_forward(*args, **kwargs):
8787
wrapped_model.forward = patched_forward
8888
optimizer = torch.optim.Adam(wrapped_model.parameters(), lr=0.1)
8989
parallel_model = torch.nn.DataParallel(wrapped_model)
90-
with pytest.raises(nncf.InternalError, match="Not supported overridden forward method"):
90+
with pytest.raises(nncf.InternalError, match="Not supported overridden forward"):
9191
run_one_epoch(parallel_model, optimizer, use_cuda=True)
9292

9393

tests/torch2/function_hook/test_wrapper.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import types
1213
from copy import deepcopy
14+
from functools import partial
1315
from pathlib import Path
1416

1517
import onnxruntime as ort
@@ -25,13 +27,40 @@
2527
ADD_VALUE = 2.0
2628

2729

28-
def test_wrapper():
30+
@pytest.mark.parametrize("forward_type", ["origin", "partial", "bound", "fn"])
31+
def test_wrapper(forward_type: str):
2932
example_input = helpers.ConvModel.get_example_inputs()
3033
model = helpers.ConvModel()
3134
model.eval()
32-
ret = model(example_input)
33-
wrapped = wrap_model(model)
34-
wrapped_ret = wrapped(example_input)
35+
36+
model._old_forward = model.forward
37+
38+
if forward_type == "partial":
39+
# Like in accelerate module
40+
def new_forward(self, x):
41+
return self._old_forward(x)
42+
43+
model.forward = partial(new_forward, model)
44+
45+
elif forward_type == "methodtype":
46+
47+
def new_forward(self, x):
48+
return model._old_forward(x)
49+
50+
model.forward = types.MethodType(new_forward, model)
51+
elif forward_type == "fn":
52+
old_forward = model.forward
53+
54+
def new_forward(x):
55+
return old_forward(x)
56+
57+
model.forward = new_forward
58+
59+
with torch.no_grad():
60+
ret = model(example_input)
61+
wrapped = wrap_model(model)
62+
wrapped_ret = wrapped(example_input)
63+
3564
torch.testing.assert_close(ret, wrapped_ret)
3665

3766

0 commit comments

Comments
 (0)