| 159 | print('=================== loaded from', path) |
| 160 | |
| 161 | def save_latest(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, model_ema=None): |
| 162 | save_state = {'model': model.state_dict(), |
| 163 | 'optimizer': optimizer.state_dict(), |
| 164 | 'lr_scheduler': lr_scheduler.state_dict(), |
| 165 | 'max_accuracy': max_accuracy, |
| 166 | 'epoch': epoch, |
| 167 | 'config': config} |
| 168 | if config.AMP_OPT_LEVEL != "O0": |
| 169 | save_state['amp'] = amp.state_dict() |
| 170 | if model_ema is not None: |
| 171 | save_state['ema'] = unwrap_model(model_ema).state_dict() |
| 172 | |
| 173 | save_path = os.path.join(config.OUTPUT, 'latest.pth') |
| 174 | logger.info(f"{save_path} saving......") |
| 175 | torch.save(save_state, save_path) |
| 176 | logger.info(f"{save_path} saved !!!") |
| 177 | |
| 178 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, is_best=False, model_ema=None): |
| 179 | save_state = {'model': model.state_dict(), |