(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger)
| 51 | |
| 52 | |
| 53 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger): |
| 54 | save_state = {'model': model.state_dict(), |
| 55 | 'optimizer': optimizer.state_dict(), |
| 56 | 'lr_scheduler': lr_scheduler.state_dict(), |
| 57 | 'scaler': scaler.state_dict(), |
| 58 | 'max_accuracy': max_accuracy, |
| 59 | 'epoch': epoch, |
| 60 | 'config': config} |
| 61 | |
| 62 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
| 63 | logger.info(f"{save_path} saving......") |
| 64 | torch.save(save_state, save_path) |
| 65 | logger.info(f"{save_path} saved !!!") |
| 66 | |
| 67 | |
| 68 | def get_grad_norm(parameters, norm_type=2): |
no test coverage detected