(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True)
| 203 | """ |
| 204 | |
| 205 | def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): |
| 206 | super(MODNet, self).__init__() |
| 207 | |
| 208 | self.in_channels = in_channels |
| 209 | self.hr_channels = hr_channels |
| 210 | self.backbone_arch = backbone_arch |
| 211 | self.backbone_pretrained = backbone_pretrained |
| 212 | |
| 213 | self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) |
| 214 | |
| 215 | self.lr_branch = LRBranch(self.backbone) |
| 216 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) |
| 217 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) |
| 218 | |
| 219 | for m in self.modules(): |
| 220 | if isinstance(m, nn.Conv2d): |
| 221 | self._init_conv(m) |
| 222 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): |
| 223 | self._init_norm(m) |
| 224 | |
| 225 | if self.backbone_pretrained: |
| 226 | self.backbone.load_pretrained_ckpt() |
| 227 | |
| 228 | def forward(self, img): |
| 229 | # NOTE |
nothing calls this directly
no test coverage detected