Save a model checkpoint.
(iteration, model, lr_scheduler, args)
| 235 | |
| 236 | |
| 237 | def save_ds_checkpoint(iteration, model, lr_scheduler, args): |
| 238 | """Save a model checkpoint.""" |
| 239 | |
| 240 | sd = {} |
| 241 | sd['iteration'] = iteration |
| 242 | if lr_scheduler is not None: |
| 243 | sd['client_lr_scheduler'] = lr_scheduler.state_dict() |
| 244 | # rng states. |
| 245 | if not args.no_save_rng: |
| 246 | sd['random_rng_state'] = random.getstate() |
| 247 | sd['np_rng_state'] = np.random.get_state() |
| 248 | sd['torch_rng_state'] = torch.get_rng_state() |
| 249 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() |
| 250 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() |
| 251 | |
| 252 | model.save_checkpoint(args.save, str(iteration), client_state=sd) |
| 253 | |
| 254 | |
| 255 | def get_checkpoint_iteration(args): |
no test coverage detected