| 53 | super(TorchVAELoss, self).__init__() |
| 54 | |
| 55 | def extract_grads(self, X, X_recon, t_mean, t_log_var): |
| 56 | eps = np.finfo(float).eps |
| 57 | X = torchify(X, requires_grad=False) |
| 58 | X_recon = torchify(np.clip(X_recon, eps, 1 - eps)) |
| 59 | t_mean = torchify(t_mean) |
| 60 | t_log_var = torchify(t_log_var) |
| 61 | |
| 62 | BCE = torch.sum(F.binary_cross_entropy(X_recon, X, reduction="none"), dim=1) |
| 63 | |
| 64 | # see Appendix B from VAE paper: |
| 65 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 |
| 66 | # https://arxiv.org/abs/1312.6114 |
| 67 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) |
| 68 | KLD = -0.5 * torch.sum(1 + t_log_var - t_mean.pow(2) - t_log_var.exp(), dim=1) |
| 69 | |
| 70 | loss = torch.mean(BCE + KLD) |
| 71 | loss.backward() |
| 72 | |
| 73 | grads = { |
| 74 | "loss": loss.detach().numpy(), |
| 75 | "dX_recon": X_recon.grad.numpy(), |
| 76 | "dt_mean": t_mean.grad.numpy(), |
| 77 | "dt_log_var": t_log_var.grad.numpy(), |
| 78 | } |
| 79 | return grads |
| 80 | |
| 81 | |
| 82 | class TorchWGANGPLoss(nn.Module): |