(N=15)
| 1009 | |
| 1010 | |
| 1011 | def test_BatchNorm2D(N=15): |
| 1012 | from numpy_ml.neural_nets.layers import BatchNorm2D |
| 1013 | |
| 1014 | N = np.inf if N is None else N |
| 1015 | |
| 1016 | np.random.seed(12345) |
| 1017 | |
| 1018 | i = 1 |
| 1019 | while i < N + 1: |
| 1020 | n_ex = np.random.randint(2, 10) |
| 1021 | in_rows = np.random.randint(1, 10) |
| 1022 | in_cols = np.random.randint(1, 10) |
| 1023 | n_in = np.random.randint(1, 3) |
| 1024 | |
| 1025 | # initialize BatchNorm2D layer |
| 1026 | X = random_tensor((n_ex, in_rows, in_cols, n_in), standardize=True) |
| 1027 | L1 = BatchNorm2D() |
| 1028 | |
| 1029 | # forward prop |
| 1030 | y_pred = L1.forward(X) |
| 1031 | |
| 1032 | # standard sum loss |
| 1033 | dLdy = np.ones_like(X) |
| 1034 | dLdX = L1.backward(dLdy) |
| 1035 | |
| 1036 | # get gold standard gradients |
| 1037 | gold_mod = TorchBatchNormLayer( |
| 1038 | n_in, L1.parameters, mode="2D", epsilon=L1.epsilon, momentum=L1.momentum |
| 1039 | ) |
| 1040 | golds = gold_mod.extract_grads(X, Y_true=None) |
| 1041 | |
| 1042 | params = [ |
| 1043 | (L1.X[0], "X"), |
| 1044 | (L1.hyperparameters["momentum"], "momentum"), |
| 1045 | (L1.hyperparameters["epsilon"], "epsilon"), |
| 1046 | (L1.parameters["scaler"].T, "scaler"), |
| 1047 | (L1.parameters["intercept"], "intercept"), |
| 1048 | (L1.parameters["running_mean"], "running_mean"), |
| 1049 | # (L1.parameters["running_var"], "running_var"), |
| 1050 | (y_pred, "y"), |
| 1051 | (L1.gradients["scaler"], "dLdScaler"), |
| 1052 | (L1.gradients["intercept"], "dLdIntercept"), |
| 1053 | (dLdX, "dLdX"), |
| 1054 | ] |
| 1055 | |
| 1056 | print("Trial {}".format(i)) |
| 1057 | for ix, (mine, label) in enumerate(params): |
| 1058 | assert_almost_equal( |
| 1059 | mine, golds[label], err_msg=err_fmt(params, golds, ix), decimal=3 |
| 1060 | ) |
| 1061 | |
| 1062 | print("\tPASSED {}".format(label)) |
| 1063 | |
| 1064 | i += 1 |
| 1065 | |
| 1066 | |
| 1067 | def test_RNNCell(N=15): |
nothing calls this directly
no test coverage detected