MCPcopy
hub / github.com/zai-org/CogView / save_checkpoint

Function save_checkpoint

utils.py:188–234  ·  view source on GitHub ↗

Save a model checkpoint.

(iteration, model, optimizer,
                    lr_scheduler, args)

Source from the content-addressed store, hash-verified

186
187
188def save_checkpoint(iteration, model, optimizer,
189 lr_scheduler, args):
190 """Save a model checkpoint."""
191 if args.deepspeed:
192 save_ds_checkpoint(iteration, model, lr_scheduler, args)
193 else:
194 # Only rank zer0 of the data parallel writes to the disk.
195 if isinstance(model, torchDDP):
196 model = model.module
197
198 if mpu.get_data_parallel_rank() == 0:
199 checkpoint_name = get_checkpoint_name(args.save, iteration)
200 print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
201 format(torch.distributed.get_rank(), iteration, checkpoint_name))
202
203 sd = {}
204 sd['iteration'] = iteration
205 sd['module'] = model.state_dict()
206
207 # Optimizer stuff.
208 if not args.no_save_optim:
209 if optimizer is not None:
210 sd['optimizer'] = optimizer.state_dict()
211 if lr_scheduler is not None:
212 sd['lr_scheduler'] = lr_scheduler.state_dict()
213
214 # rng states.
215 if not args.no_save_rng:
216 sd['random_rng_state'] = random.getstate()
217 sd['np_rng_state'] = np.random.get_state()
218 sd['torch_rng_state'] = torch.get_rng_state()
219 sd['cuda_rng_state'] = torch.cuda.get_rng_state()
220 sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
221
222 ensure_directory_exists(checkpoint_name)
223 torch.save(sd, checkpoint_name)
224 print(' successfully saved {}'.format(checkpoint_name))
225
226 # Wait so everyone is done (necessary)
227 torch.distributed.barrier()
228 # And update the latest iteration
229 if torch.distributed.get_rank() == 0:
230 tracker_filename = get_checkpoint_tracker_filename(args.save)
231 with open(tracker_filename, 'w') as f:
232 f.write(str(iteration))
233 # Wait so everyone is done (not necessary)
234 torch.distributed.barrier()
235
236
237def save_ds_checkpoint(iteration, model, lr_scheduler, args):

Callers 3

trainFunction · 0.90
save_on_exitFunction · 0.90
mainFunction · 0.90

Calls 6

save_ds_checkpointFunction · 0.85
get_checkpoint_nameFunction · 0.85
ensure_directory_existsFunction · 0.85
get_statesMethod · 0.80
state_dictMethod · 0.45

Tested by

no test coverage detected