Feature fusion block.
| 283 | |
| 284 | |
| 285 | class FeatureFusionBlock(nn.Module): |
| 286 | """Feature fusion block. |
| 287 | """ |
| 288 | |
| 289 | def __init__(self, features): |
| 290 | """Init. |
| 291 | |
| 292 | Args: |
| 293 | features (int): number of features |
| 294 | """ |
| 295 | super(FeatureFusionBlock, self).__init__() |
| 296 | |
| 297 | self.resConfUnit1 = ResidualConvUnit(features) |
| 298 | self.resConfUnit2 = ResidualConvUnit(features) |
| 299 | |
| 300 | def forward(self, *xs): |
| 301 | """Forward pass. |
| 302 | |
| 303 | Returns: |
| 304 | tensor: output |
| 305 | """ |
| 306 | output = xs[0] |
| 307 | |
| 308 | if len(xs) == 2: |
| 309 | output += self.resConfUnit1(xs[1]) |
| 310 | |
| 311 | output = self.resConfUnit2(output) |
| 312 | |
| 313 | output = nn.functional.interpolate( |
| 314 | output, scale_factor=2, mode="bilinear", align_corners=True |
| 315 | ) |
| 316 | |
| 317 | return output |
| 318 | |
| 319 | |
| 320 |