| 1460 | |
| 1461 | |
| 1462 | class TorchRNNCell(nn.Module): |
| 1463 | def __init__(self, n_in, n_hid, params, **kwargs): |
| 1464 | super(TorchRNNCell, self).__init__() |
| 1465 | |
| 1466 | self.layer1 = nn.RNNCell(n_in, n_hid, bias=True, nonlinearity="tanh") |
| 1467 | |
| 1468 | # set weights and bias to match those of RNNCell |
| 1469 | # NB: we pass the *transpose* of the RNNCell weights and biases to |
| 1470 | # pytorch, meaning we need to check against the *transpose* of our |
| 1471 | # outputs for any function of the weights |
| 1472 | self.layer1.weight_ih = nn.Parameter(torch.FloatTensor(params["Wax"].T)) |
| 1473 | self.layer1.weight_hh = nn.Parameter(torch.FloatTensor(params["Waa"].T)) |
| 1474 | self.layer1.bias_ih = nn.Parameter(torch.FloatTensor(params["bx"].T)) |
| 1475 | self.layer1.bias_hh = nn.Parameter(torch.FloatTensor(params["ba"].T)) |
| 1476 | |
| 1477 | def forward(self, X): |
| 1478 | self.X = X |
| 1479 | if not isinstance(self.X, torch.Tensor): |
| 1480 | self.X = torchify(self.X) |
| 1481 | |
| 1482 | self.X.retain_grad() |
| 1483 | |
| 1484 | # initial hidden state is 0 |
| 1485 | n_ex, n_in, n_timesteps = self.X.shape |
| 1486 | n_out, n_out = self.layer1.weight_hh.shape |
| 1487 | |
| 1488 | # initialize hidden states |
| 1489 | a0 = torchify(np.zeros((n_ex, n_out))) |
| 1490 | a0.retain_grad() |
| 1491 | |
| 1492 | # forward pass |
| 1493 | A = [] |
| 1494 | at = a0 |
| 1495 | for t in range(n_timesteps): |
| 1496 | A += [at] |
| 1497 | at1 = self.layer1(self.X[:, :, t], at) |
| 1498 | at.retain_grad() |
| 1499 | at = at1 |
| 1500 | |
| 1501 | at.retain_grad() |
| 1502 | A += [at] |
| 1503 | |
| 1504 | # don't inclue a0 in our outputs |
| 1505 | self.A = A[1:] |
| 1506 | return self.A |
| 1507 | |
| 1508 | def extract_grads(self, X): |
| 1509 | self.forward(X) |
| 1510 | self.loss = torch.stack(self.A).sum() |
| 1511 | self.loss.backward() |
| 1512 | grads = { |
| 1513 | "X": self.X.detach().numpy(), |
| 1514 | "ba": self.layer1.bias_hh.detach().numpy(), |
| 1515 | "bx": self.layer1.bias_ih.detach().numpy(), |
| 1516 | "Wax": self.layer1.weight_ih.detach().numpy(), |
| 1517 | "Waa": self.layer1.weight_hh.detach().numpy(), |
| 1518 | "y": torch.stack(self.A).detach().numpy(), |
| 1519 | "dLdA": np.array([a.grad.numpy() for a in self.A]), |