MCPcopy Index your code
hub / github.com/ZHKKKe/MODNet / FusionBranch

Class FusionBranch

torchscript/modnet_torchscript.py:170–194  ·  view source on GitHub ↗

Fusion Branch of MODNet

Source from the content-addressed store, hash-verified

168
169
170class FusionBranch(nn.Module):
171 """ Fusion Branch of MODNet
172 """
173
174 def __init__(self, hr_channels, enc_channels):
175 super(FusionBranch, self).__init__()
176 self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
177
178 self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
179 self.conv_f = nn.Sequential(
180 Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
181 Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
182 )
183
184 def forward(self, img, lr8x, hr2x):
185 lr4x = F.interpolate(lr8x, scale_factor=2.0, mode='bilinear', align_corners=False)
186 lr4x = self.conv_lr4x(lr4x)
187 lr2x = F.interpolate(lr4x, scale_factor=2.0, mode='bilinear', align_corners=False)
188
189 f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
190 f = F.interpolate(f2x, scale_factor=2.0, mode='bilinear', align_corners=False)
191 f = self.conv_f(torch.cat((f, img), dim=1))
192 pred_matte = torch.sigmoid(f)
193
194 return pred_matte
195
196
197#------------------------------------------------------------------------------

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected