(self, X_main, X_skip)
| 616 | assert self.conv_1x1.bias.shape == b.flatten().shape |
| 617 | |
| 618 | def forward(self, X_main, X_skip): |
| 619 | # (N, W, C) -> (N, C, W) |
| 620 | self.X_main = np.moveaxis(X_main, [0, 1, 2], [0, -1, -2]) |
| 621 | self.X_main = torchify(self.X_main) |
| 622 | self.X_main.retain_grad() |
| 623 | |
| 624 | self.conv_dilation_out = self.conv_dilation(self.X_main) |
| 625 | self.conv_dilation_out.retain_grad() |
| 626 | |
| 627 | self.tanh_out = torch.tanh(self.conv_dilation_out) |
| 628 | self.sigm_out = torch.sigmoid(self.conv_dilation_out) |
| 629 | |
| 630 | self.tanh_out.retain_grad() |
| 631 | self.sigm_out.retain_grad() |
| 632 | |
| 633 | self.multiply_gate_out = self.tanh_out * self.sigm_out |
| 634 | self.multiply_gate_out.retain_grad() |
| 635 | |
| 636 | self.conv_1x1_out = self.conv_1x1(self.multiply_gate_out) |
| 637 | self.conv_1x1_out.retain_grad() |
| 638 | |
| 639 | self.X_skip = torch.zeros_like(self.conv_1x1_out) |
| 640 | if X_skip is not None: |
| 641 | self.X_skip = torchify(np.moveaxis(X_skip, [0, 1, 2], [0, -1, -2])) |
| 642 | self.X_skip.retain_grad() |
| 643 | |
| 644 | self.Y_skip = self.X_skip + self.conv_1x1_out |
| 645 | self.Y_main = self.X_main + self.conv_1x1_out |
| 646 | |
| 647 | self.Y_skip.retain_grad() |
| 648 | self.Y_main.retain_grad() |
| 649 | |
| 650 | def extract_grads(self, X_main, X_skip): |
| 651 | self.forward(X_main, X_skip) |
no test coverage detected