MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_WaveNetModule

Function test_WaveNetModule

numpy_ml/tests/test_nn.py:2072–2166  ·  view source on GitHub ↗
(N=10)

Source from the content-addressed store, hash-verified

2070
2071
2072def test_WaveNetModule(N=10):
2073 from numpy_ml.neural_nets.modules import WavenetResidualModule
2074
2075 N = np.inf if N is None else N
2076
2077 np.random.seed(12345)
2078
2079 i = 1
2080 while i < N + 1:
2081 n_ex = np.random.randint(1, 10)
2082 l_in = np.random.randint(1, 10)
2083 ch_residual, ch_dilation = np.random.randint(1, 5), np.random.randint(1, 5)
2084 f_width = min(l_in, np.random.randint(1, 5))
2085 d = np.random.randint(0, 5)
2086
2087 X_main = np.zeros_like(
2088 random_tensor((n_ex, l_in, ch_residual), standardize=True)
2089 )
2090 X_main[0][0][0] = 1.0
2091 X_skip = np.zeros_like(
2092 random_tensor((n_ex, l_in, ch_residual), standardize=True)
2093 )
2094
2095 # initialize Conv2D layer
2096 L1 = WavenetResidualModule(
2097 ch_residual=ch_residual,
2098 ch_dilation=ch_dilation,
2099 kernel_width=f_width,
2100 dilation=d,
2101 )
2102
2103 # forward prop
2104 Y_main, Y_skip = L1.forward(X_main, X_skip)
2105
2106 # backprop
2107 dLdY_skip = np.ones_like(Y_skip)
2108 dLdY_main = np.ones_like(Y_main)
2109 dLdX_main, dLdX_skip = L1.backward(dLdY_skip, dLdY_main)
2110
2111 _, conv_1x1_pad = pad1D(
2112 L1._dv["multiply_gate_out"], "same", kernel_width=1, stride=1, dilation=0
2113 )
2114 if conv_1x1_pad[0] != conv_1x1_pad[1]:
2115 print("Skipping")
2116 continue
2117
2118 conv_1x1_pad = conv_1x1_pad[0]
2119
2120 # get gold standard gradients
2121 gold_mod = TorchWavenetModule(L1.parameters, L1.hyperparameters, conv_1x1_pad)
2122 golds = gold_mod.extract_grads(X_main, X_skip)
2123
2124 dv = L1.derived_variables
2125 pc = L1.parameters["components"]
2126 gr = L1.gradients["components"]
2127
2128 params = [
2129 (L1.X_main, "X_main"),

Callers

nothing calls this directly

Calls 8

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
random_tensorFunction · 0.90
pad1DFunction · 0.90
TorchWavenetModuleClass · 0.85
err_fmtFunction · 0.70

Tested by

no test coverage detected