-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathmodel.py
executable file
·71 lines (54 loc) · 2.27 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import print_function
import argparse
import numpy as np
import torch
import torch.nn as nn
from src import resnet_models
def Detector(MODEL_SELECT, NUM_SPOOF_CLASS):
if MODEL_SELECT == 0:
print('using ResNet34.')
model = resnet_models.resnet34(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 1:
print('using SEResNet34.')
model = resnet_models.se_resnet34(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 2:
print('using ResNet50.')
model = resnet_models.resnet50(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 3:
print('using SEResNet50.')
model = resnet_models.se_resnet50(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 4:
print('using Res2Net50_26w_4s.')
model = resnet_models.res2net50_v1b(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 5:
print('using SERes2Net50_26w_4s.')
model = resnet_models.se_res2net50_v1b(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 6:
print('using Res2Net50_14w_8s.')
model = resnet_models.res2net50_v1b_14w_8s(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 7:
print('using SERes2Net50_14w_8s.')
model = resnet_models.se_res2net50_v1b_14w_8s(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 8:
print('using Res2Net50_26w_8s.')
model = resnet_models.res2net50_v1b_26w_8s(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
elif MODEL_SELECT == 9:
print('using SERes2Net50_26w_8s.')
model = resnet_models.se_res2net50_v1b_26w_8s(num_classes=NUM_SPOOF_CLASS, KaimingInit=True)
return model
def test_Detector(model_id=0):
model_params = {
"MODEL_SELECT" : model_id,
"NUM_SPOOF_CLASS" : 2
}
model = Detector(**model_params)
model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('model contains {} parameters'.format(model_params))
# print(model)
x = torch.randn(2,1,257,400)
output = model(x)
print(x.size())
print(output.size())
if __name__ == '__main__':
for id in range(10):
test_Detector(id)