(z)
| 20 | |
| 21 | def torch_gradient_generator(fn, **kwargs): |
| 22 | def get_grad(z): |
| 23 | z1 = torch.autograd.Variable(torch.FloatTensor(z), requires_grad=True) |
| 24 | z2 = fn(z1, **kwargs).sum() |
| 25 | z2.backward() |
| 26 | grad = z1.grad.numpy() |
| 27 | return grad |
| 28 | |
| 29 | return get_grad |
| 30 |