MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_AddLayer

Function test_AddLayer

numpy_ml/tests/test_nn.py:955–1008  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

953
954
955def test_AddLayer(N=15):
956 from numpy_ml.neural_nets.layers import Add
957 from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine
958
959 N = np.inf if N is None else N
960
961 np.random.seed(12345)
962
963 acts = [
964 (Tanh(), nn.Tanh(), "Tanh"),
965 (Sigmoid(), nn.Sigmoid(), "Sigmoid"),
966 (ReLU(), nn.ReLU(), "ReLU"),
967 (Affine(), TorchLinearActivation(), "Affine"),
968 ]
969
970 i = 1
971 while i < N + 1:
972 Xs = []
973 n_ex = np.random.randint(1, 100)
974 n_in = np.random.randint(1, 100)
975 n_entries = np.random.randint(2, 5)
976 for _ in range(n_entries):
977 Xs.append(random_tensor((n_ex, n_in), standardize=True))
978
979 act_fn, torch_fn, act_fn_name = acts[np.random.randint(0, len(acts))]
980
981 # initialize Add layer
982 L1 = Add(act_fn)
983
984 # forward prop
985 y_pred = L1.forward(Xs)
986
987 # backprop
988 dLdy = np.ones_like(y_pred)
989 dLdXs = L1.backward(dLdy)
990
991 # get gold standard gradients
992 gold_mod = TorchAddLayer(torch_fn)
993 golds = gold_mod.extract_grads(Xs)
994
995 params = [(Xs, "Xs"), (y_pred, "Y")]
996 params.extend(
997 [(dldxi, "dLdX{}".format(i + 1)) for i, dldxi in enumerate(dLdXs)]
998 )
999
1000 print("\nTrial {}".format(i))
1001 print("n_ex={}, n_in={}".format(n_ex, n_in))
1002 print("n_entries={}, act_fn={}".format(n_entries, str(act_fn)))
1003 for ix, (mine, label) in enumerate(params):
1004 assert_almost_equal(
1005 mine, golds[label], err_msg=err_fmt(params, golds, ix), decimal=1
1006 )
1007 print("\tPASSED {}".format(label))
1008 i += 1
1009
1010
1011def test_BatchNorm2D(N=15):

Callers

nothing calls this directly

Calls 12

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
TanhClass · 0.90
SigmoidClass · 0.90
ReLUClass · 0.90
AffineClass · 0.90
random_tensorFunction · 0.90
AddClass · 0.90
TorchAddLayerClass · 0.85
err_fmtFunction · 0.70

Tested by

no test coverage detected