Skip to content

Commit 9fe431e

Browse files
Kaixhinsoumith
authored andcommitted
Fix VAE loss + improve reconstruction viz
Closes pytorch#225
1 parent 5f24730 commit 9fe431e

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

vae/main.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch.utils.data
55
from torch import nn, optim
66
from torch.autograd import Variable
7+
from torch.nn import functional as F
78
from torchvision import datasets, transforms
89
from torchvision.utils import save_image
910

@@ -77,18 +78,15 @@ def forward(self, x):
7778
if args.cuda:
7879
model.cuda()
7980

80-
reconstruction_function = nn.BCELoss()
81-
8281

8382
def loss_function(recon_x, x, mu, logvar):
84-
BCE = reconstruction_function(recon_x, x.view(-1, 784))
83+
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))
8584

8685
# see Appendix B from VAE paper:
8786
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
8887
# https://arxiv.org/abs/1312.6114
8988
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
90-
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
91-
KLD = torch.sum(KLD_element).mul_(-0.5)
89+
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
9290
# Normalise by same number of elements as in reconstruction
9391
KLD /= args.batch_size * 784
9492

@@ -131,8 +129,11 @@ def test(epoch):
131129
recon_batch, mu, logvar = model(data)
132130
test_loss += loss_function(recon_batch, data, mu, logvar).data[0]
133131
if i == 0:
134-
save_image(recon_batch.data.cpu().view(args.batch_size, 1, 28, 28),
135-
'reconstruction_' + str(epoch) + '.png')
132+
n = min(data.size(0), 8)
133+
comparison = torch.cat([data[:n],
134+
recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
135+
save_image(comparison.data.cpu(),
136+
'results/reconstruction_' + str(epoch) + '.png', nrow=n)
136137

137138
test_loss /= len(test_loader.dataset)
138139
print('====> Test set loss: {:.4f}'.format(test_loss))
@@ -145,4 +146,5 @@ def test(epoch):
145146
if args.cuda:
146147
sample = sample.cuda()
147148
sample = model.decode(sample).cpu()
148-
save_image(sample.data.view(64, 1, 28, 28), 'sample_' + str(epoch) + '.png')
149+
save_image(sample.data.view(64, 1, 28, 28),
150+
'results/sample_' + str(epoch) + '.png')
File renamed without changes.

0 commit comments

Comments
 (0)