Save a model checkpoint.
(iteration, model, optimizer,
lr_scheduler, args)
| 177 | print(' successfully saved {}'.format(zero_checkpoint_name)) |
| 178 | |
| 179 | def save_checkpoint(iteration, model, optimizer, |
| 180 | lr_scheduler, args): |
| 181 | """Save a model checkpoint.""" |
| 182 | if args.deepspeed: |
| 183 | save_ds_checkpoint(iteration, model, args) |
| 184 | else: |
| 185 | # Only rank zer0 of the data parallel writes to the disk. |
| 186 | if isinstance(model, torchDDP): |
| 187 | model = model.module |
| 188 | |
| 189 | if mpu.get_data_parallel_rank() == 0: |
| 190 | checkpoint_name = get_checkpoint_name(args.save, iteration) |
| 191 | print('global rank {} is saving checkpoint at iteration {:7d} to {}'. |
| 192 | format(torch.distributed.get_rank(), iteration, checkpoint_name)) |
| 193 | |
| 194 | sd = {} |
| 195 | sd['iteration'] = iteration |
| 196 | sd['model'] = model.state_dict() |
| 197 | |
| 198 | # Optimizer stuff. |
| 199 | if not args.no_save_optim: |
| 200 | if optimizer is not None: |
| 201 | sd['optimizer'] = optimizer.state_dict() |
| 202 | if lr_scheduler is not None: |
| 203 | sd['lr_scheduler'] = lr_scheduler.state_dict() |
| 204 | |
| 205 | # rng states. |
| 206 | if not args.no_save_rng: |
| 207 | sd['random_rng_state'] = random.getstate() |
| 208 | sd['np_rng_state'] = np.random.get_state() |
| 209 | sd['torch_rng_state'] = torch.get_rng_state() |
| 210 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() |
| 211 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() |
| 212 | |
| 213 | |
| 214 | ensure_directory_exists(checkpoint_name) |
| 215 | torch.save(sd, checkpoint_name) |
| 216 | print(' successfully saved {}'.format(checkpoint_name)) |
| 217 | |
| 218 | # Wait so everyone is done (necessary) |
| 219 | torch.distributed.barrier() |
| 220 | # And update the latest iteration |
| 221 | if torch.distributed.get_rank() == 0: |
| 222 | tracker_filename = get_checkpoint_tracker_filename(args.save) |
| 223 | with open(tracker_filename, 'w') as f: |
| 224 | f.write(str(iteration)) |
| 225 | # Wait so everyone is done (not necessary) |
| 226 | torch.distributed.barrier() |
| 227 | |
| 228 | def save_ds_checkpoint(iteration, model, args): |
| 229 | """Save a model checkpoint.""" |
no test coverage detected