| 1339 | |
| 1340 | |
| 1341 | class TorchLSTMCell(nn.Module): |
| 1342 | def __init__(self, n_in, n_out, params, **kwargs): |
| 1343 | super(TorchLSTMCell, self).__init__() |
| 1344 | |
| 1345 | Wiu = params["Wu"][n_out:, :].T |
| 1346 | Wif = params["Wf"][n_out:, :].T |
| 1347 | Wic = params["Wc"][n_out:, :].T |
| 1348 | Wio = params["Wo"][n_out:, :].T |
| 1349 | W_ih = np.vstack([Wiu, Wif, Wic, Wio]) |
| 1350 | |
| 1351 | Whu = params["Wu"][:n_out, :].T |
| 1352 | Whf = params["Wf"][:n_out, :].T |
| 1353 | Whc = params["Wc"][:n_out, :].T |
| 1354 | Who = params["Wo"][:n_out, :].T |
| 1355 | W_hh = np.vstack([Whu, Whf, Whc, Who]) |
| 1356 | |
| 1357 | self.layer1 = nn.LSTMCell(input_size=n_in, hidden_size=n_out, bias=True) |
| 1358 | assert self.layer1.weight_ih.shape == W_ih.shape |
| 1359 | assert self.layer1.weight_hh.shape == W_hh.shape |
| 1360 | self.layer1.weight_ih = nn.Parameter(torch.FloatTensor(W_ih)) |
| 1361 | self.layer1.weight_hh = nn.Parameter(torch.FloatTensor(W_hh)) |
| 1362 | |
| 1363 | b = np.concatenate( |
| 1364 | [params["bu"], params["bf"], params["bc"], params["bo"]], axis=-1 |
| 1365 | ).flatten() |
| 1366 | assert self.layer1.bias_ih.shape == b.shape |
| 1367 | assert self.layer1.bias_hh.shape == b.shape |
| 1368 | self.layer1.bias_ih = nn.Parameter(torch.FloatTensor(b)) |
| 1369 | self.layer1.bias_hh = nn.Parameter(torch.FloatTensor(b)) |
| 1370 | |
| 1371 | def forward(self, X): |
| 1372 | self.X = X |
| 1373 | if not isinstance(self.X, torch.Tensor): |
| 1374 | self.X = torchify(self.X) |
| 1375 | |
| 1376 | self.X.retain_grad() |
| 1377 | |
| 1378 | # initial hidden state is 0 |
| 1379 | n_ex, n_in, n_timesteps = self.X.shape |
| 1380 | n_out, n_out = self.layer1.weight_hh.shape |
| 1381 | |
| 1382 | # initialize hidden states |
| 1383 | a0 = torchify(np.zeros((n_ex, n_out))) |
| 1384 | c0 = torchify(np.zeros((n_ex, n_out))) |
| 1385 | a0.retain_grad() |
| 1386 | c0.retain_grad() |
| 1387 | |
| 1388 | # forward pass |
| 1389 | A, C = [], [] |
| 1390 | at = a0 |
| 1391 | ct = c0 |
| 1392 | for t in range(n_timesteps): |
| 1393 | A.append(at) |
| 1394 | C.append(ct) |
| 1395 | at1, ct1 = self.layer1(self.X[:, :, t], (at, ct)) |
| 1396 | at.retain_grad() |
| 1397 | ct.retain_grad() |
| 1398 | at = at1 |
no outgoing calls