(self, X)
| 811 | self.batchnorm_skip.bias = nn.Parameter(torch.FloatTensor(intercept)) |
| 812 | |
| 813 | def forward(self, X): |
| 814 | if not isinstance(X, torch.Tensor): |
| 815 | # (N, H, W, C) -> (N, C, H, W) |
| 816 | X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3]) |
| 817 | X = torchify(X) |
| 818 | |
| 819 | self.X = X |
| 820 | self.X.retain_grad() |
| 821 | |
| 822 | self.conv1_out = self.conv1(self.X) |
| 823 | self.conv1_out.retain_grad() |
| 824 | |
| 825 | self.act_fn1_out = self.act_fn(self.conv1_out) |
| 826 | self.act_fn1_out.retain_grad() |
| 827 | |
| 828 | self.batchnorm1_out = self.batchnorm1(self.act_fn1_out) |
| 829 | self.batchnorm1_out.retain_grad() |
| 830 | |
| 831 | self.conv2_out = self.conv2(self.batchnorm1_out) |
| 832 | self.conv2_out.retain_grad() |
| 833 | |
| 834 | self.batchnorm2_out = self.batchnorm2(self.conv2_out) |
| 835 | self.batchnorm2_out.retain_grad() |
| 836 | |
| 837 | self.c_skip_out = self.conv_skip(self.X) |
| 838 | self.c_skip_out.retain_grad() |
| 839 | |
| 840 | self.bn_skip_out = self.batchnorm_skip(self.c_skip_out) |
| 841 | self.bn_skip_out.retain_grad() |
| 842 | |
| 843 | self.layer3_in = self.batchnorm2_out + self.bn_skip_out |
| 844 | self.layer3_in.retain_grad() |
| 845 | |
| 846 | self.Y = self.act_fn(self.layer3_in) |
| 847 | self.Y.retain_grad() |
| 848 | |
| 849 | def extract_grads(self, X): |
| 850 | self.forward(X) |
no test coverage detected