| 16 | |
| 17 | |
| 18 | def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): |
| 19 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") |
| 20 | if config.MODEL.RESUME.startswith('https'): |
| 21 | checkpoint = torch.hub.load_state_dict_from_url( |
| 22 | config.MODEL.RESUME, map_location='cpu', check_hash=True) |
| 23 | else: |
| 24 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
| 25 | msg = model.load_state_dict(checkpoint['model'], strict=False) |
| 26 | logger.info(msg) |
| 27 | max_accuracy = 0.0 |
| 28 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: |
| 29 | optimizer.load_state_dict(checkpoint['optimizer']) |
| 30 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| 31 | config.defrost() |
| 32 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
| 33 | config.freeze() |
| 34 | if 'scaler' in checkpoint: |
| 35 | loss_scaler.load_state_dict(checkpoint['scaler']) |
| 36 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
| 37 | if 'max_accuracy' in checkpoint: |
| 38 | max_accuracy = checkpoint['max_accuracy'] |
| 39 | |
| 40 | del checkpoint |
| 41 | torch.cuda.empty_cache() |
| 42 | return max_accuracy |
| 43 | |
| 44 | |
| 45 | def load_pretrained(config, model, logger): |