| 243 | |
| 244 | |
| 245 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger): |
| 246 | save_state = {'model': model.state_dict(), |
| 247 | 'optimizer': optimizer.state_dict(), |
| 248 | 'lr_scheduler': lr_scheduler.state_dict(), |
| 249 | 'max_accuracy': max_accuracy, |
| 250 | 'scaler': loss_scaler.state_dict(), |
| 251 | 'epoch': epoch, |
| 252 | 'config': config} |
| 253 | |
| 254 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
| 255 | logger.info(f"{save_path} saving......") |
| 256 | torch.save(save_state, save_path) |
| 257 | logger.info(f"{save_path} saved !!!") |
| 258 | |
| 259 | |
| 260 | def auto_resume_helper(output_dir): |