(N=15)
| 1988 | |
| 1989 | |
| 1990 | def test_BidirectionalLSTM(N=15): |
| 1991 | from numpy_ml.neural_nets.modules import BidirectionalLSTM |
| 1992 | |
| 1993 | N = np.inf if N is None else N |
| 1994 | |
| 1995 | np.random.seed(12345) |
| 1996 | |
| 1997 | i = 1 |
| 1998 | while i < N + 1: |
| 1999 | n_ex = np.random.randint(1, 10) |
| 2000 | n_in = np.random.randint(1, 10) |
| 2001 | n_out = np.random.randint(1, 10) |
| 2002 | n_t = np.random.randint(1, 10) |
| 2003 | X = random_tensor((n_ex, n_in, n_t), standardize=True) |
| 2004 | |
| 2005 | # initialize LSTM layer |
| 2006 | L1 = BidirectionalLSTM(n_out=n_out) |
| 2007 | |
| 2008 | # forward prop |
| 2009 | y_pred = L1.forward(X) |
| 2010 | |
| 2011 | # backprop |
| 2012 | dLdA = np.ones_like(y_pred) |
| 2013 | dLdX = L1.backward(dLdA) |
| 2014 | |
| 2015 | # get gold standard gradients |
| 2016 | gold_mod = TorchBidirectionalLSTM(n_in, n_out, L1.parameters) |
| 2017 | golds = gold_mod.extract_grads(X) |
| 2018 | |
| 2019 | pms, grads = L1.parameters["components"], L1.gradients["components"] |
| 2020 | params = [ |
| 2021 | (X, "X"), |
| 2022 | (y_pred, "y"), |
| 2023 | (pms["cell_fwd"]["bo"].T, "bo_f"), |
| 2024 | (pms["cell_fwd"]["bu"].T, "bu_f"), |
| 2025 | (pms["cell_fwd"]["bf"].T, "bf_f"), |
| 2026 | (pms["cell_fwd"]["bc"].T, "bc_f"), |
| 2027 | (pms["cell_fwd"]["Wo"], "Wo_f"), |
| 2028 | (pms["cell_fwd"]["Wu"], "Wu_f"), |
| 2029 | (pms["cell_fwd"]["Wf"], "Wf_f"), |
| 2030 | (pms["cell_fwd"]["Wc"], "Wc_f"), |
| 2031 | (pms["cell_bwd"]["bo"].T, "bo_b"), |
| 2032 | (pms["cell_bwd"]["bu"].T, "bu_b"), |
| 2033 | (pms["cell_bwd"]["bf"].T, "bf_b"), |
| 2034 | (pms["cell_bwd"]["bc"].T, "bc_b"), |
| 2035 | (pms["cell_bwd"]["Wo"], "Wo_b"), |
| 2036 | (pms["cell_bwd"]["Wu"], "Wu_b"), |
| 2037 | (pms["cell_bwd"]["Wf"], "Wf_b"), |
| 2038 | (pms["cell_bwd"]["Wc"], "Wc_b"), |
| 2039 | (grads["cell_fwd"]["bo"].T, "dLdBo_f"), |
| 2040 | (grads["cell_fwd"]["bu"].T, "dLdBu_f"), |
| 2041 | (grads["cell_fwd"]["bf"].T, "dLdBf_f"), |
| 2042 | (grads["cell_fwd"]["bc"].T, "dLdBc_f"), |
| 2043 | (grads["cell_fwd"]["Wo"], "dLdWo_f"), |
| 2044 | (grads["cell_fwd"]["Wu"], "dLdWu_f"), |
| 2045 | (grads["cell_fwd"]["Wf"], "dLdWf_f"), |
| 2046 | (grads["cell_fwd"]["Wc"], "dLdWc_f"), |
| 2047 | (grads["cell_bwd"]["bo"].T, "dLdBo_b"), |
nothing calls this directly
no test coverage detected