MCPcopy Index your code
hub / github.com/ddbourgin/numpy-ml / test_MultiplyLayer

Function test_MultiplyLayer

numpy_ml/tests/test_nn.py:899–952  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

897
898
899def test_MultiplyLayer(N=15):
900 from numpy_ml.neural_nets.layers import Multiply
901 from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine
902
903 N = np.inf if N is None else N
904
905 np.random.seed(12345)
906
907 acts = [
908 (Tanh(), nn.Tanh(), "Tanh"),
909 (Sigmoid(), nn.Sigmoid(), "Sigmoid"),
910 (ReLU(), nn.ReLU(), "ReLU"),
911 (Affine(), TorchLinearActivation(), "Affine"),
912 ]
913
914 i = 1
915 while i < N + 1:
916 Xs = []
917 n_ex = np.random.randint(1, 100)
918 n_in = np.random.randint(1, 100)
919 n_entries = np.random.randint(2, 5)
920 for _ in range(n_entries):
921 Xs.append(random_tensor((n_ex, n_in), standardize=True))
922
923 act_fn, torch_fn, act_fn_name = acts[np.random.randint(0, len(acts))]
924
925 # initialize Add layer
926 L1 = Multiply(act_fn)
927
928 # forward prop
929 y_pred = L1.forward(Xs)
930
931 # backprop
932 dLdy = np.ones_like(y_pred)
933 dLdXs = L1.backward(dLdy)
934
935 # get gold standard gradients
936 gold_mod = TorchMultiplyLayer(torch_fn)
937 golds = gold_mod.extract_grads(Xs)
938
939 params = [(Xs, "Xs"), (y_pred, "Y")]
940 params.extend(
941 [(dldxi, "dLdX{}".format(i + 1)) for i, dldxi in enumerate(dLdXs)]
942 )
943
944 print("\nTrial {}".format(i))
945 print("n_ex={}, n_in={}".format(n_ex, n_in))
946 print("n_entries={}, act_fn={}".format(n_entries, str(act_fn)))
947 for ix, (mine, label) in enumerate(params):
948 assert_almost_equal(
949 mine, golds[label], err_msg=err_fmt(params, golds, ix), decimal=1
950 )
951 print("\tPASSED {}".format(label))
952 i += 1
953
954
955def test_AddLayer(N=15):

Callers

nothing calls this directly

Calls 12

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
TanhClass · 0.90
SigmoidClass · 0.90
ReLUClass · 0.90
AffineClass · 0.90
random_tensorFunction · 0.90
MultiplyClass · 0.90
TorchMultiplyLayerClass · 0.85
err_fmtFunction · 0.70

Tested by

no test coverage detected