(y, z, act_fn)
| 39 | |
| 40 | |
| 41 | def torch_mse_grad(y, z, act_fn): |
| 42 | y = torch.FloatTensor(y) |
| 43 | z = torch.autograd.Variable(torch.FloatTensor(z), requires_grad=True) |
| 44 | y_pred = act_fn(z) |
| 45 | loss = F.mse_loss(y_pred, y, reduction="sum") # size_average=False).sum() |
| 46 | loss.backward() |
| 47 | grad = z.grad.numpy() |
| 48 | return grad |
| 49 | |
| 50 | |
| 51 | class TorchVAELoss(nn.Module): |