MCPcopy Index your code
hub / github.com/ddbourgin/numpy-ml / torch_gradient_generator

Function torch_gradient_generator

numpy_ml/tests/nn_torch_models.py:21–29  ·  view source on GitHub ↗
(fn, **kwargs)

Source from the content-addressed store, hash-verified

19
20
21def torch_gradient_generator(fn, **kwargs):
22 def get_grad(z):
23 z1 = torch.autograd.Variable(torch.FloatTensor(z), requires_grad=True)
24 z2 = fn(z1, **kwargs).sum()
25 z2.backward()
26 grad = z1.grad.numpy()
27 return grad
28
29 return get_grad
30
31
32def torch_xe_grad(y, z):

Callers 6

test_sigmoid_gradFunction · 0.70
test_elu_gradFunction · 0.70
test_tanh_gradFunction · 0.70
test_relu_gradFunction · 0.70
test_softmax_gradFunction · 0.70
test_softplus_gradFunction · 0.70

Calls

no outgoing calls

Tested by 6

test_sigmoid_gradFunction · 0.56
test_elu_gradFunction · 0.56
test_tanh_gradFunction · 0.56
test_relu_gradFunction · 0.56
test_softmax_gradFunction · 0.56
test_softplus_gradFunction · 0.56