(N=15)
| 1480 | |
| 1481 | |
| 1482 | def test_LSTMCell(N=15): |
| 1483 | from numpy_ml.neural_nets.layers import LSTMCell |
| 1484 | |
| 1485 | N = np.inf if N is None else N |
| 1486 | |
| 1487 | np.random.seed(12345) |
| 1488 | |
| 1489 | i = 1 |
| 1490 | while i < N + 1: |
| 1491 | n_ex = np.random.randint(1, 10) |
| 1492 | n_in = np.random.randint(1, 10) |
| 1493 | n_out = np.random.randint(1, 10) |
| 1494 | n_t = np.random.randint(1, 10) |
| 1495 | X = random_tensor((n_ex, n_in, n_t), standardize=True) |
| 1496 | |
| 1497 | # initialize LSTM layer |
| 1498 | L1 = LSTMCell(n_out=n_out) |
| 1499 | |
| 1500 | # forward prop |
| 1501 | Cs = [] |
| 1502 | y_preds = [] |
| 1503 | for t in range(n_t): |
| 1504 | y_pred, Ct = L1.forward(X[:, :, t]) |
| 1505 | y_preds.append(y_pred) |
| 1506 | Cs.append(Ct) |
| 1507 | |
| 1508 | # backprop |
| 1509 | dLdX = [] |
| 1510 | dLdAt = np.ones_like(y_preds[t]) |
| 1511 | for t in reversed(range(n_t)): |
| 1512 | dLdXt = L1.backward(dLdAt) |
| 1513 | dLdX.insert(0, dLdXt) |
| 1514 | dLdX = np.dstack(dLdX) |
| 1515 | y_preds = np.dstack(y_preds) |
| 1516 | Cs = np.array(Cs) |
| 1517 | |
| 1518 | # get gold standard gradients |
| 1519 | gold_mod = TorchLSTMCell(n_in, n_out, L1.parameters) |
| 1520 | golds = gold_mod.extract_grads(X) |
| 1521 | |
| 1522 | params = [ |
| 1523 | (X, "X"), |
| 1524 | (np.array(Cs), "C"), |
| 1525 | (y_preds, "y"), |
| 1526 | (L1.parameters["bo"].T, "bo"), |
| 1527 | (L1.parameters["bu"].T, "bu"), |
| 1528 | (L1.parameters["bf"].T, "bf"), |
| 1529 | (L1.parameters["bc"].T, "bc"), |
| 1530 | (L1.parameters["Wo"], "Wo"), |
| 1531 | (L1.parameters["Wu"], "Wu"), |
| 1532 | (L1.parameters["Wf"], "Wf"), |
| 1533 | (L1.parameters["Wc"], "Wc"), |
| 1534 | (L1.gradients["bo"].T, "dLdBo"), |
| 1535 | (L1.gradients["bu"].T, "dLdBu"), |
| 1536 | (L1.gradients["bf"].T, "dLdBf"), |
| 1537 | (L1.gradients["bc"].T, "dLdBc"), |
| 1538 | (L1.gradients["Wo"], "dLdWo"), |
| 1539 | (L1.gradients["Wu"], "dLdWu"), |
nothing calls this directly
no test coverage detected