MCPcopy
hub / github.com/zai-org/CogView / save_ds_checkpoint

Function save_ds_checkpoint

utils.py:237–252  ·  view source on GitHub ↗

Save a model checkpoint.

(iteration, model, lr_scheduler, args)

Source from the content-addressed store, hash-verified

235
236
237def save_ds_checkpoint(iteration, model, lr_scheduler, args):
238 """Save a model checkpoint."""
239
240 sd = {}
241 sd['iteration'] = iteration
242 if lr_scheduler is not None:
243 sd['client_lr_scheduler'] = lr_scheduler.state_dict()
244 # rng states.
245 if not args.no_save_rng:
246 sd['random_rng_state'] = random.getstate()
247 sd['np_rng_state'] = np.random.get_state()
248 sd['torch_rng_state'] = torch.get_rng_state()
249 sd['cuda_rng_state'] = torch.cuda.get_rng_state()
250 sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
251
252 model.save_checkpoint(args.save, str(iteration), client_state=sd)
253
254
255def get_checkpoint_iteration(args):

Callers 1

save_checkpointFunction · 0.85

Calls 2

get_statesMethod · 0.80
state_dictMethod · 0.45

Tested by

no test coverage detected