MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_pad1D

Function test_pad1D

numpy_ml/tests/test_nn.py:2174–2270  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

2172
2173
2174def test_pad1D(N=15):
2175 from numpy_ml.neural_nets.layers import Conv1D
2176 from .nn_torch_models import TorchCausalConv1d, torchify
2177
2178 np.random.seed(12345)
2179
2180 N = np.inf if N is None else N
2181
2182 i = 1
2183 while i < N + 1:
2184 p = np.random.choice(["same", "causal"])
2185 n_ex = np.random.randint(1, 10)
2186 l_in = np.random.randint(1, 10)
2187 n_in, n_out = np.random.randint(1, 3), np.random.randint(1, 3)
2188 f_width = min(l_in, np.random.randint(1, 5))
2189 s = np.random.randint(1, 3)
2190 d = np.random.randint(0, 5)
2191
2192 X = random_tensor((n_ex, l_in, n_in), standardize=True)
2193 X_pad, _ = pad1D(X, p, kernel_width=f_width, stride=s, dilation=d)
2194
2195 # initialize Conv2D layer
2196 L1 = Conv1D(out_ch=n_out, kernel_width=f_width, pad=0, stride=s, dilation=d)
2197
2198 # forward prop
2199 try:
2200 y_pred = L1.forward(X_pad)
2201 except ValueError:
2202 continue
2203
2204 # ignore n. output channels
2205 print("Trial {}".format(i))
2206 print("p={} d={} s={} l_in={} f_width={}".format(p, d, s, l_in, f_width))
2207 print("n_ex={} n_in={} n_out={}".format(n_ex, n_in, n_out))
2208 assert y_pred.shape[:2] == X.shape[:2], print(
2209 "y_pred.shape={} X.shape={}".format(y_pred.shape, X.shape)
2210 )
2211
2212 if p == "causal":
2213 gold = TorchCausalConv1d(
2214 in_channels=n_in,
2215 out_channels=n_out,
2216 kernel_size=f_width,
2217 stride=s,
2218 dilation=d + 1,
2219 bias=True,
2220 )
2221 if s != 1:
2222 print(
2223 "TorchCausalConv1D does not do `same` padding for stride > 1. Skipping"
2224 )
2225 continue
2226
2227 XT = torchify(np.moveaxis(X, [0, 1, 2], [0, -1, -2]))
2228 else:
2229 gold = nn.Conv1d(
2230 in_channels=n_in,
2231 out_channels=n_out,

Callers

nothing calls this directly

Calls 7

forwardMethod · 0.95
random_tensorFunction · 0.90
pad1DFunction · 0.90
Conv1DClass · 0.90
TorchCausalConv1dClass · 0.85
torchifyFunction · 0.85
err_fmtFunction · 0.70

Tested by

no test coverage detected