-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
marsggbo
committed
Jun 20, 2023
1 parent
7d7367c
commit 477ec2e
Showing
5 changed files
with
550 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_target_: hyperbox_app.distributed.networks.proxylessnas.network.ProxylessNAS | ||
width_stages: [24,40,80,96,192,320] | ||
n_cell_stages: [4,4,4,4,4,1] | ||
stride_stages: [2,2,2,1,2,1] | ||
width_mult: 1 | ||
num_classes: 1000 | ||
dropout_rate: 0 | ||
bn_param: [0.1,1e-3] | ||
mask: null |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import torch | ||
import torch.nn as nn | ||
import math | ||
|
||
from copy import deepcopy | ||
from hyperbox_app.distributed.networks.proxylessnas import ops, putils | ||
from hyperbox.mutables.spaces import OperationSpace | ||
from hyperbox.networks.base_nas_network import BaseNASNetwork | ||
|
||
|
||
class ProxylessNAS(BaseNASNetwork): | ||
def __init__(self, | ||
width_stages=[24,40,80,96,192,320], | ||
n_cell_stages=[4,4,4,4,4,1], | ||
stride_stages=[2,2,2,1,2,1], | ||
width_mult=1, num_classes=1000, | ||
dropout_rate=0, bn_param=(0.1, 1e-3), | ||
mask=None): | ||
""" | ||
Parameters | ||
---------- | ||
width_stages: str | ||
width (output channels) of each cell stage in the block | ||
n_cell_stages: str | ||
number of cells in each cell stage | ||
stride_strages: str | ||
stride of each cell stage in the block | ||
width_mult : int | ||
the scale factor of width | ||
""" | ||
super(ProxylessNAS, self).__init__(mask=mask) | ||
|
||
input_channel = putils.make_divisible(32 * width_mult, 8) | ||
first_cell_width = putils.make_divisible(16 * width_mult, 8) | ||
width_stages = deepcopy(width_stages) | ||
for i in range(len(width_stages)): | ||
width_stages[i] = putils.make_divisible(width_stages[i] * width_mult, 8) | ||
# first conv | ||
first_conv = ops.ConvLayer(3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act') | ||
# first block | ||
first_block_conv = ops.OPS['3x3_MBConv1'](input_channel, first_cell_width, 1) | ||
first_block = first_block_conv | ||
|
||
input_channel = first_cell_width | ||
|
||
blocks = [first_block] | ||
|
||
stage_cnt = 0 | ||
for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages): | ||
for i in range(n_cell): | ||
if i == 0: | ||
stride = s | ||
else: | ||
stride = 1 | ||
op_candidates = [ops.OPS['3x3_MBConv3'](input_channel, width, stride), | ||
ops.OPS['3x3_MBConv6'](input_channel, width, stride), | ||
ops.OPS['5x5_MBConv3'](input_channel, width, stride), | ||
ops.OPS['5x5_MBConv6'](input_channel, width, stride), | ||
ops.OPS['7x7_MBConv3'](input_channel, width, stride), | ||
ops.OPS['7x7_MBConv6'](input_channel, width, stride)] | ||
if stride == 1 and input_channel == width: | ||
# if it is not the first one | ||
op_candidates += [ops.OPS['Zero'](input_channel, width, stride)] | ||
conv_op = OperationSpace(op_candidates, key="s{}_c{}".format(stage_cnt, i), mask=self.mask, return_mask=True) | ||
# shortcut | ||
if stride == 1 and input_channel == width: | ||
# if not first cell | ||
shortcut = ops.IdentityLayer(input_channel, input_channel) | ||
else: | ||
shortcut = None | ||
inverted_residual_block = ops.MobileInvertedResidualBlock(conv_op, shortcut, op_candidates) | ||
blocks.append(inverted_residual_block) | ||
input_channel = width | ||
stage_cnt += 1 | ||
|
||
# feature mix layer | ||
last_channel = putils.make_divisible(1280 * width_mult, 8) if width_mult > 1.0 else 1280 | ||
feature_mix_layer = ops.ConvLayer(input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6', ops_order='weight_bn_act', ) | ||
classifier = ops.LinearLayer(last_channel, num_classes, dropout_rate=dropout_rate) | ||
|
||
self.first_conv = first_conv | ||
self.blocks = nn.ModuleList(blocks) | ||
self.feature_mix_layer = feature_mix_layer | ||
self.global_avg_pooling = nn.AdaptiveAvgPool2d(1) | ||
self.classifier = classifier | ||
|
||
# set bn param | ||
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1]) | ||
|
||
def forward(self, x): | ||
self.features = [] | ||
x = self.first_conv(x) | ||
for block in self.blocks: | ||
x = block(x) | ||
self.features.append(x.detach()) | ||
x = self.feature_mix_layer(x) | ||
self.features.append(x.detach()) | ||
x = self.global_avg_pooling(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.classifier(x) | ||
self.features.append(x.detach()) | ||
return x | ||
|
||
def set_bn_param(self, momentum, eps): | ||
for m in self.modules(): | ||
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): | ||
m.momentum = momentum | ||
m.eps = eps | ||
return | ||
|
||
def init_model(self, model_init='he_fout', init_div_groups=False): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
if model_init == 'he_fout': | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
if init_div_groups: | ||
n /= m.groups | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
elif model_init == 'he_fin': | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels | ||
if init_div_groups: | ||
n /= m.groups | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
else: | ||
raise NotImplementedError | ||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.Linear): | ||
stdv = 1. / math.sqrt(m.weight.size(1)) | ||
m.weight.data.uniform_(-stdv, stdv) | ||
if m.bias is not None: | ||
m.bias.data.zero_() | ||
|
||
if __name__ == '__main__': | ||
from hyperbox.mutator import RandomMutator | ||
net = ProxylessNAS() | ||
rm = RandomMutator(net) | ||
rm.reset() | ||
print(rm._cache, len(rm._cache)) |
Oops, something went wrong.