| 176 | logger.info(f"{save_path} saved !!!") |
| 177 | |
| 178 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, is_best=False, model_ema=None): |
| 179 | save_state = {'model': model.state_dict(), |
| 180 | 'optimizer': optimizer.state_dict(), |
| 181 | 'lr_scheduler': lr_scheduler.state_dict(), |
| 182 | 'max_accuracy': max_accuracy, |
| 183 | 'epoch': epoch, |
| 184 | 'config': config} |
| 185 | if config.AMP_OPT_LEVEL != "O0": |
| 186 | save_state['amp'] = amp.state_dict() |
| 187 | if model_ema is not None: |
| 188 | save_state['ema'] = unwrap_model(model_ema).state_dict() |
| 189 | |
| 190 | if is_best: |
| 191 | best_path = os.path.join(config.OUTPUT, 'best_ckpt.pth') |
| 192 | torch.save(save_state, best_path) |
| 193 | |
| 194 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
| 195 | logger.info(f"{save_path} saving......") |
| 196 | torch.save(save_state, save_path) |
| 197 | logger.info(f"{save_path} saved !!!") |
| 198 | |
| 199 | |
| 200 | def get_grad_norm(parameters, norm_type=2): |