@@ -392,18 +392,21 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
392
392
393
393
class InitializeOrLoadWeightsPass (PassBase ):
394
394
"""
395
- Make weights loading/initialization a seperate pass for cleaner logic and easier extensibility. This
396
- pass will only run once in the very first compilation step .
395
+ Weights loading and intialization pass, will initialize parameters on current rank and load weights from disk
396
+ if necessary .
397
397
"""
398
398
399
- need_rerun_when_recompile = False
400
-
401
399
def run (self , graph_module : GraphModule , ctx : ParallelExecutionCtx , config : Config ) -> GraphModule :
402
400
world_size = dist .get_world_size (ctx .tp_group )
403
401
tp_rank = dist .get_rank (ctx .tp_group )
404
402
405
- new_parameters , tied_parameters = [], {}
403
+ new_parameters , tied_parameters , param_cache = [], {}, ctx . param_cache
406
404
for name , param in sorted (graph_module .named_parameters (remove_duplicate = False )):
405
+ # skip initializing new params when recompilation happens
406
+ if name in param_cache :
407
+ new_parameters .append ((name , param_cache [name ]))
408
+ continue
409
+
407
410
param_meta : ParameterMeta = getattr (param , "meta" )
408
411
# skip already initialized/loaded tied parameters
409
412
if param_meta .is_tied and id (param ) in tied_parameters :
@@ -481,6 +484,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
481
484
else :
482
485
parent_mod = graph_module
483
486
field = name
487
+ if name not in param_cache :
488
+ param_cache [name ] = new_param
484
489
setattr (parent_mod , field , new_param )
485
490
486
491
return graph_module
0 commit comments