@@ -261,6 +261,7 @@ def __init__(self):
261
261
def forward (self , input ):
262
262
return self .x + self .bn (self .fc1 (input ))
263
263
264
+ swap = torch .__future__ .get_swap_module_params_on_conversion ()
264
265
net = MyModule ()
265
266
state_dict = net .state_dict (keep_vars = keep_vars )
266
267
for v in state_dict .values ():
@@ -276,16 +277,21 @@ def forward(self, input):
276
277
net_meta_state_dict = net_meta .state_dict (keep_vars = True )
277
278
for key in state_dict .keys ():
278
279
if key in net_meta ._parameters :
279
- self .assertEqual (net_meta_state_dict_old [key ].requires_grad , net_meta_state_dict [key ].requires_grad )
280
- if keep_vars :
280
+ if keep_vars and not swap :
281
281
# state_dict[key] is an nn.Parameter
282
282
self .assertTrue (state_dict [key ] is net_meta_state_dict [key ])
283
283
else :
284
- # state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
285
- self .assertTrue (net_meta_state_dict [key ] is not net_meta_state_dict_old [key ])
286
- self .assertEqual (state_dict [key ], net_meta_state_dict [key ])
284
+ if swap :
285
+ self .assertTrue (net_meta_state_dict [key ] is net_meta_state_dict_old [key ])
286
+ else :
287
+ # state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
288
+ self .assertTrue (net_meta_state_dict [key ] is not net_meta_state_dict_old [key ])
289
+ self .assertEqual (net_meta_state_dict_old [key ].requires_grad , net_meta_state_dict [key ].requires_grad )
290
+ self .assertEqual (net_meta_state_dict_old [key ].requires_grad , net_meta_state_dict [key ].requires_grad )
291
+ self .assertEqual (state_dict [key ], net_meta_state_dict [key ])
287
292
elif key in net_meta ._buffers and key not in net_meta ._non_persistent_buffers_set :
288
293
self .assertTrue (state_dict [key ] is net_meta_state_dict [key ])
294
+ self .assertEqual (state_dict [key ], net_meta_state_dict [key ])
289
295
290
296
# Make sure that ordering of parameters and buffers is preserved
291
297
net_named_parameters = net .named_parameters ()
@@ -391,19 +397,32 @@ def test_load_state_dict_warn_assign(self):
391
397
def load_torch_function_handler (cls , func , types , args = (), kwargs = None ):
392
398
kwargs = {} if kwargs is None else kwargs
393
399
394
- def module_load (dest , src ):
395
- # always convert src to cls
400
+ def module_load (dest , src , assign = False ):
396
401
if isinstance (dest , cls ):
397
- if type (src ) is torch .Tensor :
398
- return cls (src )
399
- elif type (src ) is cls :
402
+ if assign :
400
403
return src .detach ()
401
404
else :
402
- if isinstance (src , MyWrapperLoadTensor ):
403
- return cls (src ._data )
404
- return cls (src )
405
+ if type (src ) is torch .Tensor :
406
+ return cls (src )
407
+ elif type (src ) is cls :
408
+ return src .detach ()
409
+ else :
410
+ if isinstance (src , MyWrapperLoadTensor ):
411
+ return cls (src ._data )
412
+ return cls (src )
405
413
else :
406
- return src .detach ()
414
+ assert isinstance (src , cls ), f"Expected isinstance(src, { cls } ) but got { type (src )} "
415
+ assert type (dest ) == torch .Tensor or type (dest ) == torch .nn .Parameter or issubclass (cls , type (dest ))
416
+ if assign :
417
+ return src .detach ()
418
+ else :
419
+ if isinstance (src , MyWrapperLoadTensor ):
420
+ if type (dest ) not in {torch .Tensor , torch .nn .Parameter }:
421
+ return type (dest )(src ._data )
422
+ else :
423
+ return src ._data .detach ()
424
+ else :
425
+ return torch .Tensor (src )
407
426
408
427
if func is torch .Tensor .module_load :
409
428
return module_load (* args , ** kwargs )
@@ -478,7 +497,8 @@ class TestLoadStateDictSwap(TestCase):
478
497
@skipIfCrossRef
479
498
@skipIfTorchDynamo ("Can't swap with dynamo as dynamo installs weakrefs" )
480
499
@swap ([True ])
481
- def test_swap_subclass (self ):
500
+ @parametrize ("assign" , [True , False ])
501
+ def test_swap_subclass (self , assign ):
482
502
483
503
def _create_model (subclass = None ):
484
504
m = torch .nn .Linear (2 , 3 , bias = False )
@@ -491,24 +511,20 @@ def _create_model(subclass=None):
491
511
def _test (m_subclass = None , sd_subclass = None ):
492
512
m = _create_model (m_subclass )
493
513
sd = _create_model (sd_subclass ).state_dict ()
494
- sd = sd
495
- m .load_state_dict (sd )
514
+ m .load_state_dict (sd , assign = assign )
496
515
self .assertEqual (m .weight , sd ['weight' ])
497
516
self .assertEqual (m .buf , sd ['buf' ])
498
517
self .assertTrue (isinstance (m .weight , torch .nn .Parameter ))
499
518
self .assertTrue (not isinstance (m .buf , torch .nn .Parameter ))
500
519
501
520
weight_type , buf_type = (torch .nn .Parameter , torch .Tensor )
502
- if m_subclass is not None and sd_subclass is not None :
503
- # handler of subclass takes precedence over superclass
504
- if issubclass (sd_subclass , m_subclass ):
521
+ if assign :
522
+ if sd_subclass is not None :
505
523
weight_type , buf_type = (sd_subclass , sd_subclass )
506
- else :
524
+ else :
525
+ if m_subclass is not None :
507
526
weight_type , buf_type = (m_subclass , m_subclass )
508
- elif m_subclass is not None :
509
- weight_type , buf_type = (m_subclass , m_subclass )
510
- elif sd_subclass is not None :
511
- weight_type , buf_type = (sd_subclass , sd_subclass )
527
+
512
528
self .assertTrue (type (m .weight ) is weight_type )
513
529
self .assertTrue (type (m .buf ) is buf_type )
514
530
0 commit comments