Skip to content

Commit 3ecc5a5

Browse files
committed
resolve issues
1 parent 37e9aaa commit 3ecc5a5

8 files changed

+45
-45
lines changed

configs/classification/efficientnet_v2_b0.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ train:
8282
enable: True
8383
ema_decay: 0.999
8484
mix_precision: True
85-
85+
8686
test:
8787
batch_size: 128
8888
evaluate: False

configs/multilabel_classification/efficientnet_v1_b0.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ model:
1919
p: 0.1
2020

2121
custom_datasets:
22-
roots: ['/ssd/datasets/nus_wide/train.json', '/ssd/datasets/nus_wide/val.json']
22+
roots: ['data/coco/train.json', 'data/coco/val.json']
2323
types: ['multilabel_classification', 'multilabel_classification']
24-
names: ['nus_wide_train', 'nus_wide_val']
24+
names: ['coco_train', 'coco_val']
2525

2626
data:
2727
root: './'
28-
sources: ['train_data']
29-
targets: ['val_data']
28+
sources: ['coco_data']
29+
targets: ['coco_data']
3030
height: 448
3131
width: 448
3232
norm_mean: [0.485, 0.456, 0.406]

configs/multilabel_classification/efficientnet_v2_b0.yml

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ model:
1919
p: 0.1
2020

2121
custom_datasets:
22-
roots: ['/ssd/datasets/nus_wide/train.json', '/ssd/datasets/nus_wide/val.json']
22+
roots: ['data/coco/train.json', 'data/coco/val.json']
2323
types: ['multilabel_classification', 'multilabel_classification']
24-
names: ['nus_wide_train', 'nus_wide_val']
24+
names: ['coco_train', 'coco_val']
2525

2626
data:
2727
root: './'
28-
sources: ['train_data']
29-
targets: ['val_data']
28+
sources: ['coco_data']
29+
targets: ['coco_data']
3030
height: 448
3131
width: 448
3232
norm_mean: [0.485, 0.456, 0.406]
@@ -80,7 +80,7 @@ train:
8080
enable: True
8181
ema_decay: 0.9995
8282
mix_precision: True
83-
83+
8484
test:
8585
batch_size: 128
8686
evaluate: False

configs/multilabel_classification/efficientnet_v2_small.yml

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ model:
1919
p: 0.1
2020

2121
custom_datasets:
22-
roots: ['/ssd/datasets/nus_wide/train.json', '/ssd/datasets/nus_wide/val.json']
22+
roots: ['data/coco/train.json', 'data/coco/val.json']
2323
types: ['multilabel_classification', 'multilabel_classification']
24-
names: ['nus_wide_train', 'nus_wide_val']
24+
names: ['coco_train', 'coco_val']
2525

2626
data:
2727
root: './'
28-
sources: ['train_data']
29-
targets: ['val_data']
28+
sources: ['coco_data']
29+
targets: ['coco_data']
3030
height: 448
3131
width: 448
3232
norm_mean: [0.5, 0.5, 0.5]
@@ -80,7 +80,7 @@ train:
8080
enable: True
8181
ema_decay: 0.9995
8282
mix_precision: True
83-
83+
8484
test:
8585
batch_size: 128
8686
evaluate: False

configs/multilabel_classification/mobilenetv3_large.yml

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ model:
1919
p: 0.1
2020

2121
custom_datasets:
22-
roots: ['/data/train.json', '/data/val.json']
22+
roots: ['data/coco/train.json', 'data/coco/val.json']
2323
types: ['multilabel_classification', 'multilabel_classification']
24-
names: ['train_data', 'val_data']
24+
names: ['coco_train', 'coco_val']
2525

2626
data:
2727
root: './'
28-
sources: ['train_data']
29-
targets: ['val_data']
28+
sources: ['coco_data']
29+
targets: ['coco_data']
3030
height: 448
3131
width: 448
3232
norm_mean: [0, 0, 0]
@@ -80,7 +80,7 @@ train:
8080
enable: True
8181
ema_decay: 0.9997
8282
mix_precision: True
83-
83+
8484
test:
8585
batch_size: 128
8686
evaluate: False

configs/multilabel_classification/mobilenetv3_large_75.yml

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ model:
1919
p: 0.1
2020

2121
custom_datasets:
22-
roots: ['/data/train.json', '/data/val.json']
22+
roots: ['data/coco/train.json', 'data/coco/val.json']
2323
types: ['multilabel_classification', 'multilabel_classification']
24-
names: ['train_data', 'val_data']
24+
names: ['coco_train', 'coco_val']
2525

2626
data:
2727
root: './'
28-
sources: ['train_data']
29-
targets: ['val_data']
28+
sources: ['coco_data']
29+
targets: ['coco_data']
3030
height: 448
3131
width: 448
3232
norm_mean: [0.485, 0.456, 0.406]
@@ -80,7 +80,7 @@ train:
8080
enable: True
8181
ema_decay: 0.9995
8282
mix_precision: True
83-
83+
8484
test:
8585
batch_size: 128
8686
evaluate: False

configs/multilabel_classification/mobilenetv3_small.yml

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ model:
1717
save_all_chkpts: False
1818

1919
custom_datasets:
20-
roots: ['/data/train.json', '/data/val.json']
20+
roots: ['data/coco/train.json', 'data/coco/val.json']
2121
types: ['multilabel_classification', 'multilabel_classification']
22-
names: ['train_data', 'val_data']
22+
names: ['coco_train', 'coco_val']
2323

2424
data:
2525
root: './'
26-
sources: ['train_data']
27-
targets: ['val_data']
26+
sources: ['coco_data']
27+
targets: ['coco_data']
2828
height: 448
2929
width: 448
3030
norm_mean: [0.485, 0.456, 0.406]
@@ -78,7 +78,7 @@ train:
7878
enable: True
7979
ema_decay: 0.9995
8080
mix_precision: True
81-
81+
8282
test:
8383
batch_size: 128
8484
evaluate: False

torchreid/utils/torchtools.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,26 @@
2121
'load_pretrained_weights', 'ModelEmaV2'
2222
]
2323

24+
25+
def params_to_device(param, device):
26+
def tensor_to_device(param, device):
27+
param.data = param.data.to(device)
28+
if param._grad is not None:
29+
param._grad.data = param._grad.data.to(device)
30+
31+
if isinstance(param, torch.Tensor):
32+
tensor_to_device(param, device)
33+
elif isinstance(param, dict):
34+
for subparam in param.values():
35+
tensor_to_device(subparam, device)
36+
2437
def optimizer_to(optim, device):
2538
for param in optim.state.values():
26-
# Not sure there are any global tensors in the state dict
27-
if isinstance(param, torch.Tensor):
28-
param.data = param.data.to(device)
29-
if param._grad is not None:
30-
param._grad.data = param._grad.data.to(device)
31-
elif isinstance(param, dict):
32-
for subparam in param.values():
33-
if isinstance(subparam, torch.Tensor):
34-
subparam.data = subparam.data.to(device)
35-
if subparam._grad is not None:
36-
subparam._grad.data = subparam._grad.data.to(device)
39+
params_to_device(param, device)
3740

3841
def scheduler_to(sched, device):
3942
for param in sched.__dict__.values():
40-
if isinstance(param, torch.Tensor):
41-
param.data = param.data.to(device)
42-
if param._grad is not None:
43-
param._grad.data = param._grad.data.to(device)
43+
params_to_device(param, device)
4444

4545
def save_checkpoint(
4646
state, save_dir, is_best=False, remove_module_from_keys=False, name='model'

0 commit comments

Comments
 (0)