(self, X)
| 234 | self.layer1.bias = nn.Parameter(torch.FloatTensor(intercept)) |
| 235 | |
| 236 | def forward(self, X): |
| 237 | # (N, H, W, C) -> (N, C, H, W) |
| 238 | if X.ndim == 4: |
| 239 | X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 240 | |
| 241 | if not isinstance(X, torch.Tensor): |
| 242 | X = torchify(X) |
| 243 | |
| 244 | self.X = X |
| 245 | self.Y = self.layer1(self.X) |
| 246 | self.Y.retain_grad() |
| 247 | |
| 248 | def extract_grads(self, X, Y_true=None): |
| 249 | self.forward(X) |
no test coverage detected