MCPcopy Index your code
hub / github.com/ZHKKKe/MODNet / HRBranch

Class HRBranch

torchscript/modnet_torchscript.py:120–167  ·  view source on GitHub ↗

High Resolution Branch of MODNet

Source from the content-addressed store, hash-verified

118
119
120class HRBranch(nn.Module):
121 """ High Resolution Branch of MODNet
122 """
123
124 def __init__(self, hr_channels, enc_channels):
125 super(HRBranch, self).__init__()
126
127 self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
128 self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
129
130 self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
131 self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
132
133 self.conv_hr4x = nn.Sequential(
134 Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
135 Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
136 Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
137 )
138
139 self.conv_hr2x = nn.Sequential(
140 Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
141 Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
142 Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
143 Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
144 )
145
146 self.conv_hr = nn.Sequential(
147 Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
148 Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
149 )
150
151 def forward(self, img, enc2x, enc4x, lr8x):
152 img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
153 img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)
154
155 enc2x = self.tohr_enc2x(enc2x)
156 hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
157
158 enc4x = self.tohr_enc4x(enc4x)
159 hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
160
161 lr4x = F.interpolate(lr8x, scale_factor=2.0, mode='bilinear', align_corners=False)
162 hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
163
164 hr2x = F.interpolate(hr4x, scale_factor=2.0, mode='bilinear', align_corners=False)
165 hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
166
167 return hr2x
168
169
170class FusionBranch(nn.Module):

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected