Skip to content

Commit 7c88cfa

Browse files
suppport override forward by partial
1 parent 340db48 commit 7c88cfa

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

nncf/experimental/torch2/function_hook/wrapper.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import functools
1415
import inspect
1516
import types
1617
from types import MethodType
@@ -49,7 +50,12 @@ def __new__(cls, orig_forward: Callable[..., Any]) -> ForwardWithHooks:
4950
return self
5051

5152
def __call__(self, *args: Any, **kwargs: Any) -> Any:
52-
model = cast(nn.Module, self._func.__self__) # type: ignore[attr-defined]
53+
if hasattr(self._func, "__self__"):
54+
# For bound method module is stored in the __self__
55+
model = cast(nn.Module, self._func.__self__) # type: ignore[attr-defined]
56+
elif isinstance(self._func, functools.partial):
57+
# For using partial to override the forward method, module is stored in the args[0] and mapped to self arg.
58+
model = cast(nn.Module, self._func.args[0]) # type: ignore[attr-defined]
5359
with FunctionHookMode(model=model, hook_storage=get_hook_storage(model)) as ctx:
5460
args, kwargs = ctx.process_model_inputs(args, kwargs)
5561
outputs = self._func(*args, **kwargs)

tests/torch2/function_hook/test_wrapper.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import types
1213
from copy import deepcopy
14+
from functools import partial
1315
from pathlib import Path
1416

1517
import onnxruntime as ort
@@ -25,11 +27,29 @@
2527
ADD_VALUE = 2.0
2628

2729

28-
def test_wrapper():
30+
@pytest.mark.parametrize("forward_type", ["origin", "partial", "methodtype"])
31+
def test_wrapper(forward_type: str):
2932
example_input = helpers.ConvModel.get_example_inputs()
3033
model = helpers.ConvModel()
3134
model.eval()
3235
ret = model(example_input)
36+
37+
model._old_forward = model.forward
38+
39+
if forward_type == "partial":
40+
# Like in accelerate module
41+
def new_forward(self, x):
42+
return self._old_forward(x)
43+
44+
model.forward = partial(new_forward, model)
45+
46+
elif forward_type == "methodtype":
47+
48+
def new_forward(self, x):
49+
return model._old_forward(x)
50+
51+
model.forward = types.MethodType(new_forward, model)
52+
3353
wrapped = wrap_model(model)
3454
wrapped_ret = wrapped(example_input)
3555
torch.testing.assert_close(ret, wrapped_ret)

0 commit comments

Comments
 (0)