4
4
import torch .utils .data
5
5
from torch import nn , optim
6
6
from torch .autograd import Variable
7
+ from torch .nn import functional as F
7
8
from torchvision import datasets , transforms
8
9
from torchvision .utils import save_image
9
10
@@ -77,18 +78,15 @@ def forward(self, x):
77
78
if args .cuda :
78
79
model .cuda ()
79
80
80
- reconstruction_function = nn .BCELoss ()
81
-
82
81
83
82
def loss_function (recon_x , x , mu , logvar ):
84
- BCE = reconstruction_function (recon_x , x .view (- 1 , 784 ))
83
+ BCE = F . binary_cross_entropy (recon_x , x .view (- 1 , 784 ))
85
84
86
85
# see Appendix B from VAE paper:
87
86
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
88
87
# https://arxiv.org/abs/1312.6114
89
88
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
90
- KLD_element = mu .pow (2 ).add_ (logvar .exp ()).mul_ (- 1 ).add_ (1 ).add_ (logvar )
91
- KLD = torch .sum (KLD_element ).mul_ (- 0.5 )
89
+ KLD = - 0.5 * torch .sum (1 + logvar - mu .pow (2 ) - logvar .exp ())
92
90
# Normalise by same number of elements as in reconstruction
93
91
KLD /= args .batch_size * 784
94
92
@@ -131,8 +129,11 @@ def test(epoch):
131
129
recon_batch , mu , logvar = model (data )
132
130
test_loss += loss_function (recon_batch , data , mu , logvar ).data [0 ]
133
131
if i == 0 :
134
- save_image (recon_batch .data .cpu ().view (args .batch_size , 1 , 28 , 28 ),
135
- 'reconstruction_' + str (epoch ) + '.png' )
132
+ n = min (data .size (0 ), 8 )
133
+ comparison = torch .cat ([data [:n ],
134
+ recon_batch .view (args .batch_size , 1 , 28 , 28 )[:n ]])
135
+ save_image (comparison .data .cpu (),
136
+ 'results/reconstruction_' + str (epoch ) + '.png' , nrow = n )
136
137
137
138
test_loss /= len (test_loader .dataset )
138
139
print ('====> Test set loss: {:.4f}' .format (test_loss ))
@@ -145,4 +146,5 @@ def test(epoch):
145
146
if args .cuda :
146
147
sample = sample .cuda ()
147
148
sample = model .decode (sample ).cpu ()
148
- save_image (sample .data .view (64 , 1 , 28 , 28 ), 'sample_' + str (epoch ) + '.png' )
149
+ save_image (sample .data .view (64 , 1 , 28 , 28 ),
150
+ 'results/sample_' + str (epoch ) + '.png' )
0 commit comments