Skip to content

Commit 197601e

Browse files
ErezYosefjaneyx99
authored andcommitted
Add Support for Tracking Parameter Names (named_parameters) in Optimizer State Dict (pytorch#134107)
A proposal addressing Issue pytorch#1489: **Optimizer should track parameter names and not id.** (also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552) ## Summary This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id. Optimizers can be initialized with `named_parameters()` as: ```python optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) ``` This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as: ``` state_dict = { 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0,1] 'param_names' ['layer.weight', 'layer.bias'] (optional) } ] } ``` Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored. ## Key Features #### Named Parameters in Optimizer Initialization: Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly. #### Parameter Names in `state_dict`: The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters. ## Backward Compatibility #### No Breaking Changes: This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer. #### Customization with Hooks: For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs. ## Documentation Updates Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively. ## Solution Example: A suggested solution to the problem mentioned in pytorch#1489, for the same parameters but in a different order. The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict : ```python def adapt_state_dict_ids(optimizer, state_dict): # assuming a single param group. current_state_group = optimizer.state_dict()['param_groups'][0] loaded_state_group = state_dict['param_groups'][0] # same number of params, same names, only different ordering current_state_name_to_id_mapping = {} # mapping -- param_name: id for i, name in enumerate(current_state_group['param_names']): current_state_name_to_id_mapping[name] = current_state_group['params'][i] # changing the ids of the loaded state dict to match the order of the given state dict. for i, name in enumerate(current_state_group['param_names']): loaded_state_group['params'][i] = current_state_name_to_id_mapping[name] return state_dict ``` In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`. Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict. ### Note This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement. Pull Request resolved: pytorch#134107 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
1 parent 4470339 commit 197601e

17 files changed

+395
-39
lines changed

docs/source/optim.rst

+188-1
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@ Constructing it
1313
^^^^^^^^^^^^^^^
1414

1515
To construct an :class:`Optimizer` you have to give it an iterable containing the
16-
parameters (all should be :class:`~torch.autograd.Variable` s) to optimize. Then,
16+
parameters (all should be :class:`~torch.nn.Parameter` s) or named parameters
17+
(tuples of (str, :class:`~torch.nn.Parameter`)) to optimize. Then,
1718
you can specify optimizer-specific options such as the learning rate, weight decay, etc.
1819

1920
Example::
2021

2122
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
2223
optimizer = optim.Adam([var1, var2], lr=0.0001)
2324

25+
Named parameters example::
26+
27+
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
28+
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)
29+
2430
Per-parameter options
2531
^^^^^^^^^^^^^^^^^^^^^
2632

@@ -38,6 +44,11 @@ For example, this is very useful when one wants to specify per-layer learning ra
3844
{'params': model.classifier.parameters()}
3945
], lr=1e-3, momentum=0.9)
4046

47+
optim.SGD([
48+
{'params': model.base.named_parameters(), 'lr': 1e-2},
49+
{'params': model.classifier.named_parameters()}
50+
], lr=1e-3, momentum=0.9)
51+
4152
This means that ``model.base``'s parameters will use a learning rate of ``1e-2``, whereas
4253
``model.classifier``'s parameters will stick to the default learning rate of ``1e-3``.
4354
Finally a momentum of ``0.9`` will be used for all parameters.
@@ -303,6 +314,182 @@ algorithms.
303314
lr_scheduler.OneCycleLR
304315
lr_scheduler.CosineAnnealingWarmRestarts
305316

