(z)
| 13 | |
| 14 | def torch_gradient_generator(fn, **kwargs): |
| 15 | def get_grad(z): |
| 16 | z1 = torch.autograd.Variable(torch.from_numpy(z), requires_grad=True) |
| 17 | z2 = fn(z1, **kwargs).sum() |
| 18 | z2.backward() |
| 19 | grad = z1.grad.numpy() |
| 20 | return grad |
| 21 | |
| 22 | return get_grad |
| 23 |