High Resolution Branch of MODNet
| 118 | |
| 119 | |
| 120 | class HRBranch(nn.Module): |
| 121 | """ High Resolution Branch of MODNet |
| 122 | """ |
| 123 | |
| 124 | def __init__(self, hr_channels, enc_channels): |
| 125 | super(HRBranch, self).__init__() |
| 126 | |
| 127 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) |
| 128 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) |
| 129 | |
| 130 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) |
| 131 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) |
| 132 | |
| 133 | self.conv_hr4x = nn.Sequential( |
| 134 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), |
| 135 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), |
| 136 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), |
| 137 | ) |
| 138 | |
| 139 | self.conv_hr2x = nn.Sequential( |
| 140 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), |
| 141 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), |
| 142 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), |
| 143 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), |
| 144 | ) |
| 145 | |
| 146 | self.conv_hr = nn.Sequential( |
| 147 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), |
| 148 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), |
| 149 | ) |
| 150 | |
| 151 | def forward(self, img, enc2x, enc4x, lr8x): |
| 152 | img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) |
| 153 | img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) |
| 154 | |
| 155 | enc2x = self.tohr_enc2x(enc2x) |
| 156 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) |
| 157 | |
| 158 | enc4x = self.tohr_enc4x(enc4x) |
| 159 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) |
| 160 | |
| 161 | lr4x = F.interpolate(lr8x, scale_factor=2.0, mode='bilinear', align_corners=False) |
| 162 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) |
| 163 | |
| 164 | hr2x = F.interpolate(hr4x, scale_factor=2.0, mode='bilinear', align_corners=False) |
| 165 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) |
| 166 | |
| 167 | return hr2x |
| 168 | |
| 169 | |
| 170 | class FusionBranch(nn.Module): |