-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathobject_place_net.py
56 lines (40 loc) · 1.73 KB
/
object_place_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import torch
import torch.nn.functional as F
from torch import nn
from config import opt
from resnet_4ch import resnet
class ObjectPlaceNet(nn.Module):
def __init__(self, backbone_pretrained=True):
super(ObjectPlaceNet, self).__init__()
## Backbone, only resnet
resnet_layers = int(opt.backbone.split('resnet')[-1])
backbone = resnet(resnet_layers,
backbone_pretrained,
os.path.join(opt.pretrained_model_path, opt.backbone+'.pth'))
# drop pool layer and fc layer
features = list(backbone.children())[:-2]
backbone = nn.Sequential(*features)
self.backbone = backbone
## global predict
self.global_feature_dim = 512 if opt.backbone in ['resnet18', 'resnet34'] else 2048
self.avgpool3x3 = nn.AdaptiveAvgPool2d(3)
self.avgpool1x1 = nn.AdaptiveAvgPool2d(1)
self.prediction_head = nn.Linear(self.global_feature_dim, opt.class_num, bias=False)
def forward(self, img_cat):
''' img_cat:b,4,256,256 '''
global_feature = None
if opt.without_mask:
img_cat = img_cat[:,0:3]
feature_map = self.backbone(img_cat) # b,512,8,8 (resnet layer4 output shape: b,c,8,8, if resnet18, c=512)
global_feature = self.avgpool1x1(feature_map) # b,512,1,1
global_feature = global_feature.flatten(1) # b,512
prediction = self.prediction_head(global_feature)
return prediction
if __name__ == '__main__':
device = torch.device('cuda:0')
b = 4
img_cat = torch.randn(b, 4, 256, 256).to(device)
model = ObjectPlaceNet(backbone_pretrained=False).to(device)
local_pre = model(img_cat)
print(local_pre)