-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
47 lines (43 loc) · 1.86 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from tensorflow.keras.layers import (Conv2DTranspose, ConvLSTM2D, Conv2D,
TimeDistributed, LayerNormalization)
from tensorflow.keras import Model
class ModelADetector(Model):
def __init__(self):
super(ModelADetector, self).__init__()
# dowsample
self.td_con_0 = TimeDistributed(Conv2D(128, (11, 11), strides=4))
self.ln_1 = LayerNormalization()
self.td_conv_1 = TimeDistributed(Conv2D(64, (5, 5), strides=2))
self.ln_2 = LayerNormalization()
# bottle neck
self.conv_lstm_1 = ConvLSTM2D(64, (3, 3), padding='same',
return_sequences=True)
self.ln_3 = LayerNormalization()
self.conv_lstm_2 = ConvLSTM2D(32, (3, 3), padding='same',
return_sequences=True)
self.ln_4 = LayerNormalization()
self.conv_lstm_3 = ConvLSTM2D(64, (3, 3), padding='same',
return_sequences=True)
self.ln_5 = LayerNormalization()
# upsammpling
self.td_convT_1 = TimeDistributed(Conv2DTranspose(128, (5, 5),
strides=2))
self.ln_6 = LayerNormalization()
self.td_convT_2 = TimeDistributed(Conv2DTranspose(1, (11, 11),
activation="sigmoid",
strides=4))
def call(self, inputs):
x = self.td_con_0(inputs)
x = self.ln_1(x)
x = self.td_conv_1(x)
x = self.ln_2(x)
x = self.conv_lstm_1(x)
x = self.ln_3(x)
x = self.conv_lstm_2(x)
x = self.ln_4(x)
x = self.conv_lstm_3(x)
x = self.ln_5(x)
x = self.td_convT_1(x)
x = self.ln_6(x)
x = self.td_convT_2(x)
return x