(exp, args)
| 299 | |
| 300 | |
| 301 | def main(exp, args): |
| 302 | if not args.experiment_name: |
| 303 | args.experiment_name = exp.exp_name |
| 304 | |
| 305 | output_dir = osp.join(exp.output_dir, args.experiment_name) |
| 306 | os.makedirs(output_dir, exist_ok=True) |
| 307 | |
| 308 | if args.save_result: |
| 309 | vis_folder = osp.join(output_dir, "track_vis") |
| 310 | os.makedirs(vis_folder, exist_ok=True) |
| 311 | |
| 312 | if args.trt: |
| 313 | args.device = "gpu" |
| 314 | args.device = torch.device("cuda" if args.device == "gpu" else "cpu") |
| 315 | |
| 316 | logger.info("Args: {}".format(args)) |
| 317 | |
| 318 | if args.conf is not None: |
| 319 | exp.test_conf = args.conf |
| 320 | if args.nms is not None: |
| 321 | exp.nmsthre = args.nms |
| 322 | if args.tsize is not None: |
| 323 | exp.test_size = (args.tsize, args.tsize) |
| 324 | |
| 325 | model = exp.get_model().to(args.device) |
| 326 | logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size))) |
| 327 | model.eval() |
| 328 | |
| 329 | if not args.trt: |
| 330 | if args.ckpt is None: |
| 331 | ckpt_file = osp.join(output_dir, "best_ckpt.pth.tar") |
| 332 | else: |
| 333 | ckpt_file = args.ckpt |
| 334 | logger.info("loading checkpoint") |
| 335 | ckpt = torch.load(ckpt_file, map_location="cpu") |
| 336 | # load the model state dict |
| 337 | model.load_state_dict(ckpt["model"]) |
| 338 | logger.info("loaded checkpoint done.") |
| 339 | |
| 340 | if args.fuse: |
| 341 | logger.info("\tFusing model...") |
| 342 | model = fuse_model(model) |
| 343 | |
| 344 | if args.fp16: |
| 345 | model = model.half() # to FP16 |
| 346 | |
| 347 | if args.trt: |
| 348 | assert not args.fuse, "TensorRT model is not support model fusing!" |
| 349 | trt_file = osp.join(output_dir, "model_trt.pth") |
| 350 | assert osp.exists( |
| 351 | trt_file |
| 352 | ), "TensorRT model is not found!\n Run python3 tools/trt.py first!" |
| 353 | model.head.decode_in_inference = False |
| 354 | decoder = model.head.decode_outputs |
| 355 | logger.info("Using TensorRT to inference") |
| 356 | else: |
| 357 | trt_file = None |
| 358 | decoder = None |
no test coverage detected