MCPcopy
hub / github.com/ddbourgin/numpy-ml / get_grad

Function get_grad

numpy_ml/tests/nn_torch_models.py:22–27  ·  view source on GitHub ↗
(z)

Source from the content-addressed store, hash-verified

20
21def torch_gradient_generator(fn, **kwargs):
22 def get_grad(z):
23 z1 = torch.autograd.Variable(torch.FloatTensor(z), requires_grad=True)
24 z2 = fn(z1, **kwargs).sum()
25 z2.backward()
26 grad = z1.grad.numpy()
27 return grad
28
29 return get_grad
30

Callers

nothing calls this directly

Calls 1

backwardMethod · 0.45

Tested by

no test coverage detected