(self,
backbone_name, backbone_file, deploy,
bins=(1, 2, 3, 6), dropout=0.1, classes=2,
zoom_factor=8, use_ppm=True, criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=nn.BatchNorm2d,
pretrained=True)
| 29 | |
| 30 | class PSPNet(nn.Module): |
| 31 | def __init__(self, |
| 32 | backbone_name, backbone_file, deploy, |
| 33 | bins=(1, 2, 3, 6), dropout=0.1, classes=2, |
| 34 | zoom_factor=8, use_ppm=True, criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=nn.BatchNorm2d, |
| 35 | pretrained=True): |
| 36 | super(PSPNet, self).__init__() |
| 37 | assert 2048 % len(bins) == 0 |
| 38 | assert classes > 1 |
| 39 | assert zoom_factor in [1, 2, 4, 8] |
| 40 | self.zoom_factor = zoom_factor |
| 41 | self.use_ppm = use_ppm |
| 42 | self.criterion = criterion |
| 43 | |
| 44 | repvgg_fn = get_RepVGG_func_by_name(backbone_name) |
| 45 | backbone = repvgg_fn(deploy) |
| 46 | if pretrained: |
| 47 | checkpoint = torch.load(backbone_file) |
| 48 | if 'state_dict' in checkpoint: |
| 49 | checkpoint = checkpoint['state_dict'] |
| 50 | ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names |
| 51 | backbone.load_state_dict(ckpt) |
| 52 | |
| 53 | self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3, backbone.stage4 |
| 54 | |
| 55 | # The last two stages should have stride=1 for semantic segmentation |
| 56 | # Note that the stride of 1x1 should be the same as the 3x3 |
| 57 | # Use dilation following the implementation of PSPNet |
| 58 | secondlast_channel = 0 |
| 59 | for n, m in self.layer3.named_modules(): |
| 60 | if ('rbr_dense' in n or 'rbr_reparam' in n) and isinstance(m, nn.Conv2d): |
| 61 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) |
| 62 | print('change dilation, padding, stride of ', n) |
| 63 | secondlast_channel = m.out_channels |
| 64 | elif 'rbr_1x1' in n and isinstance(m, nn.Conv2d): |
| 65 | m.stride = (1, 1) |
| 66 | print('change stride of ', n) |
| 67 | last_channel = 0 |
| 68 | for n, m in self.layer4.named_modules(): |
| 69 | if ('rbr_dense' in n or 'rbr_reparam' in n) and isinstance(m, nn.Conv2d): |
| 70 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) |
| 71 | print('change dilation, padding, stride of ', n) |
| 72 | last_channel = m.out_channels |
| 73 | elif 'rbr_1x1' in n and isinstance(m, nn.Conv2d): |
| 74 | m.stride = (1, 1) |
| 75 | print('change stride of ', n) |
| 76 | |
| 77 | fea_dim = last_channel |
| 78 | aux_in = secondlast_channel |
| 79 | |
| 80 | if use_ppm: |
| 81 | self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins, BatchNorm) |
| 82 | fea_dim *= 2 |
| 83 | |
| 84 | self.cls = nn.Sequential( |
| 85 | nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False), |
| 86 | BatchNorm(512), |
| 87 | nn.ReLU(inplace=True), |
| 88 | nn.Dropout2d(p=dropout), |
nothing calls this directly
no test coverage detected