(self, args, device, distributed: bool = False, local_rank: int = 0)
| 33 | |
| 34 | class SegDetectorModel(nn.Module): |
| 35 | def __init__(self, args, device, distributed: bool = False, local_rank: int = 0): |
| 36 | super(SegDetectorModel, self).__init__() |
| 37 | #from decoders.seg_detector_loss import SegDetectorLossBuilder |
| 38 | from ..decoders.seg_detector_loss import SegDetectorLossBuilder |
| 39 | |
| 40 | self.model = BasicModel(args) |
| 41 | # for loading models |
| 42 | self.model = parallelize(self.model, distributed, local_rank) |
| 43 | self.criterion = SegDetectorLossBuilder( |
| 44 | args['loss_class'], *args.get('loss_args', []), **args.get('loss_kwargs', {})).build() |
| 45 | self.criterion = parallelize(self.criterion, distributed, local_rank) |
| 46 | self.device = device |
| 47 | self.to(self.device) |
| 48 | |
| 49 | @staticmethod |
| 50 | def model_name(args): |
nothing calls this directly
no test coverage detected