MCPcopy
hub / github.com/ddbourgin/numpy-ml / extract_grads

Method extract_grads

numpy_ml/tests/nn_torch_models.py:1508–1526  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

1506 return self.A
1507
1508 def extract_grads(self, X):
1509 self.forward(X)
1510 self.loss = torch.stack(self.A).sum()
1511 self.loss.backward()
1512 grads = {
1513 "X": self.X.detach().numpy(),
1514 "ba": self.layer1.bias_hh.detach().numpy(),
1515 "bx": self.layer1.bias_ih.detach().numpy(),
1516 "Wax": self.layer1.weight_ih.detach().numpy(),
1517 "Waa": self.layer1.weight_hh.detach().numpy(),
1518 "y": torch.stack(self.A).detach().numpy(),
1519 "dLdA": np.array([a.grad.numpy() for a in self.A]),
1520 "dLdWaa": self.layer1.weight_hh.grad.numpy(),
1521 "dLdWax": self.layer1.weight_ih.grad.numpy(),
1522 "dLdBa": self.layer1.bias_hh.grad.numpy(),
1523 "dLdBx": self.layer1.bias_ih.grad.numpy(),
1524 "dLdX": self.X.grad.numpy(),
1525 }
1526 return grads
1527
1528
1529class TorchFCLayer(nn.Module):

Callers 1

test_RNNCellFunction · 0.95

Calls 2

forwardMethod · 0.95
backwardMethod · 0.45

Tested by 1

test_RNNCellFunction · 0.76