MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_LSTMCell

Function test_LSTMCell

numpy_ml/tests/test_nn.py:1482–1556  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

1480
1481
1482def test_LSTMCell(N=15):
1483 from numpy_ml.neural_nets.layers import LSTMCell
1484
1485 N = np.inf if N is None else N
1486
1487 np.random.seed(12345)
1488
1489 i = 1
1490 while i < N + 1:
1491 n_ex = np.random.randint(1, 10)
1492 n_in = np.random.randint(1, 10)
1493 n_out = np.random.randint(1, 10)
1494 n_t = np.random.randint(1, 10)
1495 X = random_tensor((n_ex, n_in, n_t), standardize=True)
1496
1497 # initialize LSTM layer
1498 L1 = LSTMCell(n_out=n_out)
1499
1500 # forward prop
1501 Cs = []
1502 y_preds = []
1503 for t in range(n_t):
1504 y_pred, Ct = L1.forward(X[:, :, t])
1505 y_preds.append(y_pred)
1506 Cs.append(Ct)
1507
1508 # backprop
1509 dLdX = []
1510 dLdAt = np.ones_like(y_preds[t])
1511 for t in reversed(range(n_t)):
1512 dLdXt = L1.backward(dLdAt)
1513 dLdX.insert(0, dLdXt)
1514 dLdX = np.dstack(dLdX)
1515 y_preds = np.dstack(y_preds)
1516 Cs = np.array(Cs)
1517
1518 # get gold standard gradients
1519 gold_mod = TorchLSTMCell(n_in, n_out, L1.parameters)
1520 golds = gold_mod.extract_grads(X)
1521
1522 params = [
1523 (X, "X"),
1524 (np.array(Cs), "C"),
1525 (y_preds, "y"),
1526 (L1.parameters["bo"].T, "bo"),
1527 (L1.parameters["bu"].T, "bu"),
1528 (L1.parameters["bf"].T, "bf"),
1529 (L1.parameters["bc"].T, "bc"),
1530 (L1.parameters["Wo"], "Wo"),
1531 (L1.parameters["Wu"], "Wu"),
1532 (L1.parameters["Wf"], "Wf"),
1533 (L1.parameters["Wc"], "Wc"),
1534 (L1.gradients["bo"].T, "dLdBo"),
1535 (L1.gradients["bu"].T, "dLdBu"),
1536 (L1.gradients["bf"].T, "dLdBf"),
1537 (L1.gradients["bc"].T, "dLdBc"),
1538 (L1.gradients["Wo"], "dLdWo"),
1539 (L1.gradients["Wu"], "dLdWu"),

Callers

nothing calls this directly

Calls 7

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
random_tensorFunction · 0.90
LSTMCellClass · 0.90
TorchLSTMCellClass · 0.85
err_fmtFunction · 0.70

Tested by

no test coverage detected