-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathSCAN.py
81 lines (59 loc) · 2.44 KB
/
SCAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tensorflow as tf
import VAE
import MLP
# TF 1.3 release the statistical distribution library tf.distributions,
# support for versions of TF before 1.3
try:
distributions = tf.distributions
kl_divergence = tf.distributions.kl_divergence
except:
distributions = tf.contrib.distributions
kl_divergence = tf.contrib.distributions.kl_divergence
class SCAN(object):
def __init__(self, cfg):
self.cfg = cfg
self.img_encoder = VAE.Encoder()
self.img_decoder = VAE.Decoder()
self.sym_encoder = MLP.Encoder()
self.sym_decoder = MLP.Decoder()
def train(self):
img, sym = self.read_data_sets()
with tf.variable_scope("beta_VAE"):
img_q_mu, img_q_sigma = self.img_encoder(img)
img_z = distributions.Normal(img_q_mu, img_q_sigma)
img_gen = self.img_decoder(img_z.sample(self.cfg.batch_size))
img_reconstruct_error = tf.reduce_mean(img_gen)
img_z_prior = distributions.Normal()
KL_divergence = kl_divergence(img_z, img_z_prior)
KL_divergence = self.cfg.beta_vae * KL_divergence
loss = img_reconstruct_error - KL_divergence
# train beta VAE
optimizer = tf.train.AdamOptimizer(self.cfg.learning_rate)
train_op = optimizer.minimize(loss)
for step in range(self.cfg.epoch):
self.sess.run(train_op)
with tf.variable_scope("SCAN"):
sym_q_mu, sym_q_sigma = self.sym_encoder(sym)
sym_z = distributions.Normal(sym_q_mu, sym_q_sigma)
self.sym_decoder(sym_z.sample(self.cfg.batch_size))
sym_reconstruct_error = tf.reduce_mean()
sym_z_prior = distributions.Normal()
beta_KL_divergence = kl_divergence(sym_z, sym_z_prior)
beta_KL_divergence = self.cfg.beta_scan * beta_KL_divergence
lambda_KL_divergence = kl_divergence(img_z, sym_z)
loss = sym_reconstruct_error - beta_KL_divergence
loss -= self.cfg.lambda_scan * lambda_KL_divergence
# train SCAN
optimizer = tf.train.AdamOptimizer(self.cfg.learning_rate)
train_op = optimizer.minimize(loss)
for step in range(self.cfg.epoch):
self.sess.run(train_op)
def inference(self):
pass
def read_data_sets(self):
"""
Returns:
data queues of image and symbol.
"""
img, sym = [], []
return(img, sym)