Builds 2D feature extractor backbone network from Detectron2.
(cfg)
| 12 | |
| 13 | |
| 14 | def build_backbone2d(cfg): |
| 15 | """ Builds 2D feature extractor backbone network from Detectron2.""" |
| 16 | |
| 17 | output_dim = cfg.MODEL.BACKBONE3D.CHANNELS[0] |
| 18 | norm = cfg.MODEL.FPN.NORM |
| 19 | output_stride = 4 # TODO: make configurable |
| 20 | |
| 21 | backbone = d2_build_backbone(cfg) |
| 22 | feature_extractor = FPNFeature( |
| 23 | backbone.output_shape(), output_dim, output_stride, norm) |
| 24 | |
| 25 | # load pretrained backbone |
| 26 | if cfg.MODEL.BACKBONE.WEIGHTS: |
| 27 | state_dict = torch.load(cfg.MODEL.BACKBONE.WEIGHTS) |
| 28 | backbone.load_state_dict(state_dict) |
| 29 | |
| 30 | return nn.Sequential(backbone, feature_extractor), output_stride |
| 31 | |
| 32 | |
| 33 | class FPNFeature(nn.Module): |
no test coverage detected