MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_Conv1D

Function test_Conv1D

numpy_ml/tests/test_nn.py:1267–1345  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

1265
1266
1267def test_Conv1D(N=15):
1268 from numpy_ml.neural_nets.layers import Conv1D
1269 from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine
1270
1271 N = np.inf if N is None else N
1272
1273 np.random.seed(12345)
1274
1275 acts = [
1276 (Tanh(), nn.Tanh(), "Tanh"),
1277 (Sigmoid(), nn.Sigmoid(), "Sigmoid"),
1278 (ReLU(), nn.ReLU(), "ReLU"),
1279 (Affine(), TorchLinearActivation(), "Affine"),
1280 ]
1281
1282 i = 1
1283 while i < N + 1:
1284 n_ex = np.random.randint(1, 10)
1285 l_in = np.random.randint(1, 10)
1286 n_in, n_out = np.random.randint(1, 3), np.random.randint(1, 3)
1287 f_width = min(l_in, np.random.randint(1, 5))
1288 p, s = np.random.randint(0, 5), np.random.randint(1, 3)
1289 d = np.random.randint(0, 5)
1290
1291 fc = f_width * (d + 1) - d
1292 l_out = int(1 + (l_in + 2 * p - fc) / s)
1293
1294 if l_out <= 0:
1295 continue
1296
1297 X = random_tensor((n_ex, l_in, n_in), standardize=True)
1298
1299 # randomly select an activation function
1300 act_fn, torch_fn, act_fn_name = acts[np.random.randint(0, len(acts))]
1301
1302 # initialize Conv2D layer
1303 L1 = Conv1D(
1304 out_ch=n_out,
1305 kernel_width=f_width,
1306 act_fn=act_fn,
1307 pad=p,
1308 stride=s,
1309 dilation=d,
1310 )
1311
1312 # forward prop
1313 y_pred = L1.forward(X)
1314
1315 # backprop
1316 dLdy = np.ones_like(y_pred)
1317 dLdX = L1.backward(dLdy)
1318
1319 # get gold standard gradients
1320 gold_mod = TorchConv1DLayer(
1321 n_in, n_out, torch_fn, L1.parameters, L1.hyperparameters
1322 )
1323 golds = gold_mod.extract_grads(X)
1324

Callers

nothing calls this directly

Calls 12

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
TanhClass · 0.90
SigmoidClass · 0.90
ReLUClass · 0.90
AffineClass · 0.90
random_tensorFunction · 0.90
Conv1DClass · 0.90
TorchConv1DLayerClass · 0.85
err_fmtFunction · 0.70

Tested by

no test coverage detected