@@ -93,12 +93,13 @@ class FunctionHookMode(TorchFunctionMode):
93
93
This mode wraps the function calls in the model to allow custom hooks to be executed before
94
94
and after the actual function calls.
95
95
96
-
97
96
:param model: The PyTorch model to which the hooks will be applied.
98
97
:param hook_storage: Storage for hooks to be executed.
99
98
:param module_call_stack: A stack tracking the modules being called.
100
99
:param nested_enter_count: A counter to track nested context manager entries.
101
100
:param op_calls: A dictionary to track operation calls.
101
+ :param counter_reusing_shared_weights: A dictionary to track shared weights.
102
+ :param cache_parameters: A dictionary to cache modified parameters.
102
103
"""
103
104
104
105
def __init__ (self , model : nn .Module , hook_storage : HookStorage ) -> None :
@@ -127,6 +128,14 @@ def __init__(self, model: nn.Module, hook_storage: HookStorage) -> None:
127
128
self ._get_named_hooks (self .hook_storage .pre_hooks , "pre_hook" )
128
129
self ._get_named_hooks (self .hook_storage .post_hooks , "post_hook" )
129
130
131
+ # Collect how many times shared parameter used
132
+ counter_shared_weights : Dict [int , int ] = defaultdict (int )
133
+ for name , parameter in chain (self .model .named_parameters (remove_duplicate = False )):
134
+ counter_shared_weights [id (parameter )] += 1
135
+
136
+ self .counter_reusing_shared_weights = {k : v - 1 for k , v in counter_shared_weights .items () if v > 1 }
137
+ self .cache_parameters : Dict [int , Tensor ] = {}
138
+
130
139
def _get_named_hooks (self , storage : nn .ModuleDict , prefix : str ) -> None :
131
140
"""
132
141
Associates named hooks from the given module storage with a group name, updating
@@ -306,18 +315,41 @@ def execute_hooks_for_parameter(self, value: torch.Tensor) -> torch.Tensor:
306
315
Executes post-hooks for a model parameter if a hook is defined for it.
307
316
If the input is not a `torch.nn.Parameter`, or if no hook is defined, the original tensor is returned unchanged.
308
317
318
+ For shared parameters that are used more than once, the function caches the modified parameters.
319
+ Caching mechanism allows the function to avoid redundant computations for shared parameters.
320
+
309
321
:param value: The tensor to which the post-hook will be applied..
310
322
:return: The processed tensor with the applied post-hook, if applicable.
311
323
"""
312
324
if not isinstance (value , torch .nn .Parameter ):
313
325
return value
314
326
327
+ id_param = id (value )
328
+ if id_param in self .cache_parameters :
329
+ ret = self .cache_parameters [id_param ]
330
+ self .counter_reusing_shared_weights [id_param ] -= 1
331
+ if self .counter_reusing_shared_weights [id_param ] == 0 :
332
+ # Clean cache for parameters for last used
333
+ del self .cache_parameters [id_param ]
334
+ del self .counter_reusing_shared_weights [id_param ]
335
+ return ret
336
+
337
+ ret_value = value
315
338
name_in_model = self .const_name_map .get (value , None )
316
339
if name_in_model is not None and not self .in_process_const :
317
340
self .in_process_const = True
318
- value = self .hook_storage .execute_post_function_hooks (name_in_model .replace ("." , ":" ), 0 , value )
341
+ ret_value = self .hook_storage .execute_post_function_hooks (name_in_model .replace ("." , ":" ), 0 , value )
319
342
self .in_process_const = False
320
- return value
343
+
344
+ if self .counter_reusing_shared_weights .get (id_param ):
345
+ if ret_value is value :
346
+ # Remove counter for parameters that does not change parameter
347
+ del self .counter_reusing_shared_weights [id_param ]
348
+ else :
349
+ # Save modified parameters
350
+ self .cache_parameters [id_param ] = ret_value
351
+
352
+ return ret_value
321
353
322
354
def process_parameters (self , args : List [Any ], kwargs : Dict [str , Any ]) -> Tuple [List [Any ], Dict [str , Any ]]:
323
355
"""
0 commit comments