(N=10)
| 2070 | |
| 2071 | |
| 2072 | def test_WaveNetModule(N=10): |
| 2073 | from numpy_ml.neural_nets.modules import WavenetResidualModule |
| 2074 | |
| 2075 | N = np.inf if N is None else N |
| 2076 | |
| 2077 | np.random.seed(12345) |
| 2078 | |
| 2079 | i = 1 |
| 2080 | while i < N + 1: |
| 2081 | n_ex = np.random.randint(1, 10) |
| 2082 | l_in = np.random.randint(1, 10) |
| 2083 | ch_residual, ch_dilation = np.random.randint(1, 5), np.random.randint(1, 5) |
| 2084 | f_width = min(l_in, np.random.randint(1, 5)) |
| 2085 | d = np.random.randint(0, 5) |
| 2086 | |
| 2087 | X_main = np.zeros_like( |
| 2088 | random_tensor((n_ex, l_in, ch_residual), standardize=True) |
| 2089 | ) |
| 2090 | X_main[0][0][0] = 1.0 |
| 2091 | X_skip = np.zeros_like( |
| 2092 | random_tensor((n_ex, l_in, ch_residual), standardize=True) |
| 2093 | ) |
| 2094 | |
| 2095 | # initialize Conv2D layer |
| 2096 | L1 = WavenetResidualModule( |
| 2097 | ch_residual=ch_residual, |
| 2098 | ch_dilation=ch_dilation, |
| 2099 | kernel_width=f_width, |
| 2100 | dilation=d, |
| 2101 | ) |
| 2102 | |
| 2103 | # forward prop |
| 2104 | Y_main, Y_skip = L1.forward(X_main, X_skip) |
| 2105 | |
| 2106 | # backprop |
| 2107 | dLdY_skip = np.ones_like(Y_skip) |
| 2108 | dLdY_main = np.ones_like(Y_main) |
| 2109 | dLdX_main, dLdX_skip = L1.backward(dLdY_skip, dLdY_main) |
| 2110 | |
| 2111 | _, conv_1x1_pad = pad1D( |
| 2112 | L1._dv["multiply_gate_out"], "same", kernel_width=1, stride=1, dilation=0 |
| 2113 | ) |
| 2114 | if conv_1x1_pad[0] != conv_1x1_pad[1]: |
| 2115 | print("Skipping") |
| 2116 | continue |
| 2117 | |
| 2118 | conv_1x1_pad = conv_1x1_pad[0] |
| 2119 | |
| 2120 | # get gold standard gradients |
| 2121 | gold_mod = TorchWavenetModule(L1.parameters, L1.hyperparameters, conv_1x1_pad) |
| 2122 | golds = gold_mod.extract_grads(X_main, X_skip) |
| 2123 | |
| 2124 | dv = L1.derived_variables |
| 2125 | pc = L1.parameters["components"] |
| 2126 | gr = L1.gradients["components"] |
| 2127 | |
| 2128 | params = [ |
| 2129 | (L1.X_main, "X_main"), |
nothing calls this directly
no test coverage detected