MCPcopy
hub / github.com/ddbourgin/numpy-ml / extract_grads

Method extract_grads

numpy_ml/tests/nn_torch_models.py:55–79  ·  view source on GitHub ↗
(self, X, X_recon, t_mean, t_log_var)

Source from the content-addressed store, hash-verified

53 super(TorchVAELoss, self).__init__()
54
55 def extract_grads(self, X, X_recon, t_mean, t_log_var):
56 eps = np.finfo(float).eps
57 X = torchify(X, requires_grad=False)
58 X_recon = torchify(np.clip(X_recon, eps, 1 - eps))
59 t_mean = torchify(t_mean)
60 t_log_var = torchify(t_log_var)
61
62 BCE = torch.sum(F.binary_cross_entropy(X_recon, X, reduction="none"), dim=1)
63
64 # see Appendix B from VAE paper:
65 # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
66 # https://arxiv.org/abs/1312.6114
67 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
68 KLD = -0.5 * torch.sum(1 + t_log_var - t_mean.pow(2) - t_log_var.exp(), dim=1)
69
70 loss = torch.mean(BCE + KLD)
71 loss.backward()
72
73 grads = {
74 "loss": loss.detach().numpy(),
75 "dX_recon": X_recon.grad.numpy(),
76 "dt_mean": t_mean.grad.numpy(),
77 "dt_log_var": t_log_var.grad.numpy(),
78 }
79 return grads
80
81
82class TorchWGANGPLoss(nn.Module):

Callers 2

test_VAE_lossFunction · 0.45
test_WGAN_GP_lossFunction · 0.45

Calls 2

torchifyFunction · 0.85
backwardMethod · 0.45

Tested by 2

test_VAE_lossFunction · 0.36
test_WGAN_GP_lossFunction · 0.36