MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_SkipConnectionConvModule

Function test_SkipConnectionConvModule

numpy_ml/tests/test_nn.py:1822–1987  ·  view source on GitHub ↗
(N=15)

Source from the content-addressed store, hash-verified

1820
1821
1822def test_SkipConnectionConvModule(N=15):
1823 from numpy_ml.neural_nets.modules import SkipConnectionConvModule
1824 from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine
1825
1826 N = np.inf if N is None else N
1827
1828 np.random.seed(12345)
1829
1830 acts = [
1831 (Tanh(), nn.Tanh(), "Tanh"),
1832 (Sigmoid(), nn.Sigmoid(), "Sigmoid"),
1833 (ReLU(), nn.ReLU(), "ReLU"),
1834 (Affine(), TorchLinearActivation(), "Affine"),
1835 ]
1836
1837 i = 1
1838 while i < N + 1:
1839 n_ex = np.random.randint(2, 10)
1840 in_rows = np.random.randint(2, 10)
1841 in_cols = np.random.randint(2, 10)
1842 n_in = np.random.randint(2, 5)
1843 n_out1 = np.random.randint(2, 5)
1844 n_out2 = np.random.randint(2, 5)
1845 f_shape1 = (
1846 min(in_rows, np.random.randint(1, 5)),
1847 min(in_cols, np.random.randint(1, 5)),
1848 )
1849 f_shape2 = (
1850 min(in_rows, np.random.randint(1, 5)),
1851 min(in_cols, np.random.randint(1, 5)),
1852 )
1853 f_shape_skip = (
1854 min(in_rows, np.random.randint(1, 5)),
1855 min(in_cols, np.random.randint(1, 5)),
1856 )
1857
1858 s1 = np.random.randint(1, 5)
1859 s2 = np.random.randint(1, 5)
1860 s_skip = np.random.randint(1, 5)
1861
1862 # randomly select an activation function
1863 act_fn, torch_fn, act_fn_name = acts[np.random.randint(0, len(acts))]
1864
1865 X = random_tensor((n_ex, in_rows, in_cols, n_in), standardize=True)
1866
1867 p1 = (np.random.randint(1, 5), np.random.randint(1, 5))
1868 p2 = (np.random.randint(1, 5), np.random.randint(1, 5))
1869
1870 # initialize SkipConnectionConv module
1871 L1 = SkipConnectionConvModule(
1872 out_ch1=n_out1,
1873 out_ch2=n_out2,
1874 kernel_shape1=f_shape1,
1875 kernel_shape2=f_shape2,
1876 kernel_shape_skip=f_shape_skip,
1877 stride1=s1,
1878 stride2=s2,
1879 stride_skip=s_skip,

Callers

nothing calls this directly

Calls 12

forwardMethod · 0.95
backwardMethod · 0.95
extract_gradsMethod · 0.95
TanhClass · 0.90
SigmoidClass · 0.90
ReLUClass · 0.90
AffineClass · 0.90
random_tensorFunction · 0.90
err_fmtFunction · 0.70

Tested by

no test coverage detected