(self, in_channels, key_channels, out_channels, scale=1,
dropout=0.1)
| 126 | for each pixel. |
| 127 | """ |
| 128 | def __init__(self, in_channels, key_channels, out_channels, scale=1, |
| 129 | dropout=0.1): |
| 130 | super(SpatialOCR_Module, self).__init__() |
| 131 | self.object_context_block = ObjectAttentionBlock(in_channels, |
| 132 | key_channels, |
| 133 | scale) |
| 134 | if cfg.MODEL.OCR_ASPP: |
| 135 | self.aspp, aspp_out_ch = get_aspp( |
| 136 | in_channels, bottleneck_ch=cfg.MODEL.ASPP_BOT_CH, |
| 137 | output_stride=8) |
| 138 | _in_channels = 2 * in_channels + aspp_out_ch |
| 139 | else: |
| 140 | _in_channels = 2 * in_channels |
| 141 | |
| 142 | self.conv_bn_dropout = nn.Sequential( |
| 143 | nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, |
| 144 | bias=False), |
| 145 | BNReLU(out_channels), |
| 146 | nn.Dropout2d(dropout) |
| 147 | ) |
| 148 | |
| 149 | def forward(self, feats, proxy_feats): |
| 150 | context = self.object_context_block(feats, proxy_feats) |
no test coverage detected