(N=15)
| 1346 | |
| 1347 | |
| 1348 | def test_Deconv2D(N=15): |
| 1349 | from numpy_ml.neural_nets.layers import Deconv2D |
| 1350 | from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine |
| 1351 | |
| 1352 | N = np.inf if N is None else N |
| 1353 | |
| 1354 | np.random.seed(12345) |
| 1355 | |
| 1356 | acts = [ |
| 1357 | (Tanh(), nn.Tanh(), "Tanh"), |
| 1358 | (Sigmoid(), nn.Sigmoid(), "Sigmoid"), |
| 1359 | (ReLU(), nn.ReLU(), "ReLU"), |
| 1360 | (Affine(), TorchLinearActivation(), "Affine"), |
| 1361 | ] |
| 1362 | |
| 1363 | i = 1 |
| 1364 | while i < N + 1: |
| 1365 | n_ex = np.random.randint(1, 10) |
| 1366 | in_rows = np.random.randint(1, 10) |
| 1367 | in_cols = np.random.randint(1, 10) |
| 1368 | n_in, n_out = np.random.randint(1, 3), np.random.randint(1, 3) |
| 1369 | f_shape = ( |
| 1370 | min(in_rows, np.random.randint(1, 5)), |
| 1371 | min(in_cols, np.random.randint(1, 5)), |
| 1372 | ) |
| 1373 | p, s = np.random.randint(0, 5), np.random.randint(1, 3) |
| 1374 | |
| 1375 | out_rows = s * (in_rows - 1) - 2 * p + f_shape[0] |
| 1376 | out_cols = s * (in_cols - 1) - 2 * p + f_shape[1] |
| 1377 | |
| 1378 | if out_rows <= 0 or out_cols <= 0: |
| 1379 | continue |
| 1380 | |
| 1381 | X = random_tensor((n_ex, in_rows, in_cols, n_in), standardize=True) |
| 1382 | |
| 1383 | # randomly select an activation function |
| 1384 | act_fn, torch_fn, act_fn_name = acts[np.random.randint(0, len(acts))] |
| 1385 | |
| 1386 | # initialize Deconv2D layer |
| 1387 | L1 = Deconv2D( |
| 1388 | out_ch=n_out, kernel_shape=f_shape, act_fn=act_fn, pad=p, stride=s |
| 1389 | ) |
| 1390 | |
| 1391 | # forward prop |
| 1392 | try: |
| 1393 | y_pred = L1.forward(X) |
| 1394 | except ValueError: |
| 1395 | print("Improper dimensions; retrying") |
| 1396 | continue |
| 1397 | |
| 1398 | # backprop |
| 1399 | dLdy = np.ones_like(y_pred) |
| 1400 | dLdX = L1.backward(dLdy) |
| 1401 | |
| 1402 | # get gold standard gradients |
| 1403 | gold_mod = TorchDeconv2DLayer( |
| 1404 | n_in, n_out, torch_fn, L1.parameters, L1.hyperparameters |
| 1405 | ) |
nothing calls this directly
no test coverage detected