MCPcopy Index your code
hub / github.com/THUDM/GLM / save_checkpoint

Function save_checkpoint

utils.py:224–275  ·  view source on GitHub ↗

Save a model checkpoint.

(iteration, model, optimizer, lr_scheduler, args, tag=None, barrier=True,
                    only_changed_parameters=False, no_deepspeed=False, no_save_optim=False)

Source from the content-addressed store, hash-verified

222
223
224def save_checkpoint(iteration, model, optimizer, lr_scheduler, args, tag=None, barrier=True,
225 only_changed_parameters=False, no_deepspeed=False, no_save_optim=False):
226 """Save a model checkpoint."""
227 if tag is None:
228 tag = str(iteration)
229 if args.deepspeed and not no_deepspeed:
230 save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag)
231 else:
232 # Only rank zer0 of the data parallel writes to the disk.
233
234 if mpu.get_data_parallel_rank() == 0:
235 checkpoint_name = get_checkpoint_name(args.save, tag)
236 print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
237 format(torch.distributed.get_rank(), iteration, checkpoint_name))
238 sd = {'iteration': iteration}
239 if args.deepspeed:
240 model = model.module
241 state_dict = model.state_dict()
242 if only_changed_parameters:
243 requires_grad_dict = {}
244 for name, parameter in model.named_parameters():
245 requires_grad_dict[name] = parameter.requires_grad
246 state_dict = {key: value for key, value in state_dict.items() if requires_grad_dict[key]}
247 sd['module'] = state_dict
248
249 # Optimizer stuff.
250 if not args.no_save_optim and not no_save_optim:
251 if optimizer is not None:
252 sd['optimizer'] = optimizer.state_dict()
253 if lr_scheduler is not None:
254 sd['lr_scheduler'] = lr_scheduler.state_dict()
255
256 # rng states.
257 if not args.no_save_rng:
258 sd['random_rng_state'] = random.getstate()
259 sd['np_rng_state'] = np.random.get_state()
260 sd['torch_rng_state'] = torch.get_rng_state()
261 sd['cuda_rng_state'] = torch.cuda.get_rng_state()
262 sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
263
264 ensure_directory_exists(checkpoint_name)
265 torch.save(sd, checkpoint_name)
266 print(' successfully saved {}'.format(checkpoint_name))
267
268 # Wait so everyone is done (necessary)
269 if barrier:
270 torch.distributed.barrier()
271 # And update the latest iteration
272 if torch.distributed.get_rank() == 0:
273 tracker_filename = get_checkpoint_tracker_filename(args.save)
274 with open(tracker_filename, 'w') as f:
275 f.write(tag)
276
277
278def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag):

Callers 4

_trainFunction · 0.90
trainFunction · 0.90
save_on_exitFunction · 0.90
mainFunction · 0.90

Calls 8

save_ds_checkpointFunction · 0.85
get_checkpoint_nameFunction · 0.85
ensure_directory_existsFunction · 0.85
get_statesMethod · 0.80
state_dictMethod · 0.45
named_parametersMethod · 0.45
writeMethod · 0.45

Tested by

no test coverage detected