317+
How to utilize named parameters to load optimizer state dict
318+
------------------------------------------------------------
319+
320+
The function :func:`~Optimizer.load_state_dict` stores the optional ``param_names``content from the
321+
loaded state dict if present. However, the process of loading the optimizer state is not affected,
322+
as the order of the parameters matters to maintain compatibility (in case of different ordering).
323+
To utilize the loaded parameters names from the loaded state dict, a custom ``register_load_state_dict_pre_hook``
324+
needs to be implemented according to the desired behavior.
325+
326+
This can be useful, for instance, when the model architecture changes, but the weights and optimizer states need to
327+
remain unchanged. The following example demonstrates how to implement this customization.
328+
329+
Example::
330+
331+
class OneLayerModel(nn.Module):
332+
def __init__(self):
333+
super().__init__()
334+
self.fc = nn.Linear(3, 4)
335+
336+
def forward(self, x):
337+
return self.fc(x)
338+
339+
model = OneLayerModel()
340+
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
341+
# training..
342+
torch.save(optimizer.state_dict(), PATH)
343+
344+
Let's say that ``model`` implements an expert (MoE), and we want to duplicate it and resume training
345+
for two experts, both initialized the same way as the ``fc`` layer. For the following ``model2`` we create two layers identical to ``fc`` and resume training by loading the model weights and optimizer states from ``model`` into both ``fc1`` and ``fc2`` of ``model2`` (and adjust them accordingly)::
346+
347+
class TwoLayerModel(nn.Module):
348+
def __init__(self):
349+
super().__init__()
350+
self.fc1 = nn.Linear(3, 4)
351+
self.fc2 = nn.Linear(3, 4)
352+
353+
def forward(self, x):
354+
return (self.fc1(x) + self.fc2(x)) / 2
355+
356+
model2 = TwoLayerModel()
357+
# adapt and load model weights..
358+
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)
359+
360+
To load the state dict for ``optimizer2`` with the state dict of the previous optimizer such that both
361+
``fc1`` and ``fc2`` will be initialized with a copy of ``fc`` optimizer states
362+
(to resume training for each layer from ``fc``), we can use the following hook::
363+
364+
def adapt_state_dict_ids(optimizer, state_dict):
365+
adapted_state_dict = deepcopy(optimizer.state_dict())
366+
# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
367+
for k, v in state_dict['param_groups'][0].items():
368+
if k not in ['params', 'param_names']:
369+
adapted_state_dict['param_groups'][0][k] = v
370+
371+
lookup_dict = {
372+
'fc1.weight': 'fc.weight',
373+
'fc1.bias': 'fc.bias',
374+
'fc2.weight': 'fc.weight',
375+
'fc2.bias': 'fc.bias'
376+
}
377+
clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
378+
for param_id, param_name in zip(
379+
optimizer.state_dict()['param_groups'][0]['params'],
380+
optimizer.state_dict()['param_groups'][0]['param_names']):
381+
name_in_loaded = lookup_dict[param_name]
382+
index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
383+
id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
384+
# Copy the state of the corresponding parameter
385+
if id_in_loaded in state_dict['state']:
386+
adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])
387+
388+
return adapted_state_dict
389+
390+
optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
391+
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict
392+
393+
This ensures that the adapted state_dict with the correct states for the layers of ``model2`` will be used
394+
during model loading.
395+
Note that this code is designed specifically for this example (e.g., assuming a single parameter group),
396+
and other cases might require different adaptations.
397+
398+
The following example shows how to handle missing parameters in a loaded
399+
``state dict`` when the model structure changes.
400+
The ``Model_bypass`` adds a new ``bypass`` layer, which is not present in the original ``Model1``.
401+
To resume training, a custom ``adapt_state_dict_missing_param`` hook is used to adapt the optimizer's ``state_dict``,
402+
ensuring existing parameters are mapped correctly, while missing ones (like the bypass layer) remain unchanged
403+
(as initialized in this example).
404+
This approach enables smooth loading and resuming of the optimizer state despite model changes.
405+
The new bypass layer will be trained from scratch::
406+
407+
class Model1(nn.Module):
408+
def __init__(self):
409+
super().__init__()
410+
self.fc = nn.Linear(5, 5)
411+
412+
def forward(self, x):
413+
return self.fc(x) + x
414+
415+
416+
model = Model1()
417+
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
418+
# training..
419+
torch.save(optimizer.state_dict(), PATH)
420+
421+
class Model_bypass(nn.Module):
422+
def __init__(self):
423+
super().__init__()
424+
self.fc = nn.Linear(5, 5)
425+
self.bypass = nn.Linear(5, 5, bias=False)
426+
torch.nn.init.eye_(self.bypass.weight)
427+
428+
def forward(self, x):
429+
return self.fc(x) + self.bypass(x)
430+
431+
model2 = Model_bypass()
432+
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)
433+
434+
def adapt_state_dict_missing_param(optimizer, state_dict):
435+
adapted_state_dict = deepcopy(optimizer.state_dict())
436+
# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
437+
for k, v in state_dict['param_groups'][0].items():
438+
if k not in ['params', 'param_names']:
439+
adapted_state_dict['param_groups'][0][k] = v
440+
441+
lookup_dict = {
442+
'fc.weight': 'fc.weight',
443+
'fc.bias': 'fc.bias',
444+
'bypass.weight': None,
445+
}
446+
447+
clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
448+
for param_id, param_name in zip(
449+
optimizer.state_dict()['param_groups'][0]['params'],
450+
optimizer.state_dict()['param_groups'][0]['param_names']):
451+
name_in_loaded = lookup_dict[param_name]
452+
if name_in_loaded in state_dict['param_groups'][0]['param_names']:
453+
index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
454+
id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
455+
# Copy the state of the corresponding parameter
456+
if id_in_loaded in state_dict['state']:
457+
adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])
458+
459+
return adapted_state_dict
460+
461+
optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
462+
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict
463+
464+
465+
466+
As a third example, instead of loading a state according to the order of parameters (the default approach),
467+
this hook can be used to load according to the parameters' names::
468+
469+
def names_matching(optimizer, state_dict):
470+
assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
471+
adapted_state_dict = deepcopy(optimizer.state_dict())
472+
for g_ind in range(len(state_dict['param_groups'])):
473+
assert len(state_dict['param_groups'][g_ind]['params']) == len(
474+
optimizer.state_dict()['param_groups'][g_ind]['params'])
475+
476+
for k, v in state_dict['param_groups'][g_ind].items():
477+
if k not in ['params', 'param_names']:
478+
adapted_state_dict['param_groups'][g_ind][k] = v
479+
480+
for param_id, param_name in zip(
481+
optimizer.state_dict()['param_groups'][g_ind]['params'],
482+
optimizer.state_dict()['param_groups'][g_ind]['param_names']):
483+
index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
484+
id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
485+
# Copy the state of the corresponding parameter
486+
if id_in_loaded in state_dict['state']:
487+
adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])
488+
489+
return adapted_state_dict
490+
491+
492+
306493
Weight Averaging (SWA and EMA)
307494
------------------------------
308495

