(N=15)
| 2172 | |
| 2173 | |
| 2174 | def test_pad1D(N=15): |
| 2175 | from numpy_ml.neural_nets.layers import Conv1D |
| 2176 | from .nn_torch_models import TorchCausalConv1d, torchify |
| 2177 | |
| 2178 | np.random.seed(12345) |
| 2179 | |
| 2180 | N = np.inf if N is None else N |
| 2181 | |
| 2182 | i = 1 |
| 2183 | while i < N + 1: |
| 2184 | p = np.random.choice(["same", "causal"]) |
| 2185 | n_ex = np.random.randint(1, 10) |
| 2186 | l_in = np.random.randint(1, 10) |
| 2187 | n_in, n_out = np.random.randint(1, 3), np.random.randint(1, 3) |
| 2188 | f_width = min(l_in, np.random.randint(1, 5)) |
| 2189 | s = np.random.randint(1, 3) |
| 2190 | d = np.random.randint(0, 5) |
| 2191 | |
| 2192 | X = random_tensor((n_ex, l_in, n_in), standardize=True) |
| 2193 | X_pad, _ = pad1D(X, p, kernel_width=f_width, stride=s, dilation=d) |
| 2194 | |
| 2195 | # initialize Conv2D layer |
| 2196 | L1 = Conv1D(out_ch=n_out, kernel_width=f_width, pad=0, stride=s, dilation=d) |
| 2197 | |
| 2198 | # forward prop |
| 2199 | try: |
| 2200 | y_pred = L1.forward(X_pad) |
| 2201 | except ValueError: |
| 2202 | continue |
| 2203 | |
| 2204 | # ignore n. output channels |
| 2205 | print("Trial {}".format(i)) |
| 2206 | print("p={} d={} s={} l_in={} f_width={}".format(p, d, s, l_in, f_width)) |
| 2207 | print("n_ex={} n_in={} n_out={}".format(n_ex, n_in, n_out)) |
| 2208 | assert y_pred.shape[:2] == X.shape[:2], print( |
| 2209 | "y_pred.shape={} X.shape={}".format(y_pred.shape, X.shape) |
| 2210 | ) |
| 2211 | |
| 2212 | if p == "causal": |
| 2213 | gold = TorchCausalConv1d( |
| 2214 | in_channels=n_in, |
| 2215 | out_channels=n_out, |
| 2216 | kernel_size=f_width, |
| 2217 | stride=s, |
| 2218 | dilation=d + 1, |
| 2219 | bias=True, |
| 2220 | ) |
| 2221 | if s != 1: |
| 2222 | print( |
| 2223 | "TorchCausalConv1D does not do `same` padding for stride > 1. Skipping" |
| 2224 | ) |
| 2225 | continue |
| 2226 | |
| 2227 | XT = torchify(np.moveaxis(X, [0, 1, 2], [0, -1, -2])) |
| 2228 | else: |
| 2229 | gold = nn.Conv1d( |
| 2230 | in_channels=n_in, |
| 2231 | out_channels=n_out, |
nothing calls this directly
no test coverage detected