|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +from torch import nn |
| 4 | +from torch.nn import functional as F |
| 5 | +from torch.nn.init import constant, normal |
| 6 | +import time |
| 7 | +import math |
| 8 | +from functools import reduce |
| 9 | +import torchvision.models as models |
| 10 | +import sklearn.utils.class_weight as class_weight |
| 11 | +from collections import OrderedDict |
| 12 | +#from .vidswin import VideoSwinTransformerBackbone |
| 13 | +import pdb |
| 14 | + |
| 15 | + |
| 16 | +configs = { |
| 17 | + 'video_swin_t_p4w7': |
| 18 | + dict(patch_size=(2, 4, 4), |
| 19 | + embed_dim=96, |
| 20 | + depths=[2, 2, 6, 2], |
| 21 | + num_heads=[3, 6, 12, 24], |
| 22 | + window_size=(8, 7, 7), |
| 23 | + mlp_ratio=4., |
| 24 | + qkv_bias=True, |
| 25 | + qk_scale=None, |
| 26 | + drop_rate=0., |
| 27 | + attn_drop_rate=0., |
| 28 | + drop_path_rate=0.2, |
| 29 | + patch_norm=True, |
| 30 | + use_checkpoint=False |
| 31 | + ), |
| 32 | + 'video_swin_s_p4w7': |
| 33 | + dict(patch_size=(2, 4, 4), |
| 34 | + embed_dim=96, |
| 35 | + depths=[2, 2, 18, 2], |
| 36 | + num_heads=[3, 6, 12, 24], |
| 37 | + window_size=(8, 7, 7), |
| 38 | + mlp_ratio=4., |
| 39 | + qkv_bias=True, |
| 40 | + qk_scale=None, |
| 41 | + drop_rate=0., |
| 42 | + attn_drop_rate=0., |
| 43 | + drop_path_rate=0.2, |
| 44 | + patch_norm=True, |
| 45 | + use_checkpoint=False |
| 46 | + ), |
| 47 | + 'video_swin_b_p4w7': |
| 48 | + dict(patch_size=(2, 4, 4), |
| 49 | + embed_dim=128, |
| 50 | + depths=[2, 2, 18, 2], |
| 51 | + num_heads=[4, 8, 16, 32], |
| 52 | + window_size=(8, 7, 7), |
| 53 | + mlp_ratio=4., |
| 54 | + qkv_bias=True, |
| 55 | + qk_scale=None, |
| 56 | + drop_rate=0., |
| 57 | + attn_drop_rate=0., |
| 58 | + drop_path_rate=0.2, |
| 59 | + patch_norm=True, |
| 60 | + use_checkpoint=False |
| 61 | + ) |
| 62 | +} |
| 63 | + |
| 64 | +class LayerNorm(nn.Module): |
| 65 | + def __init__(self, hidden_size, eps=1e-12): |
| 66 | + """Construct a layernorm module in the TF style (epsilon inside the square root). |
| 67 | + """ |
| 68 | + super(LayerNorm, self).__init__() |
| 69 | + self.weight = nn.Parameter(torch.ones(hidden_size)) |
| 70 | + self.bias = nn.Parameter(torch.zeros(hidden_size)) |
| 71 | + self.variance_epsilon = eps |
| 72 | + |
| 73 | + def forward(self, x): |
| 74 | + u = x.mean(-1, keepdim=True) |
| 75 | + s = (x - u).pow(2).mean(-1, keepdim=True) |
| 76 | + x = (x - u) / torch.sqrt(s + self.variance_epsilon) |
| 77 | + return self.weight * x + self.bias |
| 78 | + |
| 79 | +class QuickGELU(nn.Module): |
| 80 | + def forward(self, x: torch.Tensor): |
| 81 | + return x * torch.sigmoid(1.702 * x) |
| 82 | + |
| 83 | + |
| 84 | +class ResidualAttentionBlock(nn.Module): |
| 85 | + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
| 86 | + super().__init__() |
| 87 | + |
| 88 | + self.attn = nn.MultiheadAttention(d_model, n_head) |
| 89 | + self.ln_1 = LayerNorm(d_model) |
| 90 | + self.mlp = nn.Sequential(OrderedDict([ |
| 91 | + ("c_fc", nn.Linear(d_model, d_model * 4)), |
| 92 | + ("gelu", QuickGELU()), |
| 93 | + ("c_proj", nn.Linear(d_model * 4, d_model)) |
| 94 | + ])) |
| 95 | + self.ln_2 = LayerNorm(d_model) |
| 96 | + self.attn_mask = attn_mask |
| 97 | + |
| 98 | + def attention(self, x: torch.Tensor): |
| 99 | + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| 100 | + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
| 101 | + |
| 102 | + def forward(self, x: torch.Tensor): |
| 103 | + x = x + self.attention(self.ln_1(x)) |
| 104 | + x = x + self.mlp(self.ln_2(x)) |
| 105 | + return x |
| 106 | + |
| 107 | +class TemporalTransformer(nn.Module): |
| 108 | + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
| 109 | + super().__init__() |
| 110 | + self.width = width |
| 111 | + self.layers = layers |
| 112 | + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
| 113 | + |
| 114 | + def forward(self, x: torch.Tensor): |
| 115 | + return self.resblocks((x)) |
| 116 | + |
| 117 | +class TemTRANSModule(nn.Module): |
| 118 | + def __init__(self, |
| 119 | + act_classes, |
| 120 | + hidden, |
| 121 | + dropout=0., |
| 122 | + depth=4, |
| 123 | + num_head=8): |
| 124 | + super().__init__() |
| 125 | + |
| 126 | + self.use_CLS = False |
| 127 | + self.use_BBOX = True |
| 128 | + self.cls_token = nn.Parameter(torch.zeros(1,1, hidden*2)) |
| 129 | + # 2048 -> The length of features out of last layer of ResNext |
| 130 | + self.fc_x = nn.Linear(2048, hidden) |
| 131 | + if self.use_BBOX: |
| 132 | + # 2048 -> The length of features out of Faster R-CNN |
| 133 | + self.fc_d = nn.Linear(2048, hidden*2-12) |
| 134 | + self.fc_b = nn.Linear(4, 12) |
| 135 | + else: |
| 136 | + self.fc_d = nn.Linear(2048, hidden*2) |
| 137 | + |
| 138 | + # 126 -> 63*2 (Each hand has a descriptor of length 63 compatible with H2O format) |
| 139 | + self.fc_h = nn.Linear(126, hidden) |
| 140 | + '''layers = [] |
| 141 | + n_mlp = 3 |
| 142 | + for i in range(n_mlp): |
| 143 | + if i == 0: |
| 144 | + in_dim, out_dim = 126, hidden |
| 145 | + else: |
| 146 | + in_dim, out_dim = hidden, hidden |
| 147 | + layers.append(nn.Linear(in_dim, out_dim)) |
| 148 | + layers.append(nn.LeakyReLU(0.2, True)) |
| 149 | + self.fc_h = nn.Sequential(*layers)''' |
| 150 | + #self.fc_v = nn.Linear(768, hidden*2) |
| 151 | + |
| 152 | + context_length = 77 |
| 153 | + det_classes = 1600 |
| 154 | + self.frame_pos_embeddings = nn.Embedding(context_length, hidden*2) |
| 155 | + self.transformer = TemporalTransformer(width=hidden*2, layers=depth, heads=num_head) |
| 156 | + self.classifier_action = nn.Sequential(nn.Dropout(dropout), nn.Linear(hidden*2, act_classes)) |
| 157 | + self.classifier_detection = nn.Sequential(nn.Dropout(dropout), nn.Linear(hidden*2, det_classes)) |
| 158 | + # Loss functions |
| 159 | + self.loss_detection = nn.CrossEntropyLoss(label_smoothing=0.1) |
| 160 | + class_weight = torch.ones(act_classes) |
| 161 | + class_weight[0] = 2./float(act_classes) |
| 162 | + self.loss_action = nn.CrossEntropyLoss(weight=class_weight, label_smoothing=0.1) #LabelSmoothingLoss(0.1, 37, ignore_index=-1) |
| 163 | + self.apply(self.init_weights) |
| 164 | + |
| 165 | + # video swin transformer |
| 166 | + #cfgs = configs['video_swin_t_p4w7'] |
| 167 | + #self.extr_3d = VideoSwinTransformerBackbone(True, 'checkpoints/swin_tiny_patch244_window877_kinetics400_1k.pth', False, **cfgs) |
| 168 | + |
| 169 | + def init_weights(self, module): |
| 170 | + """ Initialize the weights. |
| 171 | + """ |
| 172 | + if isinstance(module, (nn.Linear, nn.Embedding)): |
| 173 | + # Slightly different from the TF version which uses truncated_normal for initialization |
| 174 | + # cf https://github.com/pytorch/pytorch/pull/5617 |
| 175 | + module.weight.data.normal_(mean=0.0, std=0.02) |
| 176 | + elif isinstance(module, LayerNorm): |
| 177 | + if 'beta' in dir(module) and 'gamma' in dir(module): |
| 178 | + module.beta.data.zero_() |
| 179 | + module.gamma.data.fill_(1.0) |
| 180 | + else: |
| 181 | + module.bias.data.zero_() |
| 182 | + module.weight.data.fill_(1.0) |
| 183 | + if isinstance(module, nn.Linear) and module.bias is not None: |
| 184 | + module.bias.data.zero_() |
| 185 | + |
| 186 | + def forward(self, inputs): |
| 187 | + topK = 5 |
| 188 | + # load features |
| 189 | + feat_x = inputs["feats"] # RGB features |
| 190 | + lh, rh = inputs["labels"]["l_hand"], inputs["labels"]["r_hand"] # hand poses |
| 191 | + feat_d = inputs["dets"] # detections |
| 192 | + feat_b = inputs["bbox"] # bounding boxes |
| 193 | + #det_labels = inputs[0]["dcls"] # detections labels |
| 194 | + #fr_idx = inputs[0]["idx"] # frame indices |
| 195 | + #labels = inputs["act"] # action labels |
| 196 | + #idx = inputs["idx"] # video clip index |
| 197 | + |
| 198 | + # select detections |
| 199 | + #det_step = np.random.randint(10) if if_train else 5 |
| 200 | + det_step = feat_d.size(1) |
| 201 | + #det_step = feat_x.size(1) |
| 202 | + |
| 203 | + #print(det_step, feat_d.size()) |
| 204 | + if det_step: |
| 205 | + num_batch, num_det, num_dim = feat_d.shape |
| 206 | + feat_d = [feat_d[:,k:k+topK,:] for k in range(0, num_det, det_step*topK)] |
| 207 | + feat_d = torch.stack(feat_d,1).reshape(num_batch, -1, num_dim) |
| 208 | + feat_b = [feat_b[:,k:k+topK,:] for k in range(0, num_det, det_step*topK)] |
| 209 | + feat_b = torch.stack(feat_b,1).reshape(num_batch, -1, 4) |
| 210 | + #det_labels = [det_labels[:,k:k+topK] for k in range(0, num_det, det_step*topK)] |
| 211 | + #det_labels = torch.stack(det_labels,1).reshape(num_batch, -1) |
| 212 | + |
| 213 | + feat_x = self.fc_x(feat_x) |
| 214 | + feat_h = self.fc_h(torch.cat([lh, rh], axis=-1).float()) |
| 215 | + feat_x = torch.cat([feat_x, feat_h], axis=-1) |
| 216 | + num_batch, num_frame, num_dim = feat_x.shape |
| 217 | + |
| 218 | + ''' |
| 219 | + print(f"num batch {num_batch}") |
| 220 | + print(f"num frame {num_frame}") |
| 221 | + print(f"num dim {num_dim}") |
| 222 | + ''' |
| 223 | + |
| 224 | + if feat_d.ndim == 2: |
| 225 | + feat_d = feat_d.unsqueeze(1) |
| 226 | + feat_b = feat_b.unsqueeze(1) |
| 227 | + feat_d = self.fc_d(feat_d) |
| 228 | + |
| 229 | + if self.use_BBOX: |
| 230 | + feat_b[:,:,[0,2]] = feat_b[:,:,[0,2]]/1280. |
| 231 | + feat_b[:,:,[1,3]] = feat_b[:,:,[1,3]]/720. |
| 232 | + feat_b = self.fc_b(feat_b.float()) |
| 233 | + feat_d = torch.cat([feat_d, feat_b], axis=-1) |
| 234 | + #print(feat_x.shape) |
| 235 | + if self.use_CLS: |
| 236 | + cls_tokens = self.cls_token.expand(num_batch, -1, -1) |
| 237 | + enc_feat = torch.cat([cls_tokens, feat_x, feat_d], axis=1) |
| 238 | + else: |
| 239 | + enc_feat = torch.cat([feat_x, feat_d], axis=1) |
| 240 | + enc_feat = enc_feat.contiguous() |
| 241 | + |
| 242 | + # add positional encoding |
| 243 | + if self.use_CLS: |
| 244 | + pos_ids = torch.arange(num_frame+1, dtype=torch.long, device=enc_feat.device) |
| 245 | + pos_ids_frame = pos_ids.unsqueeze(0).expand(num_batch, -1) |
| 246 | + pos_ids_det = pos_ids[1:].repeat(topK,1).transpose(1,0).reshape(-1, feat_d.size(1)).expand(num_batch, -1) |
| 247 | + frame_pos_embed = self.frame_pos_embeddings(pos_ids_frame) |
| 248 | + det_pos_embed = self.frame_pos_embeddings(pos_ids_det) |
| 249 | + #pos_embeddings = torch.cat([frame_pos_embed, frame_pos_embed[:,1:,:], det_pos_embed], axis=1) |
| 250 | + pos_embeddings = torch.cat([frame_pos_embed, det_pos_embed], axis=1) |
| 251 | + else: |
| 252 | + pos_ids = torch.arange(num_frame, dtype=torch.long, device=enc_feat.device) |
| 253 | + pos_ids_frame = pos_ids.unsqueeze(0).expand(num_batch, -1) |
| 254 | + if det_step: |
| 255 | + #pos_ids1 = pos_ids[0:num_frame:det_step] |
| 256 | + pos_ids1 = pos_ids[0] |
| 257 | + #pos_ids_det = pos_ids1.repeat(topK,1).transpose(1,0).reshape(-1, feat_d.size(1)).expand(num_batch, -1) |
| 258 | + pos_ids_det = pos_ids1.repeat(topK,1).transpose(1,0).reshape(-1, feat_d.size(1)).expand(num_batch, -1) |
| 259 | + else: |
| 260 | + pos_ids_det = pos_ids.repeat(topK,1).transpose(1,0).reshape(-1, feat_d.size(1)).expand(num_batch, -1) |
| 261 | + frame_pos_embed = self.frame_pos_embeddings(pos_ids_frame) |
| 262 | + det_pos_embed = self.frame_pos_embeddings(pos_ids_det) |
| 263 | + pos_embeddings = torch.cat([frame_pos_embed, det_pos_embed], axis=1) |
| 264 | + |
| 265 | + trans_feat = enc_feat + pos_embeddings |
| 266 | + |
| 267 | + # calculate attentions |
| 268 | + trans_feat = trans_feat.permute(1, 0, 2) # NLD -> LND |
| 269 | + trans_feat = self.transformer(trans_feat) |
| 270 | + trans_feat = trans_feat.permute(1, 0, 2) # LND -> NLD |
| 271 | + trans_feat = trans_feat.type(enc_feat.dtype) + enc_feat |
| 272 | + |
| 273 | + # classification |
| 274 | + if self.use_CLS: |
| 275 | + action_feat = trans_feat[:,0,:] |
| 276 | + #hand_feat = trans_feat[:,num_frame+1:2*num_frame+1,:] |
| 277 | + #detection_feat = trans_feat[:,2*num_frame+1:,:] |
| 278 | + detection_feat = trans_feat[:,num_frame+1:,:] |
| 279 | + action_out = self.classifier_action(action_feat) |
| 280 | + else: |
| 281 | + action_feat = trans_feat[:,:num_frame,:] |
| 282 | + #hand_feat = trans_feat[:,num_frame:2*num_frame,:] |
| 283 | + detection_feat = trans_feat[:,num_frame:,:]#feat_b.size(2)] |
| 284 | + action_out = self.classifier_action(action_feat.mean(dim=1, keepdim=False)) |
| 285 | + |
| 286 | + detection_out = self.classifier_detection(detection_feat) |
| 287 | + action_pred = torch.argmax(action_out, dim=1) |
| 288 | + action_out = torch.softmax(action_out, dim=1) |
| 289 | + |
| 290 | + return action_out, action_pred |
| 291 | + |
| 292 | +class LabelSmoothingLoss(nn.Module): |
| 293 | + """ |
| 294 | + With label smoothing, |
| 295 | + KL-divergence between q_{smoothed ground truth prob.}(w) |
| 296 | + and p_{prob. computed by model}(w) is minimized. |
| 297 | + """ |
| 298 | + |
| 299 | + def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): |
| 300 | + assert 0.0 < label_smoothing <= 1.0 |
| 301 | + self.ignore_index = ignore_index |
| 302 | + super(LabelSmoothingLoss, self).__init__() |
| 303 | + |
| 304 | + self.log_softmax = nn.LogSoftmax(dim=-1) |
| 305 | + |
| 306 | + smoothing_value = label_smoothing / (tgt_vocab_size - 1) # count for the ground-truth word |
| 307 | + one_hot = torch.full((tgt_vocab_size,), smoothing_value) |
| 308 | + self.register_buffer("one_hot", one_hot.unsqueeze(0)) |
| 309 | + |
| 310 | + self.confidence = 1.0 - label_smoothing |
| 311 | + |
| 312 | + def forward(self, output, target): |
| 313 | + """ |
| 314 | + output (FloatTensor): batch_size x n_classes |
| 315 | + target (LongTensor): batch_size, with indices in [-1, tgt_vocab_size-1], `-1` is ignored |
| 316 | + """ |
| 317 | + valid_indices = target != self.ignore_index # ignore examples with target value -1 |
| 318 | + target = target[valid_indices] |
| 319 | + output = self.log_softmax(output[valid_indices]) |
| 320 | + |
| 321 | + model_prob = self.one_hot.repeat(target.size(0), 1) |
| 322 | + model_prob.scatter_(1, target.unsqueeze(1), self.confidence) |
| 323 | + return F.kl_div(output, model_prob, reduction="sum") |
0 commit comments