Save a model checkpoint.
(iteration, model, args)
| 226 | torch.distributed.barrier() |
| 227 | |
| 228 | def save_ds_checkpoint(iteration, model, args): |
| 229 | """Save a model checkpoint.""" |
| 230 | |
| 231 | sd = {} |
| 232 | sd['iteration'] = iteration |
| 233 | # rng states. |
| 234 | if not args.no_save_rng: |
| 235 | sd['random_rng_state'] = random.getstate() |
| 236 | sd['np_rng_state'] = np.random.get_state() |
| 237 | sd['torch_rng_state'] = torch.get_rng_state() |
| 238 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() |
| 239 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() |
| 240 | |
| 241 | model.save_checkpoint(args.save, iteration, client_state = sd) |
| 242 | |
| 243 | |
| 244 | def get_checkpoint_iteration(args): |
no test coverage detected