| 714 | |
| 715 | class TorchSkipConnectionConv(nn.Module): |
| 716 | def __init__( |
| 717 | self, act_fn, pad1, pad2, pad_skip, params, hparams, momentum=0.9, epsilon=1e-5 |
| 718 | ): |
| 719 | super(TorchSkipConnectionConv, self).__init__() |
| 720 | |
| 721 | self.conv1 = nn.Conv2d( |
| 722 | hparams["in_ch"], |
| 723 | hparams["out_ch1"], |
| 724 | hparams["kernel_shape1"], |
| 725 | padding=pad1, |
| 726 | stride=hparams["stride1"], |
| 727 | bias=True, |
| 728 | ) |
| 729 | |
| 730 | self.act_fn = act_fn |
| 731 | |
| 732 | self.batchnorm1 = nn.BatchNorm2d( |
| 733 | num_features=hparams["out_ch1"], |
| 734 | momentum=1 - momentum, |
| 735 | eps=epsilon, |
| 736 | affine=True, |
| 737 | ) |
| 738 | |
| 739 | self.conv2 = nn.Conv2d( |
| 740 | hparams["out_ch1"], |
| 741 | hparams["out_ch2"], |
| 742 | hparams["kernel_shape2"], |
| 743 | padding=pad2, |
| 744 | stride=hparams["stride2"], |
| 745 | bias=True, |
| 746 | ) |
| 747 | |
| 748 | self.batchnorm2 = nn.BatchNorm2d( |
| 749 | num_features=hparams["out_ch2"], |
| 750 | momentum=1 - momentum, |
| 751 | eps=epsilon, |
| 752 | affine=True, |
| 753 | ) |
| 754 | |
| 755 | self.conv_skip = nn.Conv2d( |
| 756 | hparams["in_ch"], |
| 757 | hparams["out_ch2"], |
| 758 | hparams["kernel_shape_skip"], |
| 759 | padding=pad_skip, |
| 760 | stride=hparams["stride_skip"], |
| 761 | bias=True, |
| 762 | ) |
| 763 | |
| 764 | self.batchnorm_skip = nn.BatchNorm2d( |
| 765 | num_features=hparams["out_ch2"], |
| 766 | momentum=1 - momentum, |
| 767 | eps=epsilon, |
| 768 | affine=True, |
| 769 | ) |
| 770 | |
| 771 | orig, W_swap = [0, 1, 2, 3], [-2, -1, -3, -4] |
| 772 | # (f[0], f[1], n_in, n_out) -> (n_out, n_in, f[0], f[1]) |
| 773 | W = params["components"]["conv1"]["W"] |