| 29 | |
| 30 | |
| 31 | def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): |
| 32 | global_rank = dist.get_rank() |
| 33 | logger.info(f"==============> Rank[{global_rank}] Resuming form {config.MODEL.RESUME}....................") |
| 34 | if config.MODEL.RESUME.endswith(f'.pth'): |
| 35 | if config.TRAIN.MOE.SAVE_MASTER: |
| 36 | resume_path = config.MODEL.RESUME + f'.global' |
| 37 | else: |
| 38 | resume_path = config.MODEL.RESUME + f'.rank{global_rank}' |
| 39 | logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {resume_path}......") |
| 40 | else: |
| 41 | resume_path = config.MODEL.RESUME |
| 42 | |
| 43 | checkpoint = torch.load(resume_path, map_location='cpu') |
| 44 | msg = model.load_state_dict(checkpoint['model'], strict=False) |
| 45 | logger.info(msg) |
| 46 | max_accuracy = 0.0 |
| 47 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: |
| 48 | optimizer.load_state_dict(checkpoint['optimizer']) |
| 49 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| 50 | config.defrost() |
| 51 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
| 52 | config.freeze() |
| 53 | if 'scaler' in checkpoint: |
| 54 | loss_scaler.load_state_dict(checkpoint['scaler']) |
| 55 | logger.info(f"=>Rank[{global_rank}] loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
| 56 | if 'max_accuracy' in checkpoint: |
| 57 | max_accuracy = checkpoint['max_accuracy'] |
| 58 | |
| 59 | del checkpoint |
| 60 | torch.cuda.empty_cache() |
| 61 | return max_accuracy |
| 62 | |
| 63 | |
| 64 | def load_pretrained(config, model, logger): |