@@ -480,18 +480,21 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
480
480
481
481
class InitializeOrLoadWeightsPass (PassBase ):
482
482
"""
483
- Make weights loading/initialization a seperate pass for cleaner logic and easier extensibility. This
484
- pass will only run once in the very first compilation step .
483
+ Weights loading and intialization pass, will initialize parameters on current rank and load weights from disk
484
+ if necessary .
485
485
"""
486
486
487
- need_rerun_when_recompile = False
488
-
489
487
def run (self , graph_module : GraphModule , ctx : ParallelExecutionCtx , config : Config ) -> GraphModule :
490
488
world_size = dist .get_world_size (ctx .tp_group )
491
489
tp_rank = dist .get_rank (ctx .tp_group )
492
490
493
- new_parameters , tied_parameters = [], {}
491
+ new_parameters , tied_parameters , param_cache = [], {}, ctx . param_cache
494
492
for name , param in sorted (graph_module .named_parameters (remove_duplicate = False )):
493
+ # skip initializing new params when recompilation happens
494
+ if name in param_cache :
495
+ new_parameters .append ((name , param_cache [name ]))
496
+ continue
497
+
495
498
param_meta : ParameterMeta = getattr (param , "meta" )
496
499
# skip already initialized/loaded tied parameters
497
500
if param_meta .is_tied and id (param ) in tied_parameters :
@@ -569,6 +572,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
569
572
else :
570
573
parent_mod = graph_module
571
574
field = name
575
+ if name not in param_cache :
576
+ param_cache [name ] = new_param
572
577
setattr (parent_mod , field , new_param )
573
578
574
579
return graph_module
0 commit comments