|
21 | 21 | 'load_pretrained_weights', 'ModelEmaV2'
|
22 | 22 | ]
|
23 | 23 |
|
| 24 | + |
| 25 | +def params_to_device(param, device): |
| 26 | + def tensor_to_device(param, device): |
| 27 | + param.data = param.data.to(device) |
| 28 | + if param._grad is not None: |
| 29 | + param._grad.data = param._grad.data.to(device) |
| 30 | + |
| 31 | + if isinstance(param, torch.Tensor): |
| 32 | + tensor_to_device(param, device) |
| 33 | + elif isinstance(param, dict): |
| 34 | + for subparam in param.values(): |
| 35 | + tensor_to_device(subparam, device) |
| 36 | + |
24 | 37 | def optimizer_to(optim, device):
|
25 | 38 | for param in optim.state.values():
|
26 |
| - # Not sure there are any global tensors in the state dict |
27 |
| - if isinstance(param, torch.Tensor): |
28 |
| - param.data = param.data.to(device) |
29 |
| - if param._grad is not None: |
30 |
| - param._grad.data = param._grad.data.to(device) |
31 |
| - elif isinstance(param, dict): |
32 |
| - for subparam in param.values(): |
33 |
| - if isinstance(subparam, torch.Tensor): |
34 |
| - subparam.data = subparam.data.to(device) |
35 |
| - if subparam._grad is not None: |
36 |
| - subparam._grad.data = subparam._grad.data.to(device) |
| 39 | + params_to_device(param, device) |
37 | 40 |
|
38 | 41 | def scheduler_to(sched, device):
|
39 | 42 | for param in sched.__dict__.values():
|
40 |
| - if isinstance(param, torch.Tensor): |
41 |
| - param.data = param.data.to(device) |
42 |
| - if param._grad is not None: |
43 |
| - param._grad.data = param._grad.data.to(device) |
| 43 | + params_to_device(param, device) |
44 | 44 |
|
45 | 45 | def save_checkpoint(
|
46 | 46 | state, save_dir, is_best=False, remove_module_from_keys=False, name='model'
|
|
0 commit comments