(N=15)
| 1128 | |
| 1129 | |
| 1130 | def test_Conv2D(N=15): |
| 1131 | from numpy_ml.neural_nets.layers import Conv2D |
| 1132 | from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine |
| 1133 | |
| 1134 | N = np.inf if N is None else N |
| 1135 | |
| 1136 | np.random.seed(12345) |
| 1137 | |
| 1138 | acts = [ |
| 1139 | (Tanh(), nn.Tanh(), "Tanh"), |
| 1140 | (Sigmoid(), nn.Sigmoid(), "Sigmoid"), |
| 1141 | (ReLU(), nn.ReLU(), "ReLU"), |
| 1142 | (Affine(), TorchLinearActivation(), "Affine"), |
| 1143 | ] |
| 1144 | |
| 1145 | i = 1 |
| 1146 | while i < N + 1: |
| 1147 | n_ex = np.random.randint(1, 10) |
| 1148 | in_rows = np.random.randint(1, 10) |
| 1149 | in_cols = np.random.randint(1, 10) |
| 1150 | n_in, n_out = np.random.randint(1, 3), np.random.randint(1, 3) |
| 1151 | f_shape = ( |
| 1152 | min(in_rows, np.random.randint(1, 5)), |
| 1153 | min(in_cols, np.random.randint(1, 5)), |
| 1154 | ) |
| 1155 | p, s = np.random.randint(0, 5), np.random.randint(1, 3) |
| 1156 | d = np.random.randint(0, 5) |
| 1157 | |
| 1158 | fr, fc = f_shape[0] * (d + 1) - d, f_shape[1] * (d + 1) - d |
| 1159 | out_rows = int(1 + (in_rows + 2 * p - fr) / s) |
| 1160 | out_cols = int(1 + (in_cols + 2 * p - fc) / s) |
| 1161 | |
| 1162 | if out_rows <= 0 or out_cols <= 0: |
| 1163 | continue |
| 1164 | |
| 1165 | X = random_tensor((n_ex, in_rows, in_cols, n_in), standardize=True) |
| 1166 | |
| 1167 | # randomly select an activation function |
| 1168 | act_fn, torch_fn, act_fn_name = acts[np.random.randint(0, len(acts))] |
| 1169 | |
| 1170 | # initialize Conv2D layer |
| 1171 | L1 = Conv2D( |
| 1172 | out_ch=n_out, |
| 1173 | kernel_shape=f_shape, |
| 1174 | act_fn=act_fn, |
| 1175 | pad=p, |
| 1176 | stride=s, |
| 1177 | dilation=d, |
| 1178 | ) |
| 1179 | |
| 1180 | # forward prop |
| 1181 | y_pred = L1.forward(X) |
| 1182 | |
| 1183 | # backprop |
| 1184 | dLdy = np.ones_like(y_pred) |
| 1185 | dLdX = L1.backward(dLdy) |
| 1186 | |
| 1187 | # get gold standard gradients |
nothing calls this directly
no test coverage detected