Low Resolution Branch of MODNet
| 90 | #------------------------------------------------------------------------------ |
| 91 | |
| 92 | class LRBranch(nn.Module): |
| 93 | """ Low Resolution Branch of MODNet |
| 94 | """ |
| 95 | |
| 96 | def __init__(self, backbone): |
| 97 | super(LRBranch, self).__init__() |
| 98 | |
| 99 | enc_channels = backbone.enc_channels |
| 100 | |
| 101 | self.backbone = backbone |
| 102 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) |
| 103 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) |
| 104 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) |
| 105 | self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) |
| 106 | |
| 107 | def forward(self, img): |
| 108 | enc_features = self.backbone.forward(img) |
| 109 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] |
| 110 | |
| 111 | enc32x = self.se_block(enc32x) |
| 112 | lr16x = F.interpolate(enc32x, scale_factor=2.0, mode='bilinear', align_corners=False) |
| 113 | lr16x = self.conv_lr16x(lr16x) |
| 114 | lr8x = F.interpolate(lr16x, scale_factor=2.0, mode='bilinear', align_corners=False) |
| 115 | lr8x = self.conv_lr8x(lr8x) |
| 116 | |
| 117 | return lr8x, enc2x, enc4x |
| 118 | |
| 119 | |
| 120 | class HRBranch(nn.Module): |