Loads the yolo model from file. :param model_path: Path to model definition file (.cfg) :type model_path: str :param weights_path: Path to weights or checkpoint file (.weights or .pth) :type weights_path: str :return: Returns model :rtype: Darknet
(model_path, weights_path=None)
| 318 | |
| 319 | |
| 320 | def load_model(model_path, weights_path=None): |
| 321 | """Loads the yolo model from file. |
| 322 | |
| 323 | :param model_path: Path to model definition file (.cfg) |
| 324 | :type model_path: str |
| 325 | :param weights_path: Path to weights or checkpoint file (.weights or .pth) |
| 326 | :type weights_path: str |
| 327 | :return: Returns model |
| 328 | :rtype: Darknet |
| 329 | """ |
| 330 | device = torch.device("cuda" if torch.cuda.is_available() |
| 331 | else "cpu") # Select device for inference |
| 332 | model = Darknet(model_path).to(device) |
| 333 | |
| 334 | model.apply(weights_init_normal) |
| 335 | |
| 336 | # If pretrained weights are specified, start from checkpoint or weight file |
| 337 | if weights_path: |
| 338 | if weights_path.endswith(".pth"): |
| 339 | # Load checkpoint weights |
| 340 | model.load_state_dict(torch.load(weights_path, map_location=device)) |
| 341 | else: |
| 342 | # Load darknet weights |
| 343 | model.load_darknet_weights(weights_path) |
| 344 | return model |