(self, X)
| 155 | self.layer1.bias = nn.Parameter(torch.FloatTensor(intercept)) |
| 156 | |
| 157 | def forward(self, X): |
| 158 | # (N, H, W, C) -> (N, C, H, W) |
| 159 | if X.ndim == 4: |
| 160 | X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 161 | |
| 162 | if not isinstance(X, torch.Tensor): |
| 163 | X = torchify(X) |
| 164 | |
| 165 | self.X = X |
| 166 | self.Y = self.layer1(self.X) |
| 167 | self.Y.retain_grad() |
| 168 | |
| 169 | def extract_grads(self, X, Y_true=None): |
| 170 | self.forward(X) |
no test coverage detected