| 1475 | self.layer1.bias_hh = nn.Parameter(torch.FloatTensor(params["ba"].T)) |
| 1476 | |
| 1477 | def forward(self, X): |
| 1478 | self.X = X |
| 1479 | if not isinstance(self.X, torch.Tensor): |
| 1480 | self.X = torchify(self.X) |
| 1481 | |
| 1482 | self.X.retain_grad() |
| 1483 | |
| 1484 | # initial hidden state is 0 |
| 1485 | n_ex, n_in, n_timesteps = self.X.shape |
| 1486 | n_out, n_out = self.layer1.weight_hh.shape |
| 1487 | |
| 1488 | # initialize hidden states |
| 1489 | a0 = torchify(np.zeros((n_ex, n_out))) |
| 1490 | a0.retain_grad() |
| 1491 | |
| 1492 | # forward pass |
| 1493 | A = [] |
| 1494 | at = a0 |
| 1495 | for t in range(n_timesteps): |
| 1496 | A += [at] |
| 1497 | at1 = self.layer1(self.X[:, :, t], at) |
| 1498 | at.retain_grad() |
| 1499 | at = at1 |
| 1500 | |
| 1501 | at.retain_grad() |
| 1502 | A += [at] |
| 1503 | |
| 1504 | # don't inclue a0 in our outputs |
| 1505 | self.A = A[1:] |
| 1506 | return self.A |
| 1507 | |
| 1508 | def extract_grads(self, X): |
| 1509 | self.forward(X) |