Skip to content

Commit 2410759

Browse files
add param cache
1 parent 70bd672 commit 2410759

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

optimum/fx/parallelization/core.py

+6
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ class ParallelExecutionCtx:
125125
because we have to make sure we don't initiate new parameters and replace original ones when
126126
recompilation happens in training process.
127127
128+
- param_cache (`Dict[str, nn.Parameter]`):
129+
Cache which keeps record of newly created parameters. Similar to `parallel_layer_cache`, we
130+
need to make sure all the newly created parameters in the first compilation will still be used
131+
when recompilation happens.
132+
128133
- weight_map (`Dict[str, str]`):
129134
Mapping between parameter names and their locations on disk, useful when loading weights
130135
from disk.
@@ -140,6 +145,7 @@ class ParallelExecutionCtx:
140145
current_device: torch.device
141146
example_inputs: List[Any] = field(default_factory=list)
142147
parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict)
148+
param_cache: Dict[str, nn.Parameter] = field(default_factory=dict)
143149
weight_map: Dict[str, str] = field(default_factory=dict)
144150
last_optimized_graph_module: Optional[GraphModule] = None
145151
compile_times: int = 0

optimum/fx/parallelization/passes.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -480,18 +480,21 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
480480

481481
class InitializeOrLoadWeightsPass(PassBase):
482482
"""
483-
Make weights loading/initialization a seperate pass for cleaner logic and easier extensibility. This
484-
pass will only run once in the very first compilation step.
483+
Weights loading and intialization pass, will initialize parameters on current rank and load weights from disk
484+
if necessary.
485485
"""
486486

487-
need_rerun_when_recompile = False
488-
489487
def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule:
490488
world_size = dist.get_world_size(ctx.tp_group)
491489
tp_rank = dist.get_rank(ctx.tp_group)
492490

493-
new_parameters, tied_parameters = [], {}
491+
new_parameters, tied_parameters, param_cache = [], {}, ctx.param_cache
494492
for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)):
493+
# skip initializing new params when recompilation happens
494+
if name in param_cache:
495+
new_parameters.append((name, param_cache[name]))
496+
continue
497+
495498
param_meta: ParameterMeta = getattr(param, "meta")
496499
# skip already initialized/loaded tied parameters
497500
if param_meta.is_tied and id(param) in tied_parameters:
@@ -569,6 +572,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
569572
else:
570573
parent_mod = graph_module
571574
field = name
575+
if name not in param_cache:
576+
param_cache[name] = new_param
572577
setattr(parent_mod, field, new_param)
573578

574579
return graph_module

0 commit comments

Comments
 (0)