MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_BatchNorm2D

Function test_BatchNorm2D

numpy_ml/tests/test_nn.py:1011–1064  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

1009
1010
1011def test_BatchNorm2D(N=15):
1012 from numpy_ml.neural_nets.layers import BatchNorm2D
1013
1014 N = np.inf if N is None else N
1015
1016 np.random.seed(12345)
1017
1018 i = 1
1019 while i < N + 1:
1020 n_ex = np.random.randint(2, 10)
1021 in_rows = np.random.randint(1, 10)
1022 in_cols = np.random.randint(1, 10)
1023 n_in = np.random.randint(1, 3)
1024
1025 # initialize BatchNorm2D layer
1026 X = random_tensor((n_ex, in_rows, in_cols, n_in), standardize=True)
1027 L1 = BatchNorm2D()
1028
1029 # forward prop
1030 y_pred = L1.forward(X)
1031
1032 # standard sum loss
1033 dLdy = np.ones_like(X)
1034 dLdX = L1.backward(dLdy)
1035
1036 # get gold standard gradients
1037 gold_mod = TorchBatchNormLayer(
1038 n_in, L1.parameters, mode="2D", epsilon=L1.epsilon, momentum=L1.momentum
1039 )
1040 golds = gold_mod.extract_grads(X, Y_true=None)
1041
1042 params = [
1043 (L1.X[0], "X"),
1044 (L1.hyperparameters["momentum"], "momentum"),
1045 (L1.hyperparameters["epsilon"], "epsilon"),
1046 (L1.parameters["scaler"].T, "scaler"),
1047 (L1.parameters["intercept"], "intercept"),
1048 (L1.parameters["running_mean"], "running_mean"),
1049 # (L1.parameters["running_var"], "running_var"),
1050 (y_pred, "y"),
1051 (L1.gradients["scaler"], "dLdScaler"),
1052 (L1.gradients["intercept"], "dLdIntercept"),
1053 (dLdX, "dLdX"),
1054 ]
1055
1056 print("Trial {}".format(i))
1057 for ix, (mine, label) in enumerate(params):
1058 assert_almost_equal(
1059 mine, golds[label], err_msg=err_fmt(params, golds, ix), decimal=3
1060 )
1061
1062 print("\tPASSED {}".format(label))
1063
1064 i += 1
1065
1066
1067def test_RNNCell(N=15):

Callers

nothing calls this directly

Calls 7

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
random_tensorFunction · 0.90
BatchNorm2DClass · 0.90
TorchBatchNormLayerClass · 0.85
err_fmtFunction · 0.70

Tested by

no test coverage detected