From 5be40c8325dd680800223bd3160364b59f1a35e8 Mon Sep 17 00:00:00 2001 From: marsggbo <1435679023@qq> Date: Tue, 11 Apr 2023 06:34:23 -0700 Subject: [PATCH] v1.4.4 refactor vit --- hyperbox/networks/vit/vit.py | 41 +++++++++++++++++++++++++----------- setup.py | 2 +- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/hyperbox/networks/vit/vit.py b/hyperbox/networks/vit/vit.py index ae14796..c08bf45 100644 --- a/hyperbox/networks/vit/vit.py +++ b/hyperbox/networks/vit/vit.py @@ -19,6 +19,7 @@ 'ViT', 'ViT_S', 'ViT_B', + 'ViT_L', 'ViT_H', 'ViT_G', 'ViT_10B', @@ -58,10 +59,10 @@ def __init__( hidden_dim_list = keepPositiveList([int(r*hidden_dim) for r in search_ratio]) hidden_dim_list = spaces.ValueSpace(hidden_dim_list, key=f"{suffix}_hidden_dim", mask=self.mask) if len(hidden_dim_list) > 1 else hidden_dim_list[0] self.net = nn.Sequential( - ops.Linear(dim, hidden_dim_list), + ops.Linear(dim, hidden_dim_list, bias=True), nn.GELU(), nn.Dropout(dropout), - ops.Linear(hidden_dim_list, dim), + ops.Linear(hidden_dim_list, dim, bias=True), nn.Dropout(dropout) ) @@ -108,10 +109,10 @@ def __init__( qkv_dim_list = spaces.ValueSpace(qkv_dim_list, key=f"{suffix}_inner_dim", mask=self.mask) # coupled with self.inner_dim_list else: qkv_dim_list = self.inner_dim_list * 3 - self.to_qkv = ops.Linear(dim, qkv_dim_list, bias = False) + self.to_qkv = ops.Linear(dim, qkv_dim_list, bias=True) self.to_out = nn.Sequential( - ops.Linear(self.inner_dim_list, dim), + ops.Linear(self.inner_dim_list, dim, bias=True), nn.Dropout(dropout) ) @@ -177,21 +178,20 @@ def __init__( num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width - self.to_patch_embedding = nn.Sequential( - Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), - nn.Linear(patch_dim, dim), - ) - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.patch_embeddings = nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size) + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) def forward(self, x): - x = self.to_patch_embedding(x) - b, n, _ = x.shape + x = self.patch_embeddings(x) + x = x.flatten(2) + x = x.transpose(-1, -2) + b, n = x.shape[:2] cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, :(n + 1)] + x += self.position_embeddings[:, :(n + 1)] x = self.dropout(x) return x @@ -344,6 +344,22 @@ def forward(self, x): emb_dropout=0.1, ) +_vit_l = dict( + image_size=224, + patch_size=16, + num_classes=1000, + dim=1024, + depth=24, + heads=16, + dim_head=64, + mlp_dim=4096, + search_ratio=[0.5, 0.75, 1], + pool='cls', + channels=3, + dropout=0.1, + emb_dropout=0.1, +) + _vit_h = dict( image_size=224, patch_size=16, @@ -395,6 +411,7 @@ def forward(self, x): ViT = partial(VisionTransformer, **_vit_b) ViT_S = partial(VisionTransformer, **_vit_s) ViT_B = partial(VisionTransformer, **_vit_b) +ViT_L = partial(VisionTransformer, **_vit_l) ViT_H = partial(VisionTransformer, **_vit_h) ViT_G = partial(VisionTransformer, **_vit_g) ViT_10B = partial(VisionTransformer, **_vit_10b) diff --git a/setup.py b/setup.py index 2511098..e6f6ee4 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setup( name="hyperbox", # you should change "src" to your project name - version="1.4.3", + version="1.4.4", description="Hyperbox: An easy-to-use NAS framework.", author="marsggbo", url="https://github.com/marsggbo/hyperbox",