Skip to content

Commit 8115aed

Browse files
[PT2] Serialize and load transformations (#3329)
### Changes Implement `get_config` and `load_from_config` function for new tracing Analog #2531 ### Related tickets 152996 ### Tests tests/torch2/function_hook/test_serialization.py
1 parent d5de30d commit 8115aed

File tree

3 files changed

+225
-0
lines changed

3 files changed

+225
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Any, Dict, List, TypedDict, TypeVar, cast
13+
from weakref import WeakKeyDictionary
14+
15+
from torch import nn
16+
17+
import nncf
18+
from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage
19+
from nncf.experimental.torch2.function_hook.wrapper import wrap_model
20+
from nncf.torch.layer_utils import COMPRESSION_MODULES
21+
from nncf.torch.layer_utils import StatefullModuleInterface
22+
23+
COMPRESSION_STATE_ATTR = "compression_state"
24+
TModel = TypeVar("TModel", bound=nn.Module)
25+
26+
27+
class S_COMMAND(TypedDict):
28+
hook_names_in_model: List[str]
29+
module_cls_name: str
30+
module_config: Dict[str, Any]
31+
32+
33+
def get_config(model: nn.Module) -> Dict[str, Any]:
34+
"""
35+
Returns serializable config which contains all information required to recover all additional modules placement.
36+
37+
:param model: The model to serialize.
38+
:return: Serializable config.
39+
"""
40+
hook_storage = get_hook_storage(model)
41+
42+
# Find shared modules
43+
modules_map: WeakKeyDictionary[nn.Module, List[str]] = WeakKeyDictionary()
44+
for name, module in hook_storage.named_modules(remove_duplicate=False):
45+
splitted_name = name.split(".")
46+
if len(splitted_name) != 3:
47+
# Expected depths of target hook module is 3
48+
# <3 - ModuleDicts in HookStorage, >3 - submodules of hooks
49+
continue
50+
if module not in modules_map:
51+
modules_map[module] = []
52+
modules_map[module].append(name)
53+
54+
# Generate serialized transformation commands
55+
serialized_transformations: List[S_COMMAND] = []
56+
for module, names in modules_map.items():
57+
compression_module_name = module.__class__.__name__
58+
if compression_module_name not in COMPRESSION_MODULES.registry_dict:
59+
msg = (
60+
f"Could not serialize compression module with name {compression_module_name}. "
61+
"Please register your module in the COMPRESSION_MODULES registry."
62+
)
63+
raise nncf.InternalError(msg)
64+
if not isinstance(module, StatefullModuleInterface):
65+
msg = "Support only StatefullModuleInterface modules"
66+
raise nncf.InternalError(msg)
67+
68+
serialized_transformations.append(
69+
{
70+
"hook_names_in_model": names,
71+
"module_cls_name": compression_module_name,
72+
"module_config": module.get_config(),
73+
}
74+
)
75+
76+
return {COMPRESSION_STATE_ATTR: serialized_transformations}
77+
78+
79+
def load_from_config(model: TModel, config: Dict[str, Any]) -> TModel:
80+
"""
81+
Initialize model with compressed modules from config file.
82+
83+
.. code-block:: python
84+
85+
model = MyModel()
86+
qmodel = nncf.quantize(model, ...)
87+
torch.save(
88+
{
89+
"state_dict": qmodel.state_dict(),
90+
"config": get_config(qmodel),
91+
},
92+
"ckpt.pth",
93+
)
94+
...
95+
ckpt = torch.load("ckpt.pth")
96+
restored_model = load_from_config(MyModel(), ckpt["config"])
97+
restored_model.load_state_dict(ckpt["state_dict"])
98+
99+
:param model: The original uncompressed model.
100+
:param config: The configuration dictionary containing the compressed model information.
101+
:return: The compressed model.
102+
"""
103+
wrapped_model = wrap_model(model)
104+
hook_storage = get_hook_storage(wrapped_model)
105+
transformation_commands = cast(List[S_COMMAND], config[COMPRESSION_STATE_ATTR])
106+
for command in transformation_commands:
107+
module_cls = COMPRESSION_MODULES.get(command["module_cls_name"])
108+
module = module_cls.from_config(command["module_config"])
109+
for target_name in command["hook_names_in_model"]:
110+
hook_type, hook_key, hook_id = target_name.split(".")
111+
storage_dict = getattr(hook_storage, hook_type)
112+
if hook_key not in storage_dict:
113+
storage_dict[hook_key] = nn.ModuleDict()
114+
if hook_id in storage_dict[hook_key]:
115+
msg = f"{hook_id=} for {hook_type}.{hook_key} already registered"
116+
raise nncf.InternalError(msg)
117+
storage_dict[hook_key][hook_id] = module
118+
return wrapped_model

tests/torch2/function_hook/helpers.py

+20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook
1616
from nncf.experimental.torch2.function_hook.wrapper import wrap_model
17+
from nncf.torch.layer_utils import COMPRESSION_MODULES
18+
from nncf.torch.layer_utils import StatefullModuleInterface
1719

1820

1921
class CallCount(torch.nn.Module):
@@ -160,3 +162,21 @@ def __init__(self):
160162
def forward(self, x):
161163
self.counter += 1
162164
return x + 1
165+
166+
167+
@COMPRESSION_MODULES.register()
168+
class HookWithState(torch.nn.Module, StatefullModuleInterface):
169+
def __init__(self, state: str):
170+
super().__init__()
171+
self._state = state
172+
self._dummy_param = torch.nn.Parameter(torch.tensor(1.0))
173+
174+
def forward(self, x):
175+
return x + self._dummy_param
176+
177+
def get_config(self):
178+
return self._state
179+
180+
@classmethod
181+
def from_config(cls, state: str):
182+
return cls(state)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from pathlib import Path
12+
13+
import pytest
14+
import torch
15+
from torch import nn
16+
17+
import nncf
18+
from nncf.experimental.torch2.function_hook import get_hook_storage
19+
from nncf.experimental.torch2.function_hook import register_post_function_hook
20+
from nncf.experimental.torch2.function_hook import register_pre_function_hook
21+
from nncf.experimental.torch2.function_hook import wrap_model
22+
from nncf.experimental.torch2.function_hook.serialization import get_config
23+
from nncf.experimental.torch2.function_hook.serialization import load_from_config
24+
from tests.torch2.function_hook.helpers import HookWithState
25+
from tests.torch2.function_hook.helpers import SimpleModel
26+
27+
28+
@pytest.mark.parametrize("is_shared_hook", [True, False], ids=["shared_hook", "not_shared_hook"])
29+
def test_save_load(tmp_path: Path, is_shared_hook: bool):
30+
model = wrap_model(SimpleModel())
31+
32+
hook1 = HookWithState("hook1")
33+
hook2 = hook1 if is_shared_hook else HookWithState("hook2")
34+
35+
register_pre_function_hook(model, "conv1/conv2d/0", 0, hook1)
36+
register_post_function_hook(model, "simple/conv/conv2d/0", 0, hook2)
37+
38+
state_dict = model.state_dict()
39+
compression_config = get_config(model)
40+
41+
torch.save(
42+
{
43+
"model_state_dict": state_dict,
44+
"compression_config": compression_config,
45+
},
46+
tmp_path / "checkpoint.pth",
47+
)
48+
49+
ckpt = torch.load(tmp_path / "checkpoint.pth")
50+
config = ckpt["compression_config"]
51+
restored_model = load_from_config(SimpleModel(), config)
52+
restored_model.load_state_dict(ckpt["model_state_dict"])
53+
54+
assert state_dict == restored_model.state_dict()
55+
56+
tensor = model.get_example_inputs()
57+
ret_1 = model(tensor)
58+
ret_2 = restored_model(tensor)
59+
assert torch.allclose(ret_1[0], ret_2[0])
60+
assert torch.allclose(ret_1[1], ret_2[1])
61+
62+
hook_storage = get_hook_storage(restored_model)
63+
hook1 = hook_storage.get_submodule("pre_hooks.conv1/conv2d/0__0.0")
64+
hook2 = hook_storage.get_submodule("post_hooks.simple/conv/conv2d/0__0.0")
65+
assert (hook1 is hook2) == is_shared_hook
66+
67+
68+
def test_error_dublicate_names():
69+
config = {
70+
"compression_state": [
71+
{
72+
"hook_names_in_model": ["pre_hooks.conv1/conv2d/0__0.0", "pre_hooks.conv1/conv2d/0__0.0"],
73+
"module_cls_name": "HookWithState",
74+
"module_config": "hook1",
75+
}
76+
]
77+
}
78+
with pytest.raises(nncf.InternalError, match="already registered"):
79+
load_from_config(SimpleModel(), config)
80+
81+
82+
def test_error_not_registered_compression_modules():
83+
model = wrap_model(SimpleModel())
84+
register_pre_function_hook(model, "conv1/conv2d/0", 0, nn.ReLU())
85+
86+
with pytest.raises(nncf.InternalError, match="Please register your module in the COMPRESSION_MODULES registry."):
87+
get_config(model)

0 commit comments

Comments
 (0)