Skip to content
This repository was archived by the owner on Apr 17, 2023. It is now read-only.

Commit cc87788

Browse files
Change model to make it prunable
1 parent 54ba726 commit cc87788

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

torchreid/engine/engine.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,18 @@ def run(
326326
self.fixbase_epoch = fixbase_epoch
327327
test_acc = AverageMeter()
328328
print('=> Start training')
329-
329+
res = self.test(
330+
self.epoch,
331+
dist_metric=dist_metric,
332+
normalize_feature=normalize_feature,
333+
visrank=visrank,
334+
visrank_topk=visrank_topk,
335+
save_dir=save_dir,
336+
use_metric_cuhk03=use_metric_cuhk03,
337+
ranks=ranks,
338+
lr_finder=lr_finder,
339+
)
340+
print(res)
330341
if perf_monitor and not lr_finder: perf_monitor.on_train_begin()
331342
for self.epoch in range(self.start_epoch, self.max_epoch):
332343
# change the NumPy’s seed at every epoch

torchreid/models/mobilenetv3.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
from torch.nn import functional as F
56

67
from torchreid.losses import AngleSimpleLinear
78
from torchreid.ops import Dropout, EvalModeSetter, rsc
@@ -33,19 +34,17 @@
3334
class SELayer(nn.Module):
3435
def __init__(self, channel, reduction=4):
3536
super(SELayer, self).__init__()
36-
self.avg_pool = nn.AdaptiveAvgPool2d(1)
3737
self.fc = nn.Sequential(
38-
nn.Linear(channel, make_divisible(channel // reduction, 8)),
38+
nn.Conv2d(channel, make_divisible(channel // reduction, 8), 1),
3939
nn.ReLU(inplace=True),
40-
nn.Linear(make_divisible(channel // reduction, 8), channel),
40+
nn.Conv2d(make_divisible(channel // reduction, 8), channel, 1),
4141
HSigmoid()
4242
)
4343

4444
def forward(self, x):
4545
with no_nncf_se_layer_context():
46-
b, c, _, _ = x.size()
47-
y = self.avg_pool(x).view(b, c)
48-
y = self.fc(y).view(b, c, 1, 1)
46+
y = F.adaptive_avg_pool2d(x, 1)
47+
y = self.fc(y)
4948
return x * y
5049

5150

torchreid/utils/torchtools.py

+14
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,19 @@ def _print_loading_weights_inconsistencies(discarded_layers, unmatched_layers):
285285
)
286286

287287

288+
def update_checkpoint_mobilenet_v3(checkpoint):
289+
fc = []
290+
for k in checkpoint:
291+
if 'fc' in k and not 'bias' in k:
292+
fc.append(k)
293+
for name in fc:
294+
w = checkpoint[name]
295+
shape = w.shape
296+
w_new = w.view(shape + (1, 1))
297+
print(name, ': ', checkpoint[name].shape, '->', w_new.shape)
298+
checkpoint[name] = w_new
299+
300+
288301
def load_pretrained_weights(model, file_path='', pretrained_dict=None):
289302
r"""Loads pretrianed weights to model.
290303
Features::
@@ -317,6 +330,7 @@ def _remove_prefix(key, prefix):
317330
else:
318331
state_dict = checkpoint
319332

333+
update_checkpoint_mobilenet_v3(state_dict)
320334
model_dict = model.state_dict()
321335
new_state_dict = OrderedDict()
322336
matched_layers, discarded_layers = [], []

0 commit comments

Comments
 (0)