test/test_optim.py

+104-9
Original file line numberDiff line numberDiff line change
@@ -1341,8 +1341,12 @@ def test_optimizer_can_be_printed(self, device, dtype, optim_info):
13411341
optimizer = optim_cls(params, **optim_input.kwargs)
13421342
optimizer.__repr__()
13431343

1344+
@parametrize("is_named_optim0", [True, False])
1345+
@parametrize("is_named_optim1", [True, False])
13441346
@optims(optim_db, dtypes=[torch.float32])
1345-
def test_state_dict_deterministic(self, device, dtype, optim_info):
1347+
def test_state_dict_deterministic(
1348+
self, device, dtype, optim_info, is_named_optim0, is_named_optim1
1349+
):
13461350
optim_cls = optim_info.optim_cls
13471351

13481352
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
@@ -1356,6 +1360,17 @@ def test_state_dict_deterministic(self, device, dtype, optim_info):
13561360
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
13571361
params = [weight, bias]
13581362

1363+
def make_named_param(param, is_named):
1364+
if not is_named:
1365+
return param
1366+
return [(f"name{i}", p) for i, p in enumerate(param)]
1367+
1368+
def without_param_names(state_dict):
1369+
new_state_dict = deepcopy(state_dict)
1370+
for pg in new_state_dict["param_groups"]:
1371+
pg.pop("param_names", None)
1372+
return new_state_dict
1373+
13591374
def fwd_bwd(optim, w, b, i):
13601375
optim.zero_grad()
13611376
loss = (w.mv(i) + b).pow(2).sum()
@@ -1368,7 +1383,8 @@ def fwd_bwd(optim, w, b, i):
13681383
return loss
13691384

13701385
for optim_input in all_optim_inputs:
1371-
optimizer = optim_cls(params, **optim_input.kwargs)
1386+
params_in = make_named_param(params, is_named=is_named_optim0)
1387+
optimizer = optim_cls(params_in, **optim_input.kwargs)
13721388
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
13731389

13741390
# Prime the optimizer
@@ -1383,8 +1399,8 @@ def fwd_bwd(optim, w, b, i):
13831399
with torch.no_grad():
13841400
weight_c = Parameter(weight.clone())
13851401
bias_c = Parameter(bias.clone())
1386-
1387-
optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs)
1402+
params_c = make_named_param([weight_c, bias_c], is_named=is_named_optim1)
1403+
optimizer_c = optim_cls(params_c, **optim_input.kwargs)
13881404
closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
13891405

13901406
# Load the state dict from the original optimizer into the new one
@@ -1405,13 +1421,17 @@ def fwd_bwd(optim, w, b, i):
14051421
self.assertEqual(bias, bias_c)
14061422

14071423
# Make sure state dict is deterministic with equal (not identical) parameters
1408-
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
1424+
# Param names are optional and not needed to be the consistent.
1425+
self.assertEqual(
1426+
without_param_names(optimizer.state_dict()),
1427+
without_param_names(optimizer_c.state_dict()),
1428+
)
14091429

