| 1506 | return self.A |
| 1507 | |
| 1508 | def extract_grads(self, X): |
| 1509 | self.forward(X) |
| 1510 | self.loss = torch.stack(self.A).sum() |
| 1511 | self.loss.backward() |
| 1512 | grads = { |
| 1513 | "X": self.X.detach().numpy(), |
| 1514 | "ba": self.layer1.bias_hh.detach().numpy(), |
| 1515 | "bx": self.layer1.bias_ih.detach().numpy(), |
| 1516 | "Wax": self.layer1.weight_ih.detach().numpy(), |
| 1517 | "Waa": self.layer1.weight_hh.detach().numpy(), |
| 1518 | "y": torch.stack(self.A).detach().numpy(), |
| 1519 | "dLdA": np.array([a.grad.numpy() for a in self.A]), |
| 1520 | "dLdWaa": self.layer1.weight_hh.grad.numpy(), |
| 1521 | "dLdWax": self.layer1.weight_ih.grad.numpy(), |
| 1522 | "dLdBa": self.layer1.bias_hh.grad.numpy(), |
| 1523 | "dLdBx": self.layer1.bias_ih.grad.numpy(), |
| 1524 | "dLdX": self.X.grad.numpy(), |
| 1525 | } |
| 1526 | return grads |
| 1527 | |
| 1528 | |
| 1529 | class TorchFCLayer(nn.Module): |