Skip to content

Commit 1b3b3d5

Browse files
Merge remote-tracking branch 'upstream/main' into longjie/generalize_parallelization_strategy
2 parents 44a87f4 + 23f8574 commit 1b3b3d5

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

optimum/fx/parallelization/core.py

+6
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ class ParallelExecutionCtx:
125125
because we have to make sure we don't initiate new parameters and replace original ones when
126126
recompilation happens in training process.
127127
128+
- param_cache (`Dict[str, nn.Parameter]`):
129+
Cache which keeps record of newly created parameters. Similar to `parallel_layer_cache`, we
130+
need to make sure all the newly created parameters in the first compilation will still be used
131+
when recompilation happens.
132+
128133
- weight_map (`Dict[str, str]`):
129134
Mapping between parameter names and their locations on disk, useful when loading weights
130135
from disk.
@@ -140,6 +145,7 @@ class ParallelExecutionCtx:
140145
current_device: torch.device
141146
example_inputs: List[Any] = field(default_factory=list)
142147
parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict)
148+
param_cache: Dict[str, nn.Parameter] = field(default_factory=dict)
143149
weight_map: Dict[str, str] = field(default_factory=dict)
144150
last_optimized_graph_module: Optional[GraphModule] = None
145151
compile_times: int = 0

optimum/fx/parallelization/passes.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -392,18 +392,21 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
392392

393393
class InitializeOrLoadWeightsPass(PassBase):
394394
"""
395-
Make weights loading/initialization a seperate pass for cleaner logic and easier extensibility. This
396-
pass will only run once in the very first compilation step.
395+
Weights loading and intialization pass, will initialize parameters on current rank and load weights from disk
396+
if necessary.
397397
"""
398398

399-
need_rerun_when_recompile = False
400-
401399
def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule:
402400
world_size = dist.get_world_size(ctx.tp_group)
403401
tp_rank = dist.get_rank(ctx.tp_group)
404402

405-
new_parameters, tied_parameters = [], {}
403+
new_parameters, tied_parameters, param_cache = [], {}, ctx.param_cache
406404
for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)):
405+
# skip initializing new params when recompilation happens
406+
if name in param_cache:
407+
new_parameters.append((name, param_cache[name]))
408+
continue
409+
407410
param_meta: ParameterMeta = getattr(param, "meta")
408411
# skip already initialized/loaded tied parameters
409412
if param_meta.is_tied and id(param) in tied_parameters:
@@ -481,6 +484,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
481484
else:
482485
parent_mod = graph_module
483486
field = name
487+
if name not in param_cache:
488+
param_cache[name] = new_param
484489
setattr(parent_mod, field, new_param)
485490

486491
return graph_module

0 commit comments

Comments
 (0)