14101430
# Make sure repeated parameters have identical representation (see #36831)
14111431
optimizer_c.param_groups.extend(optimizer_c.param_groups)
14121432
self.assertEqual(
1413-
optimizer.state_dict()["param_groups"][-1],
1414-
optimizer_c.state_dict()["param_groups"][-1],
1433+
without_param_names(optimizer.state_dict())["param_groups"][-1],
1434+
without_param_names(optimizer_c.state_dict())["param_groups"][-1],
14151435
)
14161436

14171437
@optims(optim_db, dtypes=[torch.float32])
@@ -1462,8 +1482,77 @@ def fwd_bwd(optim, mod, i):
14621482
fwd_bwd(optimizer, model, input)
14631483
optimizer.step()
14641484

1485+
@parametrize("is_named_optim0", [True, False])
1486+
@parametrize("is_named_optim1", [True, False])
1487+
@optims(
1488+
[o for o in optim_db if not o.only_supports_sparse_grads],
1489+
dtypes=[torch.float32],
1490+
)
1491+
def test_can_load_from_to_named_state_dict(
1492+
self, device, dtype, optim_info, is_named_optim0, is_named_optim1
1493+
):
1494+
optim_cls = optim_info.optim_cls
1495+
1496+
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
1497+
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
1498+
device, dtype, optim_info, skip=("differentiable",)
1499+
)
1500+
for optim_input in all_optim_inputs:
1501+
torch.manual_seed(1)
1502+
model = torch.nn.Sequential(
1503+
torch.nn.Conv2d(4, 2, 1, stride=2),
1504+
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
1505+
)
1506+
model.to(dtype=dtype, device=device)
1507+
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
1508+
1509+
def fwd_bwd(optim, mod, i):
1510+
optim.zero_grad()
1511+
loss = mod(i).sum()
1512+
loss.backward()
1513+
return loss
1514+
1515+
# test for parameters, named_parameters, and 2 groups:
1516+
params_to_optimizer = (
1517+
model.named_parameters() if is_named_optim0 else model.parameters()
1518+
)
1519+
optimizer = optim_cls(params_to_optimizer, **optim_input.kwargs)
1520+
1521+
for _ in range(3):
1522+
if optim_info.step_requires_closure:
1523+
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
1524+
else:
1525+
fwd_bwd(optimizer, model, input)
1526+
optimizer.step()
1527+
1528+
# old_state_dict has all new flags del'd
1529+
old_state_dict = deepcopy(optimizer.state_dict())
1530+
1531+
params_to_optimizer2 = (
1532+
model.named_parameters() if is_named_optim1 else model.parameters()
1533+
)
1534+
optimizer2 = optim_cls(params_to_optimizer2, **optim_input.kwargs)
1535+
optimizer2.load_state_dict(old_state_dict)
1536+
1537+
# Make sure we can still step
1538+
if optim_info.step_requires_closure:
1539+
optimizer2.step(functools.partial(fwd_bwd, optimizer2, model, input))
1540+
else:
1541+
fwd_bwd(optimizer2, model, input)
1542+
optimizer2.step()
1543+
1544+
# Make sure that param_names are preserved when provided to at least one of the optimizers
1545+
if is_named_optim0 or is_named_optim1:
1546+
self.assertEqual(
1547+
optimizer2.state_dict()["param_groups"][0]["param_names"],
1548+
["0.weight", "0.bias", "1.weight", "1.bias"],
1549+
)
1550+
1551+
@parametrize("is_named_optim", [True, False])
14651552
@optims(optim_db, dtypes=[torch.float32])
1466-
def test_save_load_equality_with_weights_only(self, device, dtype, optim_info):
1553+
def test_save_load_equality_with_weights_only(
1554+
self, device, dtype, optim_info, is_named_optim
1555+
):
14671556
optim_cls = optim_info.optim_cls
14681557

14691558
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
@@ -1477,6 +1566,11 @@ def test_save_load_equality_with_weights_only(self, device, dtype, optim_info):
14771566
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
14781567
params = [weight, bias]
14791568

1569+
def make_named_param(param, is_named):
1570+
if not is_named:
1571+
return param
1572+
return [(f"name{i}", p) for i, p in enumerate(param)]
1573+
14801574
def fwd_bwd(optim, w, b, i):
14811575
optim.zero_grad()
14821576
loss = (w.mv(i) + b).pow(2).sum()
@@ -1487,7 +1581,8 @@ def fwd_bwd(optim, w, b, i):
14871581
return loss
14881582

14891583
for optim_input in all_optim_inputs:
1490-
optimizer = optim_cls(params, **optim_input.kwargs)
1584+
params_in = make_named_param(params, is_named=is_named_optim)
1585+
optimizer = optim_cls(params_in, **optim_input.kwargs)
14911586
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
14921587

14931588
# Prime the optimizer

0 commit comments

Comments
 (0)