MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / save_checkpoint

Function save_checkpoint

Megatron-LM/utils.py:179–226  ·  view source on GitHub ↗

Save a model checkpoint.

(iteration, model, optimizer,
                    lr_scheduler, args)

Source from the content-addressed store, hash-verified

177 print(' successfully saved {}'.format(zero_checkpoint_name))
178
179def save_checkpoint(iteration, model, optimizer,
180 lr_scheduler, args):
181 """Save a model checkpoint."""
182 if args.deepspeed:
183 save_ds_checkpoint(iteration, model, args)
184 else:
185 # Only rank zer0 of the data parallel writes to the disk.
186 if isinstance(model, torchDDP):
187 model = model.module
188
189 if mpu.get_data_parallel_rank() == 0:
190 checkpoint_name = get_checkpoint_name(args.save, iteration)
191 print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
192 format(torch.distributed.get_rank(), iteration, checkpoint_name))
193
194 sd = {}
195 sd['iteration'] = iteration
196 sd['model'] = model.state_dict()
197
198 # Optimizer stuff.
199 if not args.no_save_optim:
200 if optimizer is not None:
201 sd['optimizer'] = optimizer.state_dict()
202 if lr_scheduler is not None:
203 sd['lr_scheduler'] = lr_scheduler.state_dict()
204
205 # rng states.
206 if not args.no_save_rng:
207 sd['random_rng_state'] = random.getstate()
208 sd['np_rng_state'] = np.random.get_state()
209 sd['torch_rng_state'] = torch.get_rng_state()
210 sd['cuda_rng_state'] = torch.cuda.get_rng_state()
211 sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
212
213
214 ensure_directory_exists(checkpoint_name)
215 torch.save(sd, checkpoint_name)
216 print(' successfully saved {}'.format(checkpoint_name))
217
218 # Wait so everyone is done (necessary)
219 torch.distributed.barrier()
220 # And update the latest iteration
221 if torch.distributed.get_rank() == 0:
222 tracker_filename = get_checkpoint_tracker_filename(args.save)
223 with open(tracker_filename, 'w') as f:
224 f.write(str(iteration))
225 # Wait so everyone is done (not necessary)
226 torch.distributed.barrier()
227
228def save_ds_checkpoint(iteration, model, args):
229 """Save a model checkpoint."""

Callers 4

trainFunction · 0.90
mainFunction · 0.90
trainFunction · 0.90
mainFunction · 0.90

Calls 8

save_ds_checkpointFunction · 0.85
get_checkpoint_nameFunction · 0.85
ensure_directory_existsFunction · 0.85
get_statesMethod · 0.80
state_dictMethod · 0.45
saveMethod · 0.45
writeMethod · 0.45

Tested by

no test coverage detected