(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None)
| 445 | |
| 446 | |
| 447 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): |
| 448 | output_dir = Path(args.output_dir) |
| 449 | if args.auto_resume and len(args.resume) == 0: |
| 450 | import glob |
| 451 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) |
| 452 | latest_ckpt = -1 |
| 453 | for ckpt in all_checkpoints: |
| 454 | t = ckpt.split('-')[-1].split('.')[0] |
| 455 | if t.isdigit(): |
| 456 | latest_ckpt = max(int(t), latest_ckpt) |
| 457 | if latest_ckpt >= 0: |
| 458 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) |
| 459 | print("Auto resume checkpoint: %s" % args.resume) |
| 460 | |
| 461 | if args.resume: |
| 462 | if args.resume.startswith('https'): |
| 463 | checkpoint = torch.hub.load_state_dict_from_url( |
| 464 | args.resume, map_location='cpu', check_hash=True) |
| 465 | else: |
| 466 | checkpoint = torch.load(args.resume, map_location='cpu') |
| 467 | model_without_ddp.load_state_dict(checkpoint['model']) |
| 468 | print("Resume checkpoint %s" % args.resume) |
| 469 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: |
| 470 | optimizer.load_state_dict(checkpoint['optimizer']) |
| 471 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' |
| 472 | args.start_epoch = checkpoint['epoch'] + 1 |
| 473 | else: |
| 474 | assert args.eval, 'Does not support resuming with checkpoint-best' |
| 475 | if hasattr(args, 'model_ema') and args.model_ema: |
| 476 | if 'model_ema' in checkpoint.keys(): |
| 477 | model_ema.ema.load_state_dict(checkpoint['model_ema']) |
| 478 | else: |
| 479 | model_ema.ema.load_state_dict(checkpoint['model']) |
| 480 | if 'scaler' in checkpoint: |
| 481 | loss_scaler.load_state_dict(checkpoint['scaler']) |
| 482 | print("With optim & sched!") |
| 483 | |
| 484 | |
| 485 |
nothing calls this directly
no test coverage detected