Save a model checkpoint.
(iteration, model, optimizer, lr_scheduler, args, tag=None, barrier=True,
only_changed_parameters=False, no_deepspeed=False, no_save_optim=False)
| 222 | |
| 223 | |
| 224 | def save_checkpoint(iteration, model, optimizer, lr_scheduler, args, tag=None, barrier=True, |
| 225 | only_changed_parameters=False, no_deepspeed=False, no_save_optim=False): |
| 226 | """Save a model checkpoint.""" |
| 227 | if tag is None: |
| 228 | tag = str(iteration) |
| 229 | if args.deepspeed and not no_deepspeed: |
| 230 | save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag) |
| 231 | else: |
| 232 | # Only rank zer0 of the data parallel writes to the disk. |
| 233 | |
| 234 | if mpu.get_data_parallel_rank() == 0: |
| 235 | checkpoint_name = get_checkpoint_name(args.save, tag) |
| 236 | print('global rank {} is saving checkpoint at iteration {:7d} to {}'. |
| 237 | format(torch.distributed.get_rank(), iteration, checkpoint_name)) |
| 238 | sd = {'iteration': iteration} |
| 239 | if args.deepspeed: |
| 240 | model = model.module |
| 241 | state_dict = model.state_dict() |
| 242 | if only_changed_parameters: |
| 243 | requires_grad_dict = {} |
| 244 | for name, parameter in model.named_parameters(): |
| 245 | requires_grad_dict[name] = parameter.requires_grad |
| 246 | state_dict = {key: value for key, value in state_dict.items() if requires_grad_dict[key]} |
| 247 | sd['module'] = state_dict |
| 248 | |
| 249 | # Optimizer stuff. |
| 250 | if not args.no_save_optim and not no_save_optim: |
| 251 | if optimizer is not None: |
| 252 | sd['optimizer'] = optimizer.state_dict() |
| 253 | if lr_scheduler is not None: |
| 254 | sd['lr_scheduler'] = lr_scheduler.state_dict() |
| 255 | |
| 256 | # rng states. |
| 257 | if not args.no_save_rng: |
| 258 | sd['random_rng_state'] = random.getstate() |
| 259 | sd['np_rng_state'] = np.random.get_state() |
| 260 | sd['torch_rng_state'] = torch.get_rng_state() |
| 261 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() |
| 262 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() |
| 263 | |
| 264 | ensure_directory_exists(checkpoint_name) |
| 265 | torch.save(sd, checkpoint_name) |
| 266 | print(' successfully saved {}'.format(checkpoint_name)) |
| 267 | |
| 268 | # Wait so everyone is done (necessary) |
| 269 | if barrier: |
| 270 | torch.distributed.barrier() |
| 271 | # And update the latest iteration |
| 272 | if torch.distributed.get_rank() == 0: |
| 273 | tracker_filename = get_checkpoint_tracker_filename(args.save) |
| 274 | with open(tracker_filename, 'w') as f: |
| 275 | f.write(tag) |
| 276 | |
| 277 | |
| 278 | def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag): |
no test coverage detected