| 105 | self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) |
| 106 | |
| 107 | def forward(self, img): |
| 108 | enc_features = self.backbone.forward(img) |
| 109 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] |
| 110 | |
| 111 | enc32x = self.se_block(enc32x) |
| 112 | lr16x = F.interpolate(enc32x, scale_factor=2.0, mode='bilinear', align_corners=False) |
| 113 | lr16x = self.conv_lr16x(lr16x) |
| 114 | lr8x = F.interpolate(lr16x, scale_factor=2.0, mode='bilinear', align_corners=False) |
| 115 | lr8x = self.conv_lr8x(lr8x) |
| 116 | |
| 117 | return lr8x, enc2x, enc4x |
| 118 | |
| 119 | |
| 120 | class HRBranch(nn.Module): |