(self, X)
| 458 | self.batchnorm2.bias = nn.Parameter(torch.FloatTensor(intercept)) |
| 459 | |
| 460 | def forward(self, X): |
| 461 | if not isinstance(X, torch.Tensor): |
| 462 | # (N, H, W, C) -> (N, C, H, W) |
| 463 | X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 464 | X = torchify(X) |
| 465 | |
| 466 | self.X = X |
| 467 | self.X.retain_grad() |
| 468 | |
| 469 | self.conv1_out = self.conv1(self.X) |
| 470 | self.conv1_out.retain_grad() |
| 471 | |
| 472 | self.act_fn1_out = self.act_fn(self.conv1_out) |
| 473 | self.act_fn1_out.retain_grad() |
| 474 | |
| 475 | self.batchnorm1_out = self.batchnorm1(self.act_fn1_out) |
| 476 | self.batchnorm1_out.retain_grad() |
| 477 | |
| 478 | self.conv2_out = self.conv2(self.batchnorm1_out) |
| 479 | self.conv2_out.retain_grad() |
| 480 | |
| 481 | self.batchnorm2_out = self.batchnorm2(self.conv2_out) |
| 482 | self.batchnorm2_out.retain_grad() |
| 483 | |
| 484 | self.layer3_in = self.batchnorm2_out + self.X |
| 485 | self.layer3_in.retain_grad() |
| 486 | |
| 487 | self.Y = self.act_fn(self.layer3_in) |
| 488 | self.Y.retain_grad() |
| 489 | |
| 490 | def extract_grads(self, X): |
| 491 | self.forward(X) |
no test coverage detected