-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathlayers.py
202 lines (167 loc) · 7.17 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import theano
import theano.tensor as T
import lasagne
from lasagne.layers import InputLayer
from lasagne.layers import Layer
from lasagne.layers import NonlinearityLayer
from lasagne.layers import ElemwiseSumLayer
from lasagne.layers import ExpressionLayer
from lasagne.layers import BatchNormLayer
from lasagne.layers import batch_norm
from lasagne.layers import Pool2DLayer as PoolLayer
from lasagne.layers import Conv2DLayer as ConvLayer
from lasagne.init import Normal, Constant
from lasagne.nonlinearities import linear, rectify, sigmoid, identity
import numpy as np
class ReflectLayer(lasagne.layers.Layer):
"""
Layer for reflect padding. Based on code from https://gist.github.com/ajbrock/a3858c26282d9731191901b397b3ce9f
"""
def __init__(self, incoming, width, batch_ndim=2, **kwargs):
super(ReflectLayer, self).__init__(incoming, **kwargs)
self.width = width
self.batch_ndim = batch_ndim
def get_output_shape_for(self, input_shape):
output_shape = list(input_shape)
if isinstance(self.width, int):
widths = [self.width] * (len(input_shape) - self.batch_ndim)
else:
widths = self.width
for k, w in enumerate(widths):
if output_shape[k + self.batch_ndim] is None:
continue
else:
try:
l, r = w
except TypeError:
l = r = w
output_shape[k + self.batch_ndim] += l + r
return tuple(output_shape)
def get_output_for(self, input, **kwargs):
return self.reflect_pad(input, self.width, self.batch_ndim)
def reflect_pad(self, x, width, batch_ndim=1):
"""
Pad a tensor with a constant value.
Parameters
----------
x : tensor
width : int, iterable of int, or iterable of tuple
Padding width. If an int, pads each axis symmetrically with the same
amount in the beginning and end. If an iterable of int, defines the
symmetric padding width separately for each axis. If an iterable of
tuples of two ints, defines a seperate padding width for each beginning
and end of each axis.
batch_ndim : integer
Dimensions before the value will not be padded.
"""
# Idea for how to make this happen: Flip the tensor horizontally to grab horizontal values, then vertically to grab vertical values
# alternatively, just slice correctly
input_shape = x.shape
input_ndim = x.ndim
output_shape = list(input_shape)
indices = [slice(None) for _ in output_shape]
if isinstance(width, int):
widths = [width] * (input_ndim - batch_ndim)
else:
widths = width
for k, w in enumerate(widths):
try:
l, r = w
except TypeError:
l = r = w
output_shape[k + batch_ndim] += l + r
indices[k + batch_ndim] = slice(l, l + input_shape[k + batch_ndim])
# Create output array
out = T.zeros(output_shape)
# Vertical Reflections
out=T.set_subtensor(out[:,:,:width,width:-width], x[:,:,width:0:-1,:])# out[:,:,:width,width:-width] = x[:,:,width:0:-1,:]
out=T.set_subtensor(out[:,:,-width:,width:-width], x[:,:,-2:-(2+width):-1,:])#out[:,:,-width:,width:-width] = x[:,:,-2:-(2+width):-1,:]
# Place X in out
# out = T.set_subtensor(out[tuple(indices)], x) # or, alternative, out[width:-width,width:-width] = x
out=T.set_subtensor(out[:,:,width:-width,width:-width],x)#out[:,:,width:-width,width:-width] = x
#Horizontal reflections
out=T.set_subtensor(out[:,:,:,:width],out[:,:,:,(2*width):width:-1])#out[:,:,:,:width] = out[:,:,:,(2*width):width:-1]
out=T.set_subtensor(out[:,:,:,-width:],out[:,:,:,-(width+2):-(2*width+2):-1])#out[:,:,:,-width:] = out[:,:,:,-(width+2):-(2*width+2):-1]
return out
class InstanceNormLayer(Layer):
"""
An implementation of Instance Normalization, based on lasagne's BatchNormLayer.
Note that unlike the fns implementation, which uses 1e-5 as epsilon (the torch default),
this implementation uses 1e-4 as epsilon (the lasagne default).
"""
def __init__(self, incoming, num_styles=None, epsilon=1e-4,
beta=Constant(0), gamma=Constant(1), **kwargs):
super(InstanceNormLayer, self).__init__(incoming, **kwargs)
self.axes = (2, 3)
self.epsilon = epsilon
if num_styles == None:
shape = (self.input_shape[1],)
else:
shape = (num_styles, self.input_shape[1])
if beta is None:
self.beta = None
else:
self.beta = self.add_param(beta, shape, 'beta',
trainable=True, regularizable=False)
if gamma is None:
self.gamma = None
else:
self.gamma = self.add_param(gamma, shape, 'gamma',
trainable=True, regularizable=True)
def get_output_for(self, input, style=None, **kwargs):
mean = input.mean(self.axes)
inv_std = T.inv(T.sqrt(input.var(self.axes) + self.epsilon))
pattern = [0, 1, 'x', 'x']
if style == None:
pattern_params = ['x', 0, 'x', 'x']
beta = 0 if self.beta is None else self.beta.dimshuffle(pattern_params)
gamma = 1 if self.gamma is None else self.gamma.dimshuffle(pattern_params)
else:
pattern_params = pattern
beta = 0 if self.beta is None else self.beta[style].dimshuffle(pattern_params)
gamma = 1 if self.gamma is None else self.gamma[style].dimshuffle(pattern_params)
# if self.beta is not None:
# beta = ifelse(T.eq(style.shape[0], 1), T.addbroadcast(beta, 0), beta)
# if self.gamma is not None:
# gamma = ifelse(T.eq(style.shape[0], 1), T.addbroadcast(gamma, 0), gamma)
mean = mean.dimshuffle(pattern)
inv_std = inv_std.dimshuffle(pattern)
# normalize
normalized = (input - mean) * (gamma * inv_std) + beta
return normalized
def instance_norm(layer, **kwargs):
"""
The equivalent of Lasagne's `batch_norm()` convenience method, but for instance normalization.
Refer: http://lasagne.readthedocs.io/en/latest/modules/layers/normalization.html#lasagne.layers.batch_norm
"""
nonlinearity = getattr(layer, 'nonlinearity', None)
if nonlinearity is not None:
layer.nonlinearity = identity
if hasattr(layer, 'b') and layer.b is not None:
del layer.params[layer.b]
layer.b = None
bn_name = (kwargs.pop('name', None) or
(getattr(layer, 'name', None) and layer.name + '_bn'))
layer = InstanceNormLayer(layer, name=bn_name, **kwargs)
if nonlinearity is not None:
nonlin_name = bn_name and bn_name + '_nonlin'
layer = NonlinearityLayer(layer, nonlinearity, name=nonlin_name)
return layer
# TODO: Add normalization
def style_conv_block(conv_in, num_styles, num_filters, filter_size, stride, nonlinearity=rectify, normalization=instance_norm):
sc_network = ReflectLayer(conv_in, filter_size//2)
sc_network = normalization(ConvLayer(sc_network, num_filters, filter_size, stride, nonlinearity=nonlinearity, W=Normal()), num_styles=num_styles)
return sc_network
def residual_block(resnet_in, num_styles=None, num_filters=None, filter_size=3, stride=1):
if num_filters == None:
num_filters = resnet_in.output_shape[1]
conv1 = style_conv_block(resnet_in, num_styles, num_filters, filter_size, stride)
conv2 = style_conv_block(conv1, num_styles, num_filters, filter_size, stride, linear)
res_block = ElemwiseSumLayer([conv2, resnet_in])
return res_block
def nn_upsample(upsample_in, num_styles=None, num_filters=None, filter_size=3, stride=1):
if num_filters == None:
num_filters = upsample_in.output_shape[1]
nn_network = ExpressionLayer(upsample_in, lambda X: X.repeat(2, 2).repeat(2, 3), output_shape='auto')
nn_network = style_conv_block(nn_network, num_styles, num_filters, filter_size, stride)
return nn_network