Skip to content

Commit

Permalink
add test for BNNet
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Nov 2, 2021
1 parent 9df076e commit 6196461
Showing 1 changed file with 48 additions and 27 deletions.
75 changes: 48 additions & 27 deletions hyperbox/networks/bnnas/bn_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,33 +122,54 @@ def freeze_all_params(self):
params.requires_grad = False

if __name__ == '__main__':
device='cuda'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from hyperbox.mutator import RandomMutator
net = BNNet(search_depth=False, is_only_train_bn=False, num_classes=10).to(device)
opt = torch.optim.SGD(net.parameters(), lr=0.01)
H = torch.nn.CrossEntropyLoss()
print(f"Supernet size: {net.arch_size((2,3,64,64), 1, 1)}")
from pytorch_lightning.utilities.seed import seed_everything
net = BNNet(search_depth=False, is_only_train_bn=False, num_classes=10,
channels_list=[32],
num_blocks=[2],
strides_list=[2],
).to(device)
rm = RandomMutator(net)
for i in range(30):
x = torch.rand(64,3,64,64).to(device)
y = torch.randint(0,10,(64,)).to(device)
opt.zero_grad()

num = 10
for i in range(5):
seed_everything(i)
rm.reset()
x = torch.randn(num,3,64,64).to(device)
preds = net(x)
print(preds.argmax(-1))
net = net.eval()
for i in range(5):
# rm.reset()
print(f"Subnet size: {net.arch_size((2,3,64,64), 1, 1)}")
print(net.bn_metrics())
pred = net(x)
loss = H(pred,y)
loss.backward()
conv = net.features['layer0'][0].candidates[-1].conv[0]
bn = net.features['layer0'][0].candidates[-1].conv[7]
linear = net.classifier[0]
print('conv', conv.weight[0,:5,...].view(-1).detach(), conv.weight.requires_grad)
print('bn', bn.weight[:5], bn.weight.requires_grad)
print('linear', linear.weight[:5,0].view(-1).detach(), linear.weight.requires_grad)
print('loss', loss.item())
opt.step()
if 6>i>3:
net.freeze_except_bn()
elif i > 6:
net.defrost_all_params()
pass
x = torch.randn(num,3,64,64).to(device)
preds = net(x)
print(preds.argmax(-1))

# opt = torch.optim.SGD(net.parameters(), lr=0.01)
# H = torch.nn.CrossEntropyLoss()
# print(f"Supernet size: {net.arch_size((2,3,64,64), 1, 1)}")
# rm = RandomMutator(net)
# for i in range(30):
# x = torch.rand(64,3,64,64).to(device)
# y = torch.randint(0,10,(64,)).to(device)
# opt.zero_grad()
# # rm.reset()
# print(f"Subnet size: {net.arch_size((2,3,64,64), 1, 1)}")
# print(net.bn_metrics())
# pred = net(x)
# loss = H(pred,y)
# loss.backward()
# conv = net.features['layer0'][0].candidates[-1].conv[0]
# bn = net.features['layer0'][0].candidates[-1].conv[7]
# linear = net.classifier[0]
# print('conv', conv.weight[0,:5,...].view(-1).detach(), conv.weight.requires_grad)
# print('bn', bn.weight[:5], bn.weight.requires_grad)
# print('linear', linear.weight[:5,0].view(-1).detach(), linear.weight.requires_grad)
# print('loss', loss.item())
# opt.step()
# if 6>i>3:
# net.freeze_except_bn()
# elif i > 6:
# net.defrost_all_params()
# pass

0 comments on commit 6196461

Please sign in to comment.