(N=15)
| 1428 | |
| 1429 | |
| 1430 | def test_Pool2D(N=15): |
| 1431 | from numpy_ml.neural_nets.layers import Pool2D |
| 1432 | |
| 1433 | N = np.inf if N is None else N |
| 1434 | |
| 1435 | np.random.seed(12345) |
| 1436 | |
| 1437 | i = 1 |
| 1438 | while i < N + 1: |
| 1439 | n_ex = np.random.randint(1, 10) |
| 1440 | in_rows = np.random.randint(1, 10) |
| 1441 | in_cols = np.random.randint(1, 10) |
| 1442 | n_in = np.random.randint(1, 3) |
| 1443 | f_shape = ( |
| 1444 | min(in_rows, np.random.randint(1, 5)), |
| 1445 | min(in_cols, np.random.randint(1, 5)), |
| 1446 | ) |
| 1447 | p, s = np.random.randint(0, max(1, min(f_shape) // 2)), np.random.randint(1, 3) |
| 1448 | # mode = ["max", "average"][np.random.randint(0, 2)] |
| 1449 | mode = "average" |
| 1450 | out_rows = int(1 + (in_rows + 2 * p - f_shape[0]) / s) |
| 1451 | out_cols = int(1 + (in_cols + 2 * p - f_shape[1]) / s) |
| 1452 | |
| 1453 | X = random_tensor((n_ex, in_rows, in_cols, n_in), standardize=True) |
| 1454 | print("\nmode: {}".format(mode)) |
| 1455 | print("pad={}, stride={}, f_shape={}, n_ex={}".format(p, s, f_shape, n_ex)) |
| 1456 | print("in_rows={}, in_cols={}, n_in={}".format(in_rows, in_cols, n_in)) |
| 1457 | print("out_rows={}, out_cols={}, n_out={}".format(out_rows, out_cols, n_in)) |
| 1458 | |
| 1459 | # initialize Pool2D layer |
| 1460 | L1 = Pool2D(kernel_shape=f_shape, pad=p, stride=s, mode=mode) |
| 1461 | |
| 1462 | # forward prop |
| 1463 | y_pred = L1.forward(X) |
| 1464 | |
| 1465 | # backprop |
| 1466 | dLdy = np.ones_like(y_pred) |
| 1467 | dLdX = L1.backward(dLdy) |
| 1468 | |
| 1469 | # get gold standard gradients |
| 1470 | gold_mod = TorchPool2DLayer(n_in, L1.hyperparameters) |
| 1471 | golds = gold_mod.extract_grads(X) |
| 1472 | |
| 1473 | params = [(L1.X[0], "X"), (y_pred, "y"), (dLdX, "dLdX")] |
| 1474 | for ix, (mine, label) in enumerate(params): |
| 1475 | assert_almost_equal( |
| 1476 | mine, golds[label], err_msg=err_fmt(params, golds, ix), decimal=4 |
| 1477 | ) |
| 1478 | print("\tPASSED {}".format(label)) |
| 1479 | i += 1 |
| 1480 | |
| 1481 | |
| 1482 | def test_LSTMCell(N=15): |
nothing calls this directly
no test coverage detected