Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 64b01ea

Browse files
AlexanderDokuchaevshumaari
authored andcommittedMar 8, 2025
[PT2] cache for shared parameters (openvinotoolkit#3297)
### Changes Add cache for result of post hooks for shared parameters Use TypeVar for wrap_model Modified to_comparable_nx_graph to dump node with name that contain `:`
1 parent 861553e commit 64b01ea

File tree

7 files changed

+129
-7
lines changed

7 files changed

+129
-7
lines changed
 

‎nncf/experimental/torch2/function_hook/hook_executor_mode.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,13 @@ class FunctionHookMode(TorchFunctionMode):
9393
This mode wraps the function calls in the model to allow custom hooks to be executed before
9494
and after the actual function calls.
9595
96-
9796
:param model: The PyTorch model to which the hooks will be applied.
9897
:param hook_storage: Storage for hooks to be executed.
9998
:param module_call_stack: A stack tracking the modules being called.
10099
:param nested_enter_count: A counter to track nested context manager entries.
101100
:param op_calls: A dictionary to track operation calls.
101+
:param counter_reusing_shared_weights: A dictionary to track shared weights.
102+
:param cache_parameters: A dictionary to cache modified parameters.
102103
"""
103104

104105
def __init__(self, model: nn.Module, hook_storage: HookStorage) -> None:
@@ -127,6 +128,14 @@ def __init__(self, model: nn.Module, hook_storage: HookStorage) -> None:
127128
self._get_named_hooks(self.hook_storage.pre_hooks, "pre_hook")
128129
self._get_named_hooks(self.hook_storage.post_hooks, "post_hook")
129130

131+
# Collect how many times shared parameter used
132+
counter_shared_weights: Dict[int, int] = defaultdict(int)
133+
for name, parameter in chain(self.model.named_parameters(remove_duplicate=False)):
134+
counter_shared_weights[id(parameter)] += 1
135+
136+
self.counter_reusing_shared_weights = {k: v - 1 for k, v in counter_shared_weights.items() if v > 1}
137+
self.cache_parameters: Dict[int, Tensor] = {}
138+
130139
def _get_named_hooks(self, storage: nn.ModuleDict, prefix: str) -> None:
131140
"""
132141
Associates named hooks from the given module storage with a group name, updating
@@ -306,18 +315,41 @@ def execute_hooks_for_parameter(self, value: torch.Tensor) -> torch.Tensor:
306315
Executes post-hooks for a model parameter if a hook is defined for it.
307316
If the input is not a `torch.nn.Parameter`, or if no hook is defined, the original tensor is returned unchanged.
308317
318+
For shared parameters that are used more than once, the function caches the modified parameters.
319+
Caching mechanism allows the function to avoid redundant computations for shared parameters.
320+
309321
:param value: The tensor to which the post-hook will be applied..
310322
:return: The processed tensor with the applied post-hook, if applicable.
311323
"""
312324
if not isinstance(value, torch.nn.Parameter):
313325
return value
314326

327+
id_param = id(value)
328+
if id_param in self.cache_parameters:
329+
ret = self.cache_parameters[id_param]
330+
self.counter_reusing_shared_weights[id_param] -= 1
331+
if self.counter_reusing_shared_weights[id_param] == 0:
332+
# Clean cache for parameters for last used
333+
del self.cache_parameters[id_param]
334+
del self.counter_reusing_shared_weights[id_param]
335+
return ret
336+
337+
ret_value = value
315338
name_in_model = self.const_name_map.get(value, None)
316339
if name_in_model is not None and not self.in_process_const:
317340
self.in_process_const = True
318-
value = self.hook_storage.execute_post_function_hooks(name_in_model.replace(".", ":"), 0, value)
341+
ret_value = self.hook_storage.execute_post_function_hooks(name_in_model.replace(".", ":"), 0, value)
319342
self.in_process_const = False
320-
return value
343+
344+
if self.counter_reusing_shared_weights.get(id_param):
345+
if ret_value is value:
346+
# Remove counter for parameters that does not change parameter
347+
del self.counter_reusing_shared_weights[id_param]
348+
else:
349+
# Save modified parameters
350+
self.cache_parameters[id_param] = ret_value
351+
352+
return ret_value
321353

322354
def process_parameters(self, args: List[Any], kwargs: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]:
323355
"""

‎nncf/experimental/torch2/function_hook/wrapper.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import inspect
1515
import types
1616
from types import MethodType
17-
from typing import Any, Callable, Dict, Tuple, cast
17+
from typing import Any, Callable, Dict, Tuple, TypeVar, cast
1818

1919
from torch import nn
2020

@@ -25,6 +25,8 @@
2525

2626
ATR_HOOK_STORAGE = "__nncf_hooks"
2727

28+
TModel = TypeVar("TModel", bound=nn.Module)
29+
2830

2931
class ForwardWithHooks:
3032
"""Class to wrap forward function of nn.Module, to forward function of the model with enabled FunctionHookMode"""
@@ -149,7 +151,7 @@ def func(self) -> MethodType:
149151
return cast(MethodType, self._func)
150152

151153

152-
def wrap_model(model: nn.Module) -> nn.Module:
154+
def wrap_model(model: TModel) -> TModel:
153155
"""
154156
Wraps a nn.Module to inject custom behavior into the forward pass and replication process.
155157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
strict digraph {
2+
x [id=0, type="nncf_model_input", metatype=PTInputNoopMetatype];
3+
"module1.0.weight" [id=1, type="nncf_model_const", metatype=PTConstNoopMetatype];
4+
"module1/0/post_hook__module1:0:weight__0[0]/add/0" [id=2, type=add, metatype=PTAddMetatype];
5+
"module1/0/linear/0" [id=3, type=linear, metatype=PTLinearMetatype];
6+
"module2/0/linear/0" [id=4, type=linear, metatype=PTLinearMetatype];
7+
"/add/0" [id=5, type=add, metatype=PTAddMetatype];
8+
output [id=6, type="nncf_model_output", metatype=PTOutputNoopMetatype];
9+
x -> "module1/0/linear/0" [dtype=float, shape="(1, 3)", out_port_id=0, in_port_id=0];
10+
x -> "module2/0/linear/0" [dtype=float, shape="(1, 3)", out_port_id=0, in_port_id=0];
11+
"module1.0.weight" -> "module1/0/post_hook__module1:0:weight__0[0]/add/0" [dtype=float, shape="(1, 3)", out_port_id=0, in_port_id=0];
12+
"module1/0/post_hook__module1:0:weight__0[0]/add/0" -> "module1/0/linear/0" [dtype=float, shape="(1, 3)", out_port_id=0, in_port_id=1];
13+
"module1/0/post_hook__module1:0:weight__0[0]/add/0" -> "module2/0/linear/0" [dtype=float, shape="(1, 3)", out_port_id=0, in_port_id=1];
14+
"module1/0/linear/0" -> "/add/0" [dtype=float, shape="(1, 1)", out_port_id=0, in_port_id=0];
15+
"module2/0/linear/0" -> "/add/0" [dtype=float, shape="(1, 1)", out_port_id=0, in_port_id=1];
16+
"/add/0" -> output [dtype=float, shape="(1, 1)", out_port_id=0, in_port_id=0];
17+
}

‎tests/torch2/function_hook/helpers.py

+26
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,29 @@ def forward(self, x: torch.Tensor):
136136
x = self.conv(x)
137137
x = torch.relu(x)
138138
return x
139+
140+
141+
class SharedParamModel(nn.Module):
142+
143+
@staticmethod
144+
def get_example_inputs():
145+
return torch.ones([1, 3])
146+
147+
def __init__(self):
148+
super().__init__()
149+
shared_linear = nn.Linear(3, 1, bias=False)
150+
self.module1 = nn.Sequential(shared_linear)
151+
self.module2 = nn.Sequential(shared_linear)
152+
153+
def forward(self, x):
154+
return self.module1(x) + self.module2(x)
155+
156+
157+
class CounterHook(nn.Module):
158+
def __init__(self):
159+
super().__init__()
160+
self.counter = 0
161+
162+
def forward(self, x):
163+
self.counter += 1
164+
return x + 1

‎tests/torch2/function_hook/nncf_graph/test_nncf_graph.py

+11
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_dtype
3030
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_name_of_node
3131
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import get_node_type
32+
from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook
3233
from nncf.experimental.torch2.function_hook.wrapper import wrap_model
3334
from tests.cross_fw.shared.paths import TEST_ROOT
3435
from tests.torch2.function_hook import helpers
@@ -144,3 +145,13 @@ def test_model_graph(desc: ModelDesc, regen_ref_data: bool):
144145
nx_nncf_graph = nx.nx_pydot.to_pydot(graph)
145146
ref_file = REF_DIR / f"model_graph_{desc}.dot"
146147
compare_with_reference_file(str(nx_nncf_graph), ref_file, regen_ref_data)
148+
149+
150+
def test_model_graph_with_shared_parameters(regen_ref_data):
151+
model = wrap_model(helpers.SharedParamModel())
152+
register_post_function_hook(model, "module1:0:weight", 0, helpers.CounterHook())
153+
nncf_graph = build_nncf_graph(model, model.get_example_inputs())
154+
graph = to_comparable_nx_graph(nncf_graph)
155+
nx_nncf_graph = nx.nx_pydot.to_pydot(graph)
156+
ref_file = REF_DIR / "model_graph_with_shared_parameters.dot"
157+
compare_with_reference_file(str(nx_nncf_graph), ref_file, regen_ref_data)

‎tests/torch2/function_hook/test_function_hook_mode.py

+25
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from nncf.experimental.torch2.function_hook.wrapper import wrap_model
2828
from tests.torch2.function_hook import helpers
2929
from tests.torch2.function_hook.helpers import CallCount
30+
from tests.torch2.function_hook.helpers import CounterHook
31+
from tests.torch2.function_hook.helpers import SharedParamModel
3032

3133

3234
@dataclass
@@ -139,3 +141,26 @@ def test_execute_pre_hooks_for_concat():
139141
register_pre_function_hook(model, op_name, 1, AddModule(2))
140142
ret_val = model(torch.zeros(2))
141143
assert torch.allclose(ret_val, torch.tensor([1.0, 1.0, 2.0, 2.0])), ret_val
144+
145+
146+
def test_shared_parameters():
147+
model = SharedParamModel()
148+
hook_storage = HookStorage()
149+
hook = CounterHook()
150+
hook_storage.register_post_function_hook("module1:0:weight", 0, hook)
151+
152+
args = (model.get_example_inputs(),)
153+
kwargs = {}
154+
with FunctionHookMode(model, hook_storage) as ctx:
155+
assert hook.counter == 0
156+
assert ctx.cache_parameters == {}
157+
assert ctx.counter_reusing_shared_weights == {id(model.module1[0].weight): 1}
158+
159+
args, kwargs = ctx.process_model_inputs(args, kwargs)
160+
outputs = model.forward(*args, **kwargs)
161+
outputs = ctx.process_model_outputs(outputs)
162+
163+
assert hook.counter == 1
164+
# Check that the cache cleared in the end
165+
assert ctx.cache_parameters == {}
166+
assert ctx.counter_reusing_shared_weights == {}

‎tests/torch2/utils.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def compare_with_reference_file(text_data: str, ref_path: Path, regen_ref_data:
4343
)
4444

4545

46+
def _quote_str(s: str) -> str:
47+
"""
48+
Add quotes to a string if it contains a colon.
49+
"""
50+
if ":" in s:
51+
return f'"{s}"'
52+
return s
53+
54+
4655
def to_comparable_nx_graph(graph: NNCFGraph) -> nx.DiGraph:
4756
"""
4857
Convert NNCFGraph to nx.DiGraph for comparison with references.
@@ -70,7 +79,7 @@ def to_comparable_nx_graph(graph: NNCFGraph) -> nx.DiGraph:
7079
"type": node.node_type,
7180
"metatype": node.metatype.__name__,
7281
}
73-
out_graph.add_node(node.node_name, **attrs_node)
82+
out_graph.add_node(_quote_str(node.node_name), **attrs_node)
7483

7584
for edge in graph.get_all_edges():
7685
attrs_edge = {
@@ -82,5 +91,5 @@ def to_comparable_nx_graph(graph: NNCFGraph) -> nx.DiGraph:
8291
if edge.parallel_input_port_ids:
8392
attrs_edge["parallel_input_port_ids"] = edge.parallel_input_port_ids
8493

85-
out_graph.add_edge(edge.from_node.node_name, edge.to_node.node_name, **attrs_edge)
94+
out_graph.add_edge(_quote_str(edge.from_node.node_name), _quote_str(edge.to_node.node_name), **attrs_edge)
8695
return out_graph

0 commit comments

Comments
 (0)
Please sign in to comment.