(config, model, optimizer, lr_scheduler, scaler, logger)
| 14 | |
| 15 | |
| 16 | def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): |
| 17 | logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") |
| 18 | if config.MODEL.RESUME.startswith('https'): |
| 19 | checkpoint = torch.hub.load_state_dict_from_url( |
| 20 | config.MODEL.RESUME, map_location='cpu', check_hash=True) |
| 21 | else: |
| 22 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
| 23 | |
| 24 | # re-map keys due to name change (only for loading provided models) |
| 25 | rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] |
| 26 | for k in rpe_mlp_keys: |
| 27 | checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) |
| 28 | |
| 29 | msg = model.load_state_dict(checkpoint['model'], strict=False) |
| 30 | logger.info(msg) |
| 31 | |
| 32 | max_accuracy = 0.0 |
| 33 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'scaler' in checkpoint and 'epoch' in checkpoint: |
| 34 | optimizer.load_state_dict(checkpoint['optimizer']) |
| 35 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| 36 | scaler.load_state_dict(checkpoint['scaler']) |
| 37 | |
| 38 | config.defrost() |
| 39 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
| 40 | config.freeze() |
| 41 | |
| 42 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
| 43 | if 'max_accuracy' in checkpoint: |
| 44 | max_accuracy = checkpoint['max_accuracy'] |
| 45 | else: |
| 46 | max_accuracy = 0.0 |
| 47 | |
| 48 | del checkpoint |
| 49 | torch.cuda.empty_cache() |
| 50 | return max_accuracy |
| 51 | |
| 52 | |
| 53 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger): |
no test coverage detected