Save a model checkpoint.
(iteration, model, lr_scheduler, args, tag)
| 276 | |
| 277 | |
| 278 | def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag): |
| 279 | """Save a model checkpoint.""" |
| 280 | |
| 281 | sd = {} |
| 282 | sd['iteration'] = iteration |
| 283 | if lr_scheduler is not None: |
| 284 | sd['client_lr_scheduler'] = lr_scheduler.state_dict() |
| 285 | # rng states. |
| 286 | if not args.no_save_rng: |
| 287 | sd['random_rng_state'] = random.getstate() |
| 288 | sd['np_rng_state'] = np.random.get_state() |
| 289 | sd['torch_rng_state'] = torch.get_rng_state() |
| 290 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() |
| 291 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() |
| 292 | model.save_checkpoint(args.save, tag, client_state=sd) |
| 293 | |
| 294 | |
| 295 | def get_checkpoint_iteration(load_path): |
no test coverage detected