MCPcopy
hub / github.com/ddbourgin/numpy-ml / TorchLSTMCell

Class TorchLSTMCell

numpy_ml/tests/nn_torch_models.py:1341–1459  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

1339
1340
1341class TorchLSTMCell(nn.Module):
1342 def __init__(self, n_in, n_out, params, **kwargs):
1343 super(TorchLSTMCell, self).__init__()
1344
1345 Wiu = params["Wu"][n_out:, :].T
1346 Wif = params["Wf"][n_out:, :].T
1347 Wic = params["Wc"][n_out:, :].T
1348 Wio = params["Wo"][n_out:, :].T
1349 W_ih = np.vstack([Wiu, Wif, Wic, Wio])
1350
1351 Whu = params["Wu"][:n_out, :].T
1352 Whf = params["Wf"][:n_out, :].T
1353 Whc = params["Wc"][:n_out, :].T
1354 Who = params["Wo"][:n_out, :].T
1355 W_hh = np.vstack([Whu, Whf, Whc, Who])
1356
1357 self.layer1 = nn.LSTMCell(input_size=n_in, hidden_size=n_out, bias=True)
1358 assert self.layer1.weight_ih.shape == W_ih.shape
1359 assert self.layer1.weight_hh.shape == W_hh.shape
1360 self.layer1.weight_ih = nn.Parameter(torch.FloatTensor(W_ih))
1361 self.layer1.weight_hh = nn.Parameter(torch.FloatTensor(W_hh))
1362
1363 b = np.concatenate(
1364 [params["bu"], params["bf"], params["bc"], params["bo"]], axis=-1
1365 ).flatten()
1366 assert self.layer1.bias_ih.shape == b.shape
1367 assert self.layer1.bias_hh.shape == b.shape
1368 self.layer1.bias_ih = nn.Parameter(torch.FloatTensor(b))
1369 self.layer1.bias_hh = nn.Parameter(torch.FloatTensor(b))
1370
1371 def forward(self, X):
1372 self.X = X
1373 if not isinstance(self.X, torch.Tensor):
1374 self.X = torchify(self.X)
1375
1376 self.X.retain_grad()
1377
1378 # initial hidden state is 0
1379 n_ex, n_in, n_timesteps = self.X.shape
1380 n_out, n_out = self.layer1.weight_hh.shape
1381
1382 # initialize hidden states
1383 a0 = torchify(np.zeros((n_ex, n_out)))
1384 c0 = torchify(np.zeros((n_ex, n_out)))
1385 a0.retain_grad()
1386 c0.retain_grad()
1387
1388 # forward pass
1389 A, C = [], []
1390 at = a0
1391 ct = c0
1392 for t in range(n_timesteps):
1393 A.append(at)
1394 C.append(ct)
1395 at1, ct1 = self.layer1(self.X[:, :, t], (at, ct))
1396 at.retain_grad()
1397 ct.retain_grad()
1398 at = at1

Callers 1

test_LSTMCellFunction · 0.85

Calls

no outgoing calls

Tested by 1

test_LSTMCellFunction · 0.68