Skip to content

Commit c13177f

Browse files
edpizzipytorchmergebot
authored andcommitted
[FSDP] Propagate requires_grad attribute to unsharded params (pytorch#109892)
Summary: This preserves `requires_grad` in the case where all parameters within a `FlatParameter` have the same `requires_grad` value. Currently, unsharded parameters have `requires_grad=True` in some cases where the `FlatParameter` and all original parameters have `requires_grad=False`. This could be extended to support `FlatParameters` with a mix of `requires_grad` states by extending `ParamInfo` to capture `requires_grad` for each parameter. Test Plan: test added Differential Revision: D49517155 Pull Request resolved: pytorch#109892 Approved by: https://github.com/awgu
1 parent ebb30bd commit c13177f

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

test/distributed/fsdp/test_fsdp_freezing_weights.py

+4
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def test_freezing_weights(
165165
msg="FullyShardedDataParallel states didn't match PyTorch DDP states",
166166
)
167167

168+
if freezing_method == FreezingMethod.RequiresGrad:
169+
for ddp_param, fsdp_param in zip(ddp_state, fsdp_state):
170+
self.assertEqual(ddp_param.requires_grad, fsdp_param.requires_grad)
171+
168172

169173
instantiate_parametrized_tests(TestFreezingWeights)
170174

torch/distributed/fsdp/flat_param.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1807,13 +1807,21 @@ def _use_unsharded_views(self, as_params: bool) -> None:
18071807
# A `DTensor` `view` is not compatible with assigning
18081808
# `param.data = view`, so we cannot preserve the parameter
18091809
# variable.
1810-
self._setattr_param(module, param_name, nn.Parameter(view))
1810+
self._setattr_param(
1811+
module,
1812+
param_name,
1813+
nn.Parameter(view, requires_grad=flat_param.requires_grad),
1814+
)
18111815
continue
18121816
param = self.flat_param._params[i]
18131817
self._setattr_param(module, param_name, param)
18141818
param.data = view
18151819
elif as_params:
1816-
self._setattr_param(module, param_name, nn.Parameter(view))
1820+
self._setattr_param(
1821+
module,
1822+
param_name,
1823+
nn.Parameter(view, requires_grad=flat_param.requires_grad),
1824+
)
18171825
else: # `as_params=False`
18181826
param_var: Tensor = view
18191827
if self._use_orig_params:

0 commit comments

Comments
 (0)