(self, num_classes, trunk='hrnetv2', criterion=None)
| 69 | ASPP-based Segmentation network |
| 70 | """ |
| 71 | def __init__(self, num_classes, trunk='hrnetv2', criterion=None): |
| 72 | super(ASPP, self).__init__() |
| 73 | self.criterion = criterion |
| 74 | self.backbone, _, _, high_level_ch = get_trunk(trunk) |
| 75 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, |
| 76 | bottleneck_ch=cfg.MODEL.ASPP_BOT_CH, |
| 77 | output_stride=8) |
| 78 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) |
| 79 | self.final = make_seg_head(in_ch=256, |
| 80 | out_ch=num_classes) |
| 81 | |
| 82 | initialize_weights(self.final, self.bot_aspp, self.aspp) |
| 83 | |
| 84 | def forward(self, inputs): |
| 85 | x = inputs['images'] |
no test coverage detected