Utility function for checkpointing model + optimizer dictionaries The main purpose for this is to be able to resume training from that instant again
(PATH, ckpt_id, model, epoch, last_global_step, last_global_data_samples, **kwargs)
| 38 | last_global_step_from_restore = 0 |
| 39 | |
| 40 | def checkpoint_model(PATH, ckpt_id, model, epoch, last_global_step, last_global_data_samples, **kwargs): |
| 41 | """Utility function for checkpointing model + optimizer dictionaries |
| 42 | The main purpose for this is to be able to resume training from that instant again |
| 43 | """ |
| 44 | checkpoint_state_dict = {'epoch': epoch, |
| 45 | 'last_global_step': last_global_step, |
| 46 | 'last_global_data_samples': last_global_data_samples} |
| 47 | # Add extra kwargs too |
| 48 | checkpoint_state_dict.update(kwargs) |
| 49 | |
| 50 | success = model.network.save_checkpoint(PATH, ckpt_id, checkpoint_state_dict) |
| 51 | status_msg = 'checkpointing: PATH={}, ckpt_id={}'.format(PATH, ckpt_id) |
| 52 | if success: |
| 53 | logging.info(f"Success {status_msg}") |
| 54 | else: |
| 55 | logging.warning(f"Failure {status_msg}") |
| 56 | return |
| 57 | |
| 58 | |
| 59 | def load_training_checkpoint(args, model, PATH, ckpt_id): |