MCPcopy
hub / github.com/zai-org/CogView / load_checkpoint

Function load_checkpoint

utils.py:289–380  ·  view source on GitHub ↗

Load a model checkpoint.

(model, optimizer, lr_scheduler, args, load_optimizer_states=True)

Source from the content-addressed store, hash-verified

287 return position_embeddings
288
289def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=True):
290 """Load a model checkpoint."""
291
292 iteration, release, success = get_checkpoint_iteration(args)
293
294 if not success:
295 return 0
296
297 if args.deepspeed:
298
299 checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim)
300 if args.fp16 and args.no_load_optim:
301 model.optimizer.refresh_fp32_params()
302
303 if "client_lr_scheduler" in sd:
304 lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
305 print_rank_0("Load lr scheduler state")
306 if checkpoint_name is None:
307 if mpu.get_data_parallel_rank() == 0:
308 print("Unable to load checkpoint.")
309 return iteration
310
311 else:
312
313 # Checkpoint.
314 checkpoint_name = get_checkpoint_name(args.load, iteration, release)
315
316 if mpu.get_data_parallel_rank() == 0:
317 print('global rank {} is loading checkpoint {}'.format(
318 torch.distributed.get_rank(), checkpoint_name))
319
320 # Load the checkpoint.
321 sd = torch.load(checkpoint_name, map_location='cpu')
322
323 if isinstance(model, torchDDP):
324 model = model.module
325
326 # Model.
327 try:
328 model.load_state_dict(sd['module'])
329 except KeyError:
330 print_rank_0('A metadata file exists but unable to load model '
331 'from checkpoint {}, exiting'.format(checkpoint_name))
332 exit()
333
334 # Optimizer.
335 if not release and not args.finetune and not args.no_load_optim:
336 try:
337 if optimizer is not None and load_optimizer_states:
338 optimizer.load_state_dict(sd['optimizer'])
339 if lr_scheduler is not None:
340 lr_scheduler.load_state_dict(sd['lr_scheduler'])
341 except KeyError:
342 print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
343 'Specify --no-load-optim or --finetune to prevent '
344 'attempting to load the optimizer '
345 'state.'.format(checkpoint_name))
346 exit()

Callers 2

setup_modelFunction · 0.90
mainFunction · 0.90

Calls 5

get_checkpoint_iterationFunction · 0.85
print_rank_0Function · 0.85
get_checkpoint_nameFunction · 0.85
set_statesMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected