Skip to content

Commit b498c31

Browse files
authored
Merge pull request #112 from joshanderson-kw/dev/add-uho-activity-detector
Add UHO activity detector
2 parents ad87ee4 + 7200655 commit b498c31

File tree

8 files changed

+916
-6
lines changed

8 files changed

+916
-6
lines changed

angel_system/uho/__init__.py

Whitespace-only changes.

angel_system/uho/src/models/components/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
from torch.nn import functional as F
5+
from torch.nn.init import constant, normal
6+
import time
7+
import math
8+
from functools import reduce
9+
import torchvision.models as models
10+
import sklearn.utils.class_weight as class_weight
11+
from collections import OrderedDict
12+
#from .vidswin import VideoSwinTransformerBackbone
13+
import pdb
14+
15+
16+
configs = {
17+
'video_swin_t_p4w7':
18+
dict(patch_size=(2, 4, 4),
19+
embed_dim=96,
20+
depths=[2, 2, 6, 2],
21+
num_heads=[3, 6, 12, 24],
22+
window_size=(8, 7, 7),
23+
mlp_ratio=4.,
24+
qkv_bias=True,
25+
qk_scale=None,
26+
drop_rate=0.,
27+
attn_drop_rate=0.,
28+
drop_path_rate=0.2,
29+
patch_norm=True,
30+
use_checkpoint=False
31+
),
32+
'video_swin_s_p4w7':
33+
dict(patch_size=(2, 4, 4),
34+
embed_dim=96,
35+
depths=[2, 2, 18, 2],
36+
num_heads=[3, 6, 12, 24],
37+
window_size=(8, 7, 7),
38+
mlp_ratio=4.,
39+
qkv_bias=True,
40+
qk_scale=None,
41+
drop_rate=0.,
42+
attn_drop_rate=0.,
43+
drop_path_rate=0.2,
44+
patch_norm=True,
45+
use_checkpoint=False
46+
),
47+
'video_swin_b_p4w7':
48+
dict(patch_size=(2, 4, 4),
49+
embed_dim=128,
50+
depths=[2, 2, 18, 2],
51+
num_heads=[4, 8, 16, 32],
52+
window_size=(8, 7, 7),
53+
mlp_ratio=4.,
54+
qkv_bias=True,
55+
qk_scale=None,
56+
drop_rate=0.,
57+
attn_drop_rate=0.,
58+
drop_path_rate=0.2,
59+
patch_norm=True,
60+
use_checkpoint=False
61+
)
62+
}
63+
64+
class LayerNorm(nn.Module):
65+
def __init__(self, hidden_size, eps=1e-12):
66+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
67+
"""
68+
super(LayerNorm, self).__init__()
69+
self.weight = nn.Parameter(torch.ones(hidden_size))
70+
self.bias = nn.Parameter(torch.zeros(hidden_size))
71+
self.variance_epsilon = eps
72+
73+
def forward(self, x):
74+
u = x.mean(-1, keepdim=True)
75+
s = (x - u).pow(2).mean(-1, keepdim=True)
76+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
77+
return self.weight * x + self.bias
78+
79+
class QuickGELU(nn.Module):
80+
def forward(self, x: torch.Tensor):
81+
return x * torch.sigmoid(1.702 * x)
82+
83+
84+
class ResidualAttentionBlock(nn.Module):
85+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
86+
super().__init__()
87+
88+
self.attn = nn.MultiheadAttention(d_model, n_head)
89+
self.ln_1 = LayerNorm(d_model)
90+
self.mlp = nn.Sequential(OrderedDict([
91+
("c_fc", nn.Linear(d_model, d_model * 4)),
92+
("gelu", QuickGELU()),
93+
("c_proj", nn.Linear(d_model * 4, d_model))
94+
]))
95+
self.ln_2 = LayerNorm(d_model)
96+
self.attn_mask = attn_mask
97+
98+
def attention(self, x: torch.Tensor):
99+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
100+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
101+
102+
def forward(self, x: torch.Tensor):
103+
x = x + self.attention(self.ln_1(x))
104+
x = x + self.mlp(self.ln_2(x))
105+
return x
106+
107+
class TemporalTransformer(nn.Module):
108+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
109+
super().__init__()
110+
self.width = width
111+
self.layers = layers
112+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
113+
114+
def forward(self, x: torch.Tensor):
115+
return self.resblocks((x))
116+
117+
class TemTRANSModule(nn.Module):
118+
def __init__(self,
119+
act_classes,
120+
hidden,
121+
dropout=0.,
122+
depth=4,
123+
num_head=8):
124+
super().__init__()
125+
126+
self.use_CLS = False
127+
self.use_BBOX = True
128+
self.cls_token = nn.Parameter(torch.zeros(1,1, hidden*2))
129+
# 2048 -> The length of features out of last layer of ResNext
130+
self.fc_x = nn.Linear(2048, hidden)
131+
if self.use_BBOX:
132+
# 2048 -> The length of features out of Faster R-CNN
133+
self.fc_d = nn.Linear(2048, hidden*2-12)
134+
self.fc_b = nn.Linear(4, 12)
135+
else:
136+
self.fc_d = nn.Linear(2048, hidden*2)
137+
138+
# 126 -> 63*2 (Each hand has a descriptor of length 63 compatible with H2O format)
139+
self.fc_h = nn.Linear(126, hidden)
140+
'''layers = []
141+
n_mlp = 3
142+
for i in range(n_mlp):
143+
if i == 0:
144+
in_dim, out_dim = 126, hidden
145+
else:
146+
in_dim, out_dim = hidden, hidden
147+
layers.append(nn.Linear(in_dim, out_dim))
148+
layers.append(nn.LeakyReLU(0.2, True))
149+
self.fc_h = nn.Sequential(*layers)'''
150+
#self.fc_v = nn.Linear(768, hidden*2)
151+
152+
context_length = 77
153+
det_classes = 1600
154+
self.frame_pos_embeddings = nn.Embedding(context_length, hidden*2)
155+
self.transformer = TemporalTransformer(width=hidden*2, layers=depth, heads=num_head)
156+
self.classifier_action = nn.Sequential(nn.Dropout(dropout), nn.Linear(hidden*2, act_classes))
157+
self.classifier_detection = nn.Sequential(nn.Dropout(dropout), nn.Linear(hidden*2, det_classes))
158+
# Loss functions
159+
self.loss_detection = nn.CrossEntropyLoss(label_smoothing=0.1)
160+
class_weight = torch.ones(act_classes)
161+
class_weight[0] = 2./float(act_classes)
162+
self.loss_action = nn.CrossEntropyLoss(weight=class_weight, label_smoothing=0.1) #LabelSmoothingLoss(0.1, 37, ignore_index=-1)
163+
self.apply(self.init_weights)
164+
165+
# video swin transformer
166+
#cfgs = configs['video_swin_t_p4w7']
167+
#self.extr_3d = VideoSwinTransformerBackbone(True, 'checkpoints/swin_tiny_patch244_window877_kinetics400_1k.pth', False, **cfgs)
168+
169+
def init_weights(self, module):
170+
""" Initialize the weights.
171+
"""
172+
if isinstance(module, (nn.Linear, nn.Embedding)):
173+
# Slightly different from the TF version which uses truncated_normal for initialization
174+
# cf https://github.com/pytorch/pytorch/pull/5617
175+
module.weight.data.normal_(mean=0.0, std=0.02)
176+
elif isinstance(module, LayerNorm):
177+
if 'beta' in dir(module) and 'gamma' in dir(module):
178+
module.beta.data.zero_()
179+
module.gamma.data.fill_(1.0)
180+
else:
181+
module.bias.data.zero_()
182+
module.weight.data.fill_(1.0)
183+
if isinstance(module, nn.Linear) and module.bias is not None:
184+
module.bias.data.zero_()
185+
186+
def forward(self, inputs):
187+
topK = 5
188+
# load features
189+
feat_x = inputs["feats"] # RGB features
190+
lh, rh = inputs["labels"]["l_hand"], inputs["labels"]["r_hand"] # hand poses
191+
feat_d = inputs["dets"] # detections
192+
feat_b = inputs["bbox"] # bounding boxes
193+
#det_labels = inputs[0]["dcls"] # detections labels
194+
#fr_idx = inputs[0]["idx"] # frame indices
195+
#labels = inputs["act"] # action labels
196+
#idx = inputs["idx"] # video clip index
197+
198+
# select detections
199+
#det_step = np.random.randint(10) if if_train else 5
200+
det_step = feat_d.size(1)
201+
#det_step = feat_x.size(1)
202+
203+
#print(det_step, feat_d.size())
204+
if det_step:
205+
num_batch, num_det, num_dim = feat_d.shape
206+
feat_d = [feat_d[:,k:k+topK,:] for k in range(0, num_det, det_step*topK)]
207+
feat_d = torch.stack(feat_d,1).reshape(num_batch, -1, num_dim)
208+
feat_b = [feat_b[:,k:k+topK,:] for k in range(0, num_det, det_step*topK)]
209+
feat_b = torch.stack(feat_b,1).reshape(num_batch, -1, 4)
210+
#det_labels = [det_labels[:,k:k+topK] for k in range(0, num_det, det_step*topK)]
211+
#det_labels = torch.stack(det_labels,1).reshape(num_batch, -1)
212+
213+
feat_x = self.fc_x(feat_x)
214+
feat_h = self.fc_h(torch.cat([lh, rh], axis=-1).float())
215+
feat_x = torch.cat([feat_x, feat_h], axis=-1)
216+
num_batch, num_frame, num_dim = feat_x.shape
217+
218+
'''
219+
print(f"num batch {num_batch}")
220+
print(f"num frame {num_frame}")
221+
print(f"num dim {num_dim}")
222+
'''
223+
224+
if feat_d.ndim == 2:
225+
feat_d = feat_d.unsqueeze(1)
226+
feat_b = feat_b.unsqueeze(1)
227+
feat_d = self.fc_d(feat_d)
228+
229+
if self.use_BBOX:
230+
feat_b[:,:,[0,2]] = feat_b[:,:,[0,2]]/1280.
231+
feat_b[:,:,[1,3]] = feat_b[:,:,[1,3]]/720.
232+
feat_b = self.fc_b(feat_b.float())
233+
feat_d = torch.cat([feat_d, feat_b], axis=-1)
234+
#print(feat_x.shape)
235+
if self.use_CLS:
236+
cls_tokens = self.cls_token.expand(num_batch, -1, -1)
237+
enc_feat = torch.cat([cls_tokens, feat_x, feat_d], axis=1)
238+
else:
239+
enc_feat = torch.cat([feat_x, feat_d], axis=1)
240+
enc_feat = enc_feat.contiguous()
241+
242+
# add positional encoding
243+
if self.use_CLS:
244+
pos_ids = torch.arange(num_frame+1, dtype=torch.long, device=enc_feat.device)
245+
pos_ids_frame = pos_ids.unsqueeze(0).expand(num_batch, -1)
246+
pos_ids_det = pos_ids[1:].repeat(topK,1).transpose(1,0).reshape(-1, feat_d.size(1)).expand(num_batch, -1)
247+
frame_pos_embed = self.frame_pos_embeddings(pos_ids_frame)
248+
det_pos_embed = self.frame_pos_embeddings(pos_ids_det)
249+
#pos_embeddings = torch.cat([frame_pos_embed, frame_pos_embed[:,1:,:], det_pos_embed], axis=1)
250+
pos_embeddings = torch.cat([frame_pos_embed, det_pos_embed], axis=1)
251+
else:
252+
pos_ids = torch.arange(num_frame, dtype=torch.long, device=enc_feat.device)
253+
pos_ids_frame = pos_ids.unsqueeze(0).expand(num_batch, -1)
254+
if det_step:
255+
#pos_ids1 = pos_ids[0:num_frame:det_step]
256+
pos_ids1 = pos_ids[0]
257+
#pos_ids_det = pos_ids1.repeat(topK,1).transpose(1,0).reshape(-1, feat_d.size(1)).expand(num_batch, -1)
258+
pos_ids_det = pos_ids1.repeat(topK,1).transpose(1,0).reshape(-1, feat_d.size(1)).expand(num_batch, -1)
259+
else:
260+
pos_ids_det = pos_ids.repeat(topK,1).transpose(1,0).reshape(-1, feat_d.size(1)).expand(num_batch, -1)
261+
frame_pos_embed = self.frame_pos_embeddings(pos_ids_frame)
262+
det_pos_embed = self.frame_pos_embeddings(pos_ids_det)
263+
pos_embeddings = torch.cat([frame_pos_embed, det_pos_embed], axis=1)
264+
265+
trans_feat = enc_feat + pos_embeddings
266+
267+
# calculate attentions
268+
trans_feat = trans_feat.permute(1, 0, 2) # NLD -> LND
269+
trans_feat = self.transformer(trans_feat)
270+
trans_feat = trans_feat.permute(1, 0, 2) # LND -> NLD
271+
trans_feat = trans_feat.type(enc_feat.dtype) + enc_feat
272+
273+
# classification
274+
if self.use_CLS:
275+
action_feat = trans_feat[:,0,:]
276+
#hand_feat = trans_feat[:,num_frame+1:2*num_frame+1,:]
277+
#detection_feat = trans_feat[:,2*num_frame+1:,:]
278+
detection_feat = trans_feat[:,num_frame+1:,:]
279+
action_out = self.classifier_action(action_feat)
280+
else:
281+
action_feat = trans_feat[:,:num_frame,:]
282+
#hand_feat = trans_feat[:,num_frame:2*num_frame,:]
283+
detection_feat = trans_feat[:,num_frame:,:]#feat_b.size(2)]
284+
action_out = self.classifier_action(action_feat.mean(dim=1, keepdim=False))
285+
286+
detection_out = self.classifier_detection(detection_feat)
287+
action_pred = torch.argmax(action_out, dim=1)
288+
action_out = torch.softmax(action_out, dim=1)
289+
290+
return action_out, action_pred
291+
292+
class LabelSmoothingLoss(nn.Module):
293+
"""
294+
With label smoothing,
295+
KL-divergence between q_{smoothed ground truth prob.}(w)
296+
and p_{prob. computed by model}(w) is minimized.
297+
"""
298+
299+
def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
300+
assert 0.0 < label_smoothing <= 1.0
301+
self.ignore_index = ignore_index
302+
super(LabelSmoothingLoss, self).__init__()
303+
304+
self.log_softmax = nn.LogSoftmax(dim=-1)
305+
306+
smoothing_value = label_smoothing / (tgt_vocab_size - 1) # count for the ground-truth word
307+
one_hot = torch.full((tgt_vocab_size,), smoothing_value)
308+
self.register_buffer("one_hot", one_hot.unsqueeze(0))
309+
310+
self.confidence = 1.0 - label_smoothing
311+
312+
def forward(self, output, target):
313+
"""
314+
output (FloatTensor): batch_size x n_classes
315+
target (LongTensor): batch_size, with indices in [-1, tgt_vocab_size-1], `-1` is ignored
316+
"""
317+
valid_indices = target != self.ignore_index # ignore examples with target value -1
318+
target = target[valid_indices]
319+
output = self.log_softmax(output[valid_indices])
320+
321+
model_prob = self.one_hot.repeat(target.size(0), 1)
322+
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
323+
return F.kl_div(output, model_prob, reduction="sum")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from collections import OrderedDict
2+
from multiprocessing import Pool
3+
from typing import Dict, Tuple
4+
5+
import torch
6+
from torch import nn
7+
from torchvision.models import resnext50_32x4d # ,convnext_tiny
8+
import pdb
9+
# from torchvision.models.feature_extraction import create_feature_extractor
10+
11+
12+
class UnifiedFCNModule(nn.Module):
13+
"""Class implements fully convolutional network for extracting spatial
14+
features from the video frames."""
15+
16+
def __init__(self, net: str, num_cpts: int, obj_classes: int, verb_classes: int):
17+
super(UnifiedFCNModule, self).__init__()
18+
self.num_cpts = num_cpts
19+
self.obj_classes = obj_classes
20+
self.verb_classes = verb_classes
21+
22+
self.output_layers = [8] # 8 -> Avg. pool layer
23+
self.selected_out = OrderedDict()
24+
self.net = self._select_network(net)
25+
# Freeze network weights
26+
for param in self.net.parameters():
27+
param.requires_grad = False
28+
29+
self.fhooks = []
30+
# 2048 -> The length of features out of last layer of ResNext
31+
self.fc1 = nn.Linear(2048, self.obj_classes + self.verb_classes)
32+
for i, l in enumerate(list(self.net._modules.keys())):
33+
if i in self.output_layers:
34+
self.fhooks.append(
35+
getattr(self.net, l).register_forward_hook(self.forward_hook(l))
36+
)
37+
38+
# loss function
39+
self.lhand_loss = None
40+
self.rhand_loss = None
41+
self.obj_pose_loss = None
42+
self.conf_loss = None
43+
self.oclass_loss = nn.CrossEntropyLoss()
44+
self.vclass_loss = nn.CrossEntropyLoss()
45+
46+
def forward_hook(self, layer_name):
47+
def hook(module, input, output):
48+
self.selected_out[layer_name] = output
49+
50+
return hook
51+
52+
def _select_network(self, net_opt: str) -> nn.Module:
53+
net: nn.Module = None
54+
if net_opt == "resnext":
55+
net = resnext50_32x4d(pretrained=True)
56+
else:
57+
print("NN model not found. Change the feature extractor network.")
58+
59+
return net
60+
61+
def forward(self, data: Dict):
62+
x = data
63+
out = self.net(x)
64+
x = self.selected_out["avgpool"].reshape(-1, self.fc1.in_features)
65+
66+
return x

0 commit comments

Comments
 (0)