MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_FullyConnected

Function test_FullyConnected

numpy_ml/tests/test_nn.py:640–696  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

638
639
640def test_FullyConnected(N=15):
641 from numpy_ml.neural_nets.layers import FullyConnected
642 from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine
643
644 np.random.seed(12345)
645
646 N = np.inf if N is None else N
647
648 acts = [
649 (Tanh(), nn.Tanh(), "Tanh"),
650 (Sigmoid(), nn.Sigmoid(), "Sigmoid"),
651 (ReLU(), nn.ReLU(), "ReLU"),
652 (Affine(), TorchLinearActivation(), "Affine"),
653 ]
654
655 i = 1
656 while i < N + 1:
657 n_ex = np.random.randint(1, 100)
658 n_in = np.random.randint(1, 100)
659 n_out = np.random.randint(1, 100)
660 X = random_tensor((n_ex, n_in), standardize=True)
661
662 # randomly select an activation function
663 act_fn, torch_fn, act_fn_name = acts[np.random.randint(0, len(acts))]
664
665 # initialize FC layer
666 L1 = FullyConnected(n_out=n_out, act_fn=act_fn)
667
668 # forward prop
669 y_pred = L1.forward(X)
670
671 # backprop
672 dLdy = np.ones_like(y_pred)
673 dLdX = L1.backward(dLdy)
674
675 # get gold standard gradients
676 gold_mod = TorchFCLayer(n_in, n_out, torch_fn, L1.parameters)
677 golds = gold_mod.extract_grads(X)
678
679 params = [
680 (L1.X[0], "X"),
681 (y_pred, "y"),
682 (L1.parameters["W"].T, "W"),
683 (L1.parameters["b"], "b"),
684 (dLdy, "dLdy"),
685 (L1.gradients["W"].T, "dLdW"),
686 (L1.gradients["b"], "dLdB"),
687 (dLdX, "dLdX"),
688 ]
689
690 print("\nTrial {}\nact_fn={}".format(i, act_fn_name))
691 for ix, (mine, label) in enumerate(params):
692 assert_almost_equal(
693 mine, golds[label], err_msg=err_fmt(params, golds, ix), decimal=3
694 )
695 print("\tPASSED {}".format(label))
696 i += 1
697

Callers

nothing calls this directly

Calls 12

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
TanhClass · 0.90
SigmoidClass · 0.90
ReLUClass · 0.90
AffineClass · 0.90
random_tensorFunction · 0.90
FullyConnectedClass · 0.90
TorchFCLayerClass · 0.85
err_fmtFunction · 0.70

Tested by

no test coverage detected