Skip to content

Commit 5ba6bb7

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add swap_tensors path to nn parametrizations (pytorch#124130)
Fixes pytorch#123859 Pull Request resolved: pytorch#124130 Approved by: https://github.com/albanD
1 parent 87f651c commit 5ba6bb7

15 files changed

+226
-15
lines changed

test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_True

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_initialization_parametrization_swap_True

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_buffer_parametrization_swap_True

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_nested_parametrization_swap_True

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_register_and_remove_parametrization_swap_True

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_serialization_parametrization_swap_True

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_transfer_parametrizations_and_params_right_inverse_swap_True

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_transfer_parametrizations_and_params_swap_True

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_wrapper_subclass_parametrization_swap_True

Whitespace-only changes.

test/dynamo_skips/TestNNParametrization.test_new_spectral_norm_dim_swap_True

Whitespace-only changes.

test/nn/test_parametrization.py

+206-7
Large diffs are not rendered by default.

torch/nn/utils/parametrize.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
from torch.__future__ import get_swap_module_params_on_conversion
23
from torch.nn.modules.container import ModuleList, ModuleDict, Module
34
from torch.nn.parameter import Parameter
5+
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
46
from torch import Tensor
57

68
import collections
@@ -64,6 +66,14 @@ def _register_parameter_or_buffer(module, name, X):
6466
else:
6567
module.register_buffer(name, X)
6668

69+
def _maybe_set(dest: Tensor, src: Tensor) -> None:
70+
should_swap = get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest)
71+
if should_swap:
72+
if isinstance(dest, Parameter) and not isinstance(src, Parameter):
73+
src = Parameter(src, requires_grad=dest.requires_grad)
74+
torch.utils.swap_tensors(dest, src)
75+
else:
76+
dest.set_(src) # type: ignore[call-overload]
6777

6878
class ParametrizationList(ModuleList):
6979
r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`.
@@ -157,7 +167,7 @@ def __init__(
157167
# Set the original to original so that the user does not need to re-register the parameter
158168
# manually in the optimiser
159169
with torch.no_grad():
160-
original.set_(new) # type: ignore[call-overload]
170+
_maybe_set(original, new)
161171
_register_parameter_or_buffer(self, "original", original)
162172
else:
163173
for i, originali in enumerate(new):
@@ -231,7 +241,7 @@ def right_inverse(self, value: Tensor) -> None:
231241
f"while `original` has dtype {self.original.dtype}"
232242
)
233243
# We know that the result is going to have the same dtype
234-
self.original.set_(value) # type: ignore[call-overload]
244+
_maybe_set(self.original, value)
235245
else:
236246
if not isinstance(value, collections.abc.Sequence):
237247
raise ValueError(
@@ -255,7 +265,7 @@ def right_inverse(self, value: Tensor) -> None:
255265
f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "
256266
f"while `original{i}` has dtype {original_i.dtype}"
257267
)
258-
original_i.set_(tensor)
268+
_maybe_set(original_i, tensor)
259269

260270
def forward(self) -> Tensor:
261271
if torch.jit.is_scripting():
@@ -645,18 +655,20 @@ def remove_parametrizations(
645655
# This way the user does not need to update the optimizer
646656
with torch.no_grad():
647657
if type(original) is torch.Tensor:
648-
original.set_(t)
658+
_maybe_set(original, t)
649659
else:
650660
try:
651-
original.set_(t)
661+
_maybe_set(original, t)
652662
except RuntimeError as e:
653663
# TODO: Fix this for tensor subclasses that are parameters:
654664
# RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
655665
raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True "
656666
"for a parameter that is an instance of a tensor subclass requires "
657-
"set_() to be implemented correctly for the tensor subclass. Either "
658-
"set leave_parametrized=False or provide a working implementation for "
659-
"set_() in the tensor subclass.") from e
667+
"set_() to be implemented correctly for the tensor subclass."
668+
"Alternatively, one can opt into the swap_tensors path"
669+
"Either set leave_parametrized=False or provide a working implementation"
670+
"for set_() in the tensor subclass or set "
671+
"torch.__future__.set_swap_module_params_on_conversion(True).") from e
660672
else:
661673
if leave_parametrized:
662674
# We cannot use no_grad because we need to know whether one or more

0 commit comments

Comments
 (0)