(recon_x, x, mu, logvar)
| 78 | |
| 79 | # Reconstruction + KL divergence losses summed over all elements and batch |
| 80 | def loss_function(recon_x, x, mu, logvar): |
| 81 | BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') |
| 82 | |
| 83 | # see Appendix B from VAE paper: |
| 84 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 |
| 85 | # https://arxiv.org/abs/1312.6114 |
| 86 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) |
| 87 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
| 88 | |
| 89 | return BCE + KLD |
| 90 | |
| 91 | |
| 92 | def train(epoch): |