(self, num_classes, trunk='wrn38', criterion=None)
| 39 | stride8 only |
| 40 | """ |
| 41 | def __init__(self, num_classes, trunk='wrn38', criterion=None): |
| 42 | super(DeeperS8, self).__init__() |
| 43 | |
| 44 | self.criterion = criterion |
| 45 | self.trunk, s2_ch, s4_ch, high_level_ch = get_trunk(trunk_name=trunk, |
| 46 | output_stride=8) |
| 47 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, |
| 48 | output_stride=8) |
| 49 | |
| 50 | self.convs2 = nn.Conv2d(s2_ch, 32, kernel_size=1, bias=False) |
| 51 | self.convs4 = nn.Conv2d(s4_ch, 64, kernel_size=1, bias=False) |
| 52 | self.conv_up1 = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) |
| 53 | self.conv_up2 = ConvBnRelu(256 + 64, 256, kernel_size=5, padding=2) |
| 54 | self.conv_up3 = ConvBnRelu(256 + 32, 256, kernel_size=5, padding=2) |
| 55 | self.conv_up5 = nn.Conv2d(256, num_classes, kernel_size=1, bias=False) |
| 56 | |
| 57 | def forward(self, inputs, gts=None): |
| 58 | assert 'images' in inputs |
nothing calls this directly
no test coverage detected