Save a model checkpoint.
(iteration, model, optimizer,
lr_scheduler, args)
| 186 | |
| 187 | |
| 188 | def save_checkpoint(iteration, model, optimizer, |
| 189 | lr_scheduler, args): |
| 190 | """Save a model checkpoint.""" |
| 191 | if args.deepspeed: |
| 192 | save_ds_checkpoint(iteration, model, lr_scheduler, args) |
| 193 | else: |
| 194 | # Only rank zer0 of the data parallel writes to the disk. |
| 195 | if isinstance(model, torchDDP): |
| 196 | model = model.module |
| 197 | |
| 198 | if mpu.get_data_parallel_rank() == 0: |
| 199 | checkpoint_name = get_checkpoint_name(args.save, iteration) |
| 200 | print('global rank {} is saving checkpoint at iteration {:7d} to {}'. |
| 201 | format(torch.distributed.get_rank(), iteration, checkpoint_name)) |
| 202 | |
| 203 | sd = {} |
| 204 | sd['iteration'] = iteration |
| 205 | sd['module'] = model.state_dict() |
| 206 | |
| 207 | # Optimizer stuff. |
| 208 | if not args.no_save_optim: |
| 209 | if optimizer is not None: |
| 210 | sd['optimizer'] = optimizer.state_dict() |
| 211 | if lr_scheduler is not None: |
| 212 | sd['lr_scheduler'] = lr_scheduler.state_dict() |
| 213 | |
| 214 | # rng states. |
| 215 | if not args.no_save_rng: |
| 216 | sd['random_rng_state'] = random.getstate() |
| 217 | sd['np_rng_state'] = np.random.get_state() |
| 218 | sd['torch_rng_state'] = torch.get_rng_state() |
| 219 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() |
| 220 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() |
| 221 | |
| 222 | ensure_directory_exists(checkpoint_name) |
| 223 | torch.save(sd, checkpoint_name) |
| 224 | print(' successfully saved {}'.format(checkpoint_name)) |
| 225 | |
| 226 | # Wait so everyone is done (necessary) |
| 227 | torch.distributed.barrier() |
| 228 | # And update the latest iteration |
| 229 | if torch.distributed.get_rank() == 0: |
| 230 | tracker_filename = get_checkpoint_tracker_filename(args.save) |
| 231 | with open(tracker_filename, 'w') as f: |
| 232 | f.write(str(iteration)) |
| 233 | # Wait so everyone is done (not necessary) |
| 234 | torch.distributed.barrier() |
| 235 | |
| 236 | |
| 237 | def save_ds_checkpoint(iteration, model, lr_scheduler, args): |
no test coverage detected