(self, num_classes, trunk='hrnetv2', criterion=None)
| 127 | OCR net |
| 128 | """ |
| 129 | def __init__(self, num_classes, trunk='hrnetv2', criterion=None): |
| 130 | super(OCRNetASPP, self).__init__() |
| 131 | self.criterion = criterion |
| 132 | self.backbone, _, _, high_level_ch = get_trunk(trunk) |
| 133 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, |
| 134 | bottleneck_ch=256, |
| 135 | output_stride=8) |
| 136 | self.ocr = OCR_block(aspp_out_ch) |
| 137 | |
| 138 | def forward(self, inputs): |
| 139 | assert 'images' in inputs |