Skip to content

Commit 5f24730

Browse files
Kaixhinsoumith
authored andcommitted
Add support for CUDA
1 parent ab7cb38 commit 5f24730

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

vae/main.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test(epoch):
131131
recon_batch, mu, logvar = model(data)
132132
test_loss += loss_function(recon_batch, data, mu, logvar).data[0]
133133
if i == 0:
134-
save_image(recon_batch.data.view(args.batch_size, 1, 28, 28),
134+
save_image(recon_batch.data.cpu().view(args.batch_size, 1, 28, 28),
135135
'reconstruction_' + str(epoch) + '.png')
136136

137137
test_loss /= len(test_loader.dataset)
@@ -141,5 +141,8 @@ def test(epoch):
141141
for epoch in range(1, args.epochs + 1):
142142
train(epoch)
143143
test(epoch)
144-
sample = model.decode(Variable(torch.randn(64, 20)))
144+
sample = Variable(torch.randn(64, 20))
145+
if args.cuda:
146+
sample = sample.cuda()
147+
sample = model.decode(sample).cpu()
145148
save_image(sample.data.view(64, 1, 28, 28), 'sample_' + str(epoch) + '.png')

0 commit comments

Comments
 (0)