MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / save_ds_checkpoint

Function save_ds_checkpoint

Megatron-LM/utils.py:228–241  ·  view source on GitHub ↗

Save a model checkpoint.

(iteration, model, args)

Source from the content-addressed store, hash-verified

226 torch.distributed.barrier()
227
228def save_ds_checkpoint(iteration, model, args):
229 """Save a model checkpoint."""
230
231 sd = {}
232 sd['iteration'] = iteration
233 # rng states.
234 if not args.no_save_rng:
235 sd['random_rng_state'] = random.getstate()
236 sd['np_rng_state'] = np.random.get_state()
237 sd['torch_rng_state'] = torch.get_rng_state()
238 sd['cuda_rng_state'] = torch.cuda.get_rng_state()
239 sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
240
241 model.save_checkpoint(args.save, iteration, client_state = sd)
242
243
244def get_checkpoint_iteration(args):

Callers 1

save_checkpointFunction · 0.85

Calls 1

get_statesMethod · 0.80

Tested by

no test coverage detected