operations performed: 1x1 x depth 3x3 x depth dilation 6 3x3 x depth dilation 12 3x3 x depth dilation 18 image pooling concatenate all together Final 1x1 conv
| 160 | |
| 161 | |
| 162 | class AtrousSpatialPyramidPoolingModule(nn.Module): |
| 163 | """ |
| 164 | operations performed: |
| 165 | 1x1 x depth |
| 166 | 3x3 x depth dilation 6 |
| 167 | 3x3 x depth dilation 12 |
| 168 | 3x3 x depth dilation 18 |
| 169 | image pooling |
| 170 | concatenate all together |
| 171 | Final 1x1 conv |
| 172 | """ |
| 173 | |
| 174 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, |
| 175 | rates=(6, 12, 18)): |
| 176 | super(AtrousSpatialPyramidPoolingModule, self).__init__() |
| 177 | |
| 178 | if output_stride == 8: |
| 179 | rates = [2 * r for r in rates] |
| 180 | elif output_stride == 16: |
| 181 | pass |
| 182 | else: |
| 183 | raise 'output stride of {} not supported'.format(output_stride) |
| 184 | |
| 185 | self.features = [] |
| 186 | # 1x1 |
| 187 | self.features.append( |
| 188 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, |
| 189 | bias=False), |
| 190 | Norm2d(reduction_dim), nn.ReLU(inplace=True))) |
| 191 | # other rates |
| 192 | for r in rates: |
| 193 | self.features.append(nn.Sequential( |
| 194 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, |
| 195 | dilation=r, padding=r, bias=False), |
| 196 | Norm2d(reduction_dim), |
| 197 | nn.ReLU(inplace=True) |
| 198 | )) |
| 199 | self.features = nn.ModuleList(self.features) |
| 200 | |
| 201 | # img level features |
| 202 | self.img_pooling = nn.AdaptiveAvgPool2d(1) |
| 203 | self.img_conv = nn.Sequential( |
| 204 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), |
| 205 | Norm2d(reduction_dim), nn.ReLU(inplace=True)) |
| 206 | |
| 207 | def forward(self, x): |
| 208 | x_size = x.size() |
| 209 | |
| 210 | img_features = self.img_pooling(x) |
| 211 | img_features = self.img_conv(img_features) |
| 212 | img_features = Upsample(img_features, x_size[2:]) |
| 213 | out = img_features |
| 214 | |
| 215 | for f in self.features: |
| 216 | y = f(x) |
| 217 | out = torch.cat((out, y), 1) |
| 218 | return out |
| 219 |