MCPcopy
hub / github.com/ddbourgin/numpy-ml / get_grad

Function get_grad

numpy_ml/tests/test_nn_activations.py:15–20  ·  view source on GitHub ↗
(z)

Source from the content-addressed store, hash-verified

13
14def torch_gradient_generator(fn, **kwargs):
15 def get_grad(z):
16 z1 = torch.autograd.Variable(torch.from_numpy(z), requires_grad=True)
17 z2 = fn(z1, **kwargs).sum()
18 z2.backward()
19 grad = z1.grad.numpy()
20 return grad
21
22 return get_grad
23

Callers

nothing calls this directly

Calls 1

backwardMethod · 0.45

Tested by

no test coverage detected