(self, num_classes, trunk='wrn38', criterion=None)
| 167 | DeepLabV3Plus-based mscale segmentation model |
| 168 | """ |
| 169 | def __init__(self, num_classes, trunk='wrn38', criterion=None): |
| 170 | super(MscaleV3Plus, self).__init__() |
| 171 | self.criterion = criterion |
| 172 | self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) |
| 173 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, |
| 174 | bottleneck_ch=256, |
| 175 | output_stride=8) |
| 176 | self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) |
| 177 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) |
| 178 | |
| 179 | # Semantic segmentation prediction head |
| 180 | self.final = nn.Sequential( |
| 181 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), |
| 182 | Norm2d(256), |
| 183 | nn.ReLU(inplace=True), |
| 184 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), |
| 185 | Norm2d(256), |
| 186 | nn.ReLU(inplace=True), |
| 187 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) |
| 188 | |
| 189 | # Scale-attention prediction head |
| 190 | scale_in_ch = 2 * (256 + 48) |
| 191 | |
| 192 | self.scale_attn = nn.Sequential( |
| 193 | nn.Conv2d(scale_in_ch, 256, kernel_size=3, padding=1, bias=False), |
| 194 | Norm2d(256), |
| 195 | nn.ReLU(inplace=True), |
| 196 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), |
| 197 | Norm2d(256), |
| 198 | nn.ReLU(inplace=True), |
| 199 | nn.Conv2d(256, 1, kernel_size=1, bias=False), |
| 200 | nn.Sigmoid()) |
| 201 | |
| 202 | if cfg.OPTIONS.INIT_DECODER: |
| 203 | initialize_weights(self.bot_fine) |
| 204 | initialize_weights(self.bot_aspp) |
| 205 | initialize_weights(self.scale_attn) |
| 206 | initialize_weights(self.final) |
| 207 | else: |
| 208 | initialize_weights(self.final) |
| 209 | |
| 210 | def _fwd(self, x): |
| 211 | x_size = x.size() |
nothing calls this directly
no test coverage detected