MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_BidirectionalLSTM

Function test_BidirectionalLSTM

numpy_ml/tests/test_nn.py:1990–2069  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

1988
1989
1990def test_BidirectionalLSTM(N=15):
1991 from numpy_ml.neural_nets.modules import BidirectionalLSTM
1992
1993 N = np.inf if N is None else N
1994
1995 np.random.seed(12345)
1996
1997 i = 1
1998 while i < N + 1:
1999 n_ex = np.random.randint(1, 10)
2000 n_in = np.random.randint(1, 10)
2001 n_out = np.random.randint(1, 10)
2002 n_t = np.random.randint(1, 10)
2003 X = random_tensor((n_ex, n_in, n_t), standardize=True)
2004
2005 # initialize LSTM layer
2006 L1 = BidirectionalLSTM(n_out=n_out)
2007
2008 # forward prop
2009 y_pred = L1.forward(X)
2010
2011 # backprop
2012 dLdA = np.ones_like(y_pred)
2013 dLdX = L1.backward(dLdA)
2014
2015 # get gold standard gradients
2016 gold_mod = TorchBidirectionalLSTM(n_in, n_out, L1.parameters)
2017 golds = gold_mod.extract_grads(X)
2018
2019 pms, grads = L1.parameters["components"], L1.gradients["components"]
2020 params = [
2021 (X, "X"),
2022 (y_pred, "y"),
2023 (pms["cell_fwd"]["bo"].T, "bo_f"),
2024 (pms["cell_fwd"]["bu"].T, "bu_f"),
2025 (pms["cell_fwd"]["bf"].T, "bf_f"),
2026 (pms["cell_fwd"]["bc"].T, "bc_f"),
2027 (pms["cell_fwd"]["Wo"], "Wo_f"),
2028 (pms["cell_fwd"]["Wu"], "Wu_f"),
2029 (pms["cell_fwd"]["Wf"], "Wf_f"),
2030 (pms["cell_fwd"]["Wc"], "Wc_f"),
2031 (pms["cell_bwd"]["bo"].T, "bo_b"),
2032 (pms["cell_bwd"]["bu"].T, "bu_b"),
2033 (pms["cell_bwd"]["bf"].T, "bf_b"),
2034 (pms["cell_bwd"]["bc"].T, "bc_b"),
2035 (pms["cell_bwd"]["Wo"], "Wo_b"),
2036 (pms["cell_bwd"]["Wu"], "Wu_b"),
2037 (pms["cell_bwd"]["Wf"], "Wf_b"),
2038 (pms["cell_bwd"]["Wc"], "Wc_b"),
2039 (grads["cell_fwd"]["bo"].T, "dLdBo_f"),
2040 (grads["cell_fwd"]["bu"].T, "dLdBu_f"),
2041 (grads["cell_fwd"]["bf"].T, "dLdBf_f"),
2042 (grads["cell_fwd"]["bc"].T, "dLdBc_f"),
2043 (grads["cell_fwd"]["Wo"], "dLdWo_f"),
2044 (grads["cell_fwd"]["Wu"], "dLdWu_f"),
2045 (grads["cell_fwd"]["Wf"], "dLdWf_f"),
2046 (grads["cell_fwd"]["Wc"], "dLdWc_f"),
2047 (grads["cell_bwd"]["bo"].T, "dLdBo_b"),

Callers

nothing calls this directly

Calls 7

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
random_tensorFunction · 0.90
BidirectionalLSTMClass · 0.90
err_fmtFunction · 0.70

Tested by

no test coverage detected