MCPcopy
hub / github.com/ddbourgin/numpy-ml / extract_grads

Method extract_grads

numpy_ml/tests/nn_torch_models.py:376–390  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

374 return self.Y
375
376 def extract_grads(self, X):
377 self.forward(X)
378 self.loss = self.Y.sum()
379 self.loss.backward()
380 grads = {
381 "Xs": X,
382 "Prod": self.prod.detach().numpy(),
383 "Y": self.Y.detach().numpy(),
384 "dLdY": self.Y.grad.numpy(),
385 "dLdProd": self.prod.grad.numpy(),
386 }
387 grads.update(
388 {"dLdX{}".format(i + 1): xi.grad.numpy() for i, xi in enumerate(self.Xs)}
389 )
390 return grads
391
392
393class TorchSkipConnectionIdentity(nn.Module):

Callers 1

test_MultiplyLayerFunction · 0.95

Calls 3

forwardMethod · 0.95
backwardMethod · 0.45
updateMethod · 0.45

Tested by 1

test_MultiplyLayerFunction · 0.76