Skip to content

Commit 4b39033

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add assign argument to torch.Tensor.module_load (pytorch#121158)
Make `torch.__future__.get_swap_module_params_on_conversion() == True` account for `assign` argument to `nn.Module.load_state_dict` Similar to when `torch.__future__.set_swap_module_params_on_conversion()` is `False`, `assign=True` means that we do not incur a `self.copy_(other)` and the properties of `other` will be preserved Pull Request resolved: pytorch#121158 Approved by: https://github.com/albanD ghstack dependencies: pytorch#121157
1 parent 27389e0 commit 4b39033

File tree

4 files changed

+74
-50
lines changed

4 files changed

+74
-50
lines changed

test/nn/test_load_state_dict.py

+41-25
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def __init__(self):
261261
def forward(self, input):
262262
return self.x + self.bn(self.fc1(input))
263263

264+
swap = torch.__future__.get_swap_module_params_on_conversion()
264265
net = MyModule()
265266
state_dict = net.state_dict(keep_vars=keep_vars)
266267
for v in state_dict.values():
@@ -276,16 +277,21 @@ def forward(self, input):
276277
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
277278
for key in state_dict.keys():
278279
if key in net_meta._parameters:
279-
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
280-
if keep_vars:
280+
if keep_vars and not swap:
281281
# state_dict[key] is an nn.Parameter
282282
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
283283
else:
284-
# state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
285-
self.assertTrue(net_meta_state_dict[key] is not net_meta_state_dict_old[key])
286-
self.assertEqual(state_dict[key], net_meta_state_dict[key])
284+
if swap:
285+
self.assertTrue(net_meta_state_dict[key] is net_meta_state_dict_old[key])
286+
else:
287+
# state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
288+
self.assertTrue(net_meta_state_dict[key] is not net_meta_state_dict_old[key])
289+
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
290+
self.assertEqual(net_meta_state_dict_old[key].requires_grad, net_meta_state_dict[key].requires_grad)
291+
self.assertEqual(state_dict[key], net_meta_state_dict[key])
287292
elif key in net_meta._buffers and key not in net_meta._non_persistent_buffers_set:
288293
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
294+
self.assertEqual(state_dict[key], net_meta_state_dict[key])
289295

290296
# Make sure that ordering of parameters and buffers is preserved
291297
net_named_parameters = net.named_parameters()
@@ -391,19 +397,32 @@ def test_load_state_dict_warn_assign(self):
391397
def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
392398
kwargs = {} if kwargs is None else kwargs
393399

394-
def module_load(dest, src):
395-
# always convert src to cls
400+
def module_load(dest, src, assign=False):
396401
if isinstance(dest, cls):
397-
if type(src) is torch.Tensor:
398-
return cls(src)
399-
elif type(src) is cls:
402+
if assign:
400403
return src.detach()
401404
else:
402-
if isinstance(src, MyWrapperLoadTensor):
403-
return cls(src._data)
404-
return cls(src)
405+
if type(src) is torch.Tensor:
406+
return cls(src)
407+
elif type(src) is cls:
408+
return src.detach()
409+
else:
410+
if isinstance(src, MyWrapperLoadTensor):
411+
return cls(src._data)
412+
return cls(src)
405413
else:
406-
return src.detach()
414+
assert isinstance(src, cls), f"Expected isinstance(src, {cls}) but got {type(src)}"
415+
assert type(dest) == torch.Tensor or type(dest) == torch.nn.Parameter or issubclass(cls, type(dest))
416+
if assign:
417+
return src.detach()
418+
else:
419+
if isinstance(src, MyWrapperLoadTensor):
420+
if type(dest) not in {torch.Tensor, torch.nn.Parameter}:
421+
return type(dest)(src._data)
422+
else:
423+
return src._data.detach()
424+
else:
425+
return torch.Tensor(src)
407426

408427
if func is torch.Tensor.module_load:
409428
return module_load(*args, **kwargs)
@@ -478,7 +497,8 @@ class TestLoadStateDictSwap(TestCase):
478497
@skipIfCrossRef
479498
@skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs")
480499
@swap([True])
481-
def test_swap_subclass(self):
500+
@parametrize("assign", [True, False])
501+
def test_swap_subclass(self, assign):
482502

483503
def _create_model(subclass=None):
484504
m = torch.nn.Linear(2, 3, bias=False)
@@ -491,24 +511,20 @@ def _create_model(subclass=None):
491511
def _test(m_subclass=None, sd_subclass=None):
492512
m = _create_model(m_subclass)
493513
sd = _create_model(sd_subclass).state_dict()
494-
sd = sd
495-
m.load_state_dict(sd)
514+
m.load_state_dict(sd, assign=assign)
496515
self.assertEqual(m.weight, sd['weight'])
497516
self.assertEqual(m.buf, sd['buf'])
498517
self.assertTrue(isinstance(m.weight, torch.nn.Parameter))
499518
self.assertTrue(not isinstance(m.buf, torch.nn.Parameter))
500519

501520
weight_type, buf_type = (torch.nn.Parameter, torch.Tensor)
502-
if m_subclass is not None and sd_subclass is not None:
503-
# handler of subclass takes precedence over superclass
504-
if issubclass(sd_subclass, m_subclass):
521+
if assign:
522+
if sd_subclass is not None:
505523
weight_type, buf_type = (sd_subclass, sd_subclass)
506-
else:
524+
else:
525+
if m_subclass is not None:
507526
weight_type, buf_type = (m_subclass, m_subclass)
508-
elif m_subclass is not None:
509-
weight_type, buf_type = (m_subclass, m_subclass)
510-
elif sd_subclass is not None:
511-
weight_type, buf_type = (sd_subclass, sd_subclass)
527+
512528
self.assertTrue(type(m.weight) is weight_type)
513529
self.assertTrue(type(m.buf) is buf_type)
514530

torch/_tensor.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def share_memory_(self):
711711
self._typed_storage()._share_memory_()
712712
return self
713713

714-
def module_load(self, other):
714+
def module_load(self, other, assign=False):
715715
r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
716716
717717
Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
@@ -723,16 +723,23 @@ def module_load(self, other):
723723
724724
.. note::
725725
This method should always return a new object that is not ``self`` or ``other``.
726-
For example, the default implementation returns ``self.copy_(other).detach()``.
726+
For example, the default implementation returns ``self.copy_(other).detach()``
727+
if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
727728
728729
Args:
729730
other (Tensor): value in state dict with key corresponding to ``self``
731+
assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
730732
731733
"""
732734
if has_torch_function_variadic(self, other):
733-
return handle_torch_function(Tensor.module_load, (self, other), self, other)
734-
# In the default case, swap_tensors becomes a no-op
735-
return self.copy_(other).detach()
735+
return handle_torch_function(
736+
Tensor.module_load, (self, other), self, other, assign=assign
737+
)
738+
739+
if assign:
740+
return other.detach()
741+
else:
742+
return self.copy_(other).detach()
736743

737744
def __reversed__(self):
738745
r"""Reverses the tensor along dimension 0."""

torch/nn/modules/module.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -2046,25 +2046,26 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
20462046

20472047
try:
20482048
with torch.no_grad():
2049-
if assign_to_params_buffers:
2049+
if use_swap_tensors:
2050+
new_input_param = param.module_load(input_param, assign=assign_to_params_buffers)
2051+
if id(new_input_param) == id(input_param) or id(new_input_param) == id(param):
2052+
raise RuntimeError("module_load returned one of self or other, please .detach() "
2053+
"the result if returning one of the inputs in module_load")
2054+
if (isinstance(param, torch.nn.Parameter)):
2055+
if not isinstance(new_input_param, torch.nn.Parameter):
2056+
new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param.requires_grad)
2057+
else:
2058+
new_input_param.requires_grad_(param.requires_grad)
2059+
torch.utils.swap_tensors(param, new_input_param)
2060+
del new_input_param
2061+
elif assign_to_params_buffers:
20502062
# Shape checks are already done above
20512063
if (isinstance(param, torch.nn.Parameter)):
20522064
if not isinstance(input_param, torch.nn.Parameter):
20532065
input_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad)
20542066
else:
20552067
input_param.requires_grad_(param.requires_grad)
20562068
setattr(self, name, input_param)
2057-
elif use_swap_tensors:
2058-
param_requires_grad = param.requires_grad
2059-
new_input_param = param.module_load(input_param)
2060-
if id(new_input_param) == id(input_param) or id(new_input_param) == id(param):
2061-
raise RuntimeError("module_load returned one of self or other, please .detach() "
2062-
"the result if returning one of the inputs in module_load")
2063-
if (isinstance(param, torch.nn.Parameter) and
2064-
not isinstance(new_input_param, torch.nn.Parameter)):
2065-
new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param_requires_grad)
2066-
torch.utils.swap_tensors(param, new_input_param)
2067-
del new_input_param
20682069
else:
20692070
param.copy_(input_param)
20702071
except Exception as ex:
@@ -2104,20 +2105,20 @@ def load_state_dict(self, state_dict: Mapping[str, Any],
21042105
21052106
.. warning::
21062107
If :attr:`assign` is ``True`` the optimizer must be created after
2107-
the call to :attr:`load_state_dict`.
2108+
the call to :attr:`load_state_dict` unless
2109+
:func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
21082110
21092111
Args:
21102112
state_dict (dict): a dict containing parameters and
21112113
persistent buffers.
21122114
strict (bool, optional): whether to strictly enforce that the keys
21132115
in :attr:`state_dict` match the keys returned by this module's
21142116
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
2115-
assign (bool, optional): whether to assign items in the state
2116-
dictionary to their corresponding keys in the module instead
2117-
of copying them inplace into the module's current parameters and buffers.
2118-
When ``False``, the properties of the tensors in the current
2119-
module are preserved while when ``True``, the properties of the
2120-
Tensors in the state dict are preserved.
2117+
assign (bool, optional): When ``False``, the properties of the tensors
2118+
in the current module are preserved while when ``True``, the
2119+
properties of the Tensors in the state dict are preserved. The only
2120+
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
2121+
for which the value from the module is preserved.
21212122
Default: ``False``
21222123
21232124
Returns:

torch/overrides.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1367,7 +1367,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
13671367
Tensor.map_: lambda self, tensor, callable: -1,
13681368
Tensor.map2_: lambda self, x, y, callable: -1,
13691369
Tensor.mm: lambda self, mat2: -1,
1370-
Tensor.module_load: lambda self, other: -1,
1370+
Tensor.module_load: lambda self, other, assign=False: -1,
13711371
Tensor.narrow_copy: lambda self, dimension, start, length: -1,
13721372
Tensor.ndimension: lambda self: -1,
13731373
Tensor.nelement: lambda self: -1,

0 commit comments

Comments
 (0)