-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgan_generator.py
44 lines (28 loc) · 1.49 KB
/
gan_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# -*- coding: utf-8 -*-
"""GAN generator.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/15CXWy3aN0qKCyoEIhssxo1jiTRLcXnwL
"""
class Gen(nn.Module):
def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 16, act = nn.PReLU(), scale=4):
super(Gen, self).__init__()
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)
resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
self.body = nn.Sequential(*resblocks)
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = None)
if(scale == 4):
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]
else:
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)]
self.tail = nn.Sequential(*upsample_blocks)
self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
def forward(self, x):
x = self.conv01(x)
_skip_connection = x
x = self.body(x)
x = self.conv02(x)
feat = x + _skip_connection
x = self.tail(feat)
x = self.last_conv(x)
return x,