|
14 | 14 | import inspect
|
15 | 15 | import types
|
16 | 16 | from types import MethodType
|
17 |
| -from typing import Any, Callable, Dict, Tuple, TypeVar, cast |
| 17 | +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast |
18 | 18 |
|
19 | 19 | from torch import nn
|
20 | 20 |
|
|
31 | 31 | class ForwardWithHooks:
|
32 | 32 | """Class to wrap forward function of nn.Module, to forward function of the model with enabled FunctionHookMode"""
|
33 | 33 |
|
34 |
| - __slots__ = "_func", "__dict__", "__weakref__" |
35 |
| - _func: Callable[..., Any] |
36 |
| - |
37 |
| - def __new__(cls, orig_forward: Callable[..., Any]) -> ForwardWithHooks: |
38 |
| - if not callable(orig_forward): |
39 |
| - msg = "the first argument must be callable" |
40 |
| - raise TypeError(msg) |
| 34 | + __slots__ = "_orig_forward", "_model", "__dict__", "__weakref__" |
| 35 | + _orig_forward: Callable[..., Any] |
| 36 | + _model: nn.Module |
41 | 37 |
|
42 |
| - if isinstance(orig_forward, ForwardWithHooks): |
| 38 | + def __new__(cls, model: nn.Module, orig_forward: Optional[Callable[..., Any]] = None) -> ForwardWithHooks: |
| 39 | + if isinstance(model.forward, ForwardWithHooks): |
43 | 40 | msg = "Func already wrapped"
|
44 | 41 | raise TypeError(msg)
|
45 | 42 |
|
46 | 43 | self = super().__new__(cls)
|
47 | 44 |
|
48 |
| - self._func = orig_forward |
| 45 | + self._orig_forward = model.forward if orig_forward is None else orig_forward |
| 46 | + self._model = model |
49 | 47 | return self
|
50 | 48 |
|
51 | 49 | def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
52 |
| - model = cast(nn.Module, self._func.__self__) # type: ignore[attr-defined] |
53 |
| - with FunctionHookMode(model=model, hook_storage=get_hook_storage(model)) as ctx: |
| 50 | + with FunctionHookMode(model=self.model, hook_storage=get_hook_storage(self.model)) as ctx: |
54 | 51 | args, kwargs = ctx.process_model_inputs(args, kwargs)
|
55 |
| - outputs = self._func(*args, **kwargs) |
| 52 | + outputs = self.orig_forward(*args, **kwargs) |
56 | 53 | outputs = ctx.process_model_outputs(outputs)
|
57 | 54 | return outputs
|
58 | 55 |
|
59 | 56 | def __repr__(self) -> str:
|
60 |
| - return f"ForwardWithHooks.{repr(self._func)}" |
| 57 | + return f"ForwardWithHooks.{repr(self.orig_forward)}" |
61 | 58 |
|
62 |
| - def __reduce__(self) -> Tuple[Callable[..., Any], Tuple[Any, ...], Tuple[Any, ...]]: |
63 |
| - return type(self), (self._func,), (self._func, self.__dict__ or None) |
| 59 | + def __reduce__(self) -> Any: |
| 60 | + return type(self), (self.model, self.orig_forward), (self.model, self.orig_forward, self.__dict__ or None) |
64 | 61 |
|
65 |
| - def __setstate__(self, state: Tuple[Any, Any]) -> None: |
| 62 | + def __setstate__(self, state: Any) -> None: |
66 | 63 | if not isinstance(state, tuple):
|
67 | 64 | msg = "argument to __setstate__ must be a tuple"
|
68 | 65 | raise TypeError(msg)
|
69 |
| - if len(state) != 2: |
70 |
| - msg = f"expected 2 items in state, got {len(state)}" |
| 66 | + if len(state) != 3: |
| 67 | + msg = f"expected 3 items in state, got {len(state)}" |
71 | 68 | raise TypeError(msg)
|
72 |
| - func, namespace = state |
73 |
| - if not callable(func) or (namespace is not None and not isinstance(namespace, dict)): |
| 69 | + model, orig_forward, namespace = state |
| 70 | + if not callable(orig_forward) or (namespace is not None and not isinstance(namespace, dict)): |
74 | 71 | msg = "invalid partial state"
|
75 | 72 | raise TypeError(msg)
|
76 | 73 |
|
77 | 74 | if namespace is None:
|
78 | 75 | namespace = {}
|
79 | 76 |
|
| 77 | + self._model = model |
| 78 | + self._orig_forward = orig_forward |
80 | 79 | self.__dict__ = namespace
|
81 |
| - self._func = func |
82 | 80 |
|
83 | 81 | @property
|
84 | 82 | def __code__(self) -> types.CodeType:
|
85 | 83 | return self.__call__.__code__
|
86 | 84 |
|
87 | 85 | @property
|
88 | 86 | def __globals__(self) -> Dict[str, Any]:
|
89 |
| - return self._func.__globals__ |
| 87 | + return self.orig_forward.__globals__ |
90 | 88 |
|
91 | 89 | @property
|
92 | 90 | def __name__(self) -> str:
|
93 |
| - return self._func.__name__ |
| 91 | + return self.orig_forward.__name__ |
94 | 92 |
|
95 | 93 | @property
|
96 | 94 | def __signature__(self) -> inspect.Signature:
|
97 |
| - return inspect.signature(self._func) |
| 95 | + return inspect.signature(self.orig_forward) |
| 96 | + |
| 97 | + @property |
| 98 | + def orig_forward(self) -> Callable[..., Any]: |
| 99 | + return self._orig_forward |
| 100 | + |
| 101 | + @property |
| 102 | + def model(self) -> nn.Module: |
| 103 | + return self._model |
98 | 104 |
|
99 | 105 |
|
100 | 106 | class ReplicateForDataParallel:
|
@@ -135,16 +141,16 @@ def __call__(self, *args: Any, **kwargs: Any) -> nn.Module:
|
135 | 141 | raise nncf.InternalError(msg)
|
136 | 142 |
|
137 | 143 | if not (
|
138 |
| - isinstance(saved_forward_with_hooks._func, types.MethodType) |
139 |
| - and saved_forward_with_hooks._func.__func__ is module.__class__.forward |
| 144 | + isinstance(saved_forward_with_hooks.orig_forward, types.MethodType) |
| 145 | + and saved_forward_with_hooks.orig_forward.__func__ is module.__class__.forward |
140 | 146 | ):
|
141 | 147 | msg = "Not supported overridden forward method of original module"
|
142 | 148 | raise nncf.InternalError(msg)
|
143 | 149 |
|
144 | 150 | module.__dict__.pop("forward")
|
145 | 151 |
|
146 |
| - replica: nn.Module = self._func(*args, **kwargs) |
147 |
| - replica.forward = ForwardWithHooks(replica.forward) |
| 152 | + replica: nn.Module = self.func(*args, **kwargs) |
| 153 | + replica.forward = ForwardWithHooks(replica) |
148 | 154 | module.forward = saved_forward_with_hooks
|
149 | 155 |
|
150 | 156 | return replica
|
@@ -193,7 +199,7 @@ def wrap_model(model: TModel) -> TModel:
|
193 | 199 | :param model: The nn.Module to be wrapped.
|
194 | 200 | :return: The modified model with the custom behavior injected.
|
195 | 201 | """
|
196 |
| - model.forward = ForwardWithHooks(model.forward) |
| 202 | + model.forward = ForwardWithHooks(model) |
197 | 203 | model._replicate_for_data_parallel = ReplicateForDataParallel(model._replicate_for_data_parallel) # type: ignore
|
198 | 204 | model.add_module(ATR_HOOK_STORAGE, HookStorage())
|
199 | 205 | return model
|
@@ -236,7 +242,6 @@ def register_pre_function_hook(model: nn.Module, op_name: str, port_id: int, hoo
|
236 | 242 | :param op_name: The name of the operation associated with the hook.
|
237 | 243 | :param port_id: The port ID associated with the hook.
|
238 | 244 | :param hook: The pre-function hook module to be executed.
|
239 |
| -
|
240 | 245 | :return: A handle that can be used to remove the hook later.
|
241 | 246 | """
|
242 | 247 | hook_storage = get_hook_storage(model)
|
|
0 commit comments