MCPcopy Index your code
hub / github.com/THUDM/GLM / load_checkpoint

Function load_checkpoint

utils.py:327–411  ·  view source on GitHub ↗

Load a model checkpoint.

(model, optimizer, lr_scheduler, args, no_deepspeed=False, no_load_optim=False, no_load_rng=False)

Source from the content-addressed store, hash-verified

325
326
327def load_checkpoint(model, optimizer, lr_scheduler, args, no_deepspeed=False, no_load_optim=False, no_load_rng=False):
328 """Load a model checkpoint."""
329
330 load_dir, tag, release, success = get_checkpoint_iteration(args.load)
331
332 if not success:
333 return 0
334
335 if args.deepspeed and not no_deepspeed:
336
337 checkpoint_name, sd = model.load_checkpoint(load_dir, tag,
338 load_optimizer_states=not args.no_load_optim and not no_load_optim,
339 load_lr_scheduler_states=not args.no_load_lr_scheduler)
340 if not args.no_load_lr_scheduler and "client_lr_scheduler" in sd:
341 lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
342 print_rank_0("Load lr scheduler state")
343 if checkpoint_name is None:
344 if mpu.get_data_parallel_rank() == 0:
345 print("Unable to load checkpoint.")
346 return tag
347
348 else:
349
350 # Checkpoint.
351 checkpoint_name = get_checkpoint_name(load_dir, tag, release)
352
353 if mpu.get_data_parallel_rank() == 0:
354 print('global rank {} is loading checkpoint {}'.format(
355 torch.distributed.get_rank(), checkpoint_name))
356
357 # Load the checkpoint.
358 sd = torch.load(checkpoint_name, map_location='cpu')
359
360 # Model.
361 if args.deepspeed:
362 model = model.module
363 missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False)
364 if missing_keys or unexpected_keys:
365 print_rank_0(f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}")
366
367 # Optimizer.
368 if not release and not args.finetune and not args.no_load_optim and not no_load_optim:
369 try:
370 if optimizer is not None:
371 optimizer.load_state_dict(sd['optimizer'])
372 if lr_scheduler is not None:
373 lr_scheduler.load_state_dict(sd['lr_scheduler'])
374 except KeyError:
375 print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
376 'Specify --no-load-optim or --finetune to prevent '
377 'attempting to load the optimizer '
378 'state.'.format(checkpoint_name))
379
380 # Iterations.
381 if args.finetune or release:
382 iteration = 0
383 else:
384 try:

Callers 3

finetuneFunction · 0.90
mainFunction · 0.90
setup_modelFunction · 0.90

Calls 6

get_checkpoint_iterationFunction · 0.85
print_rank_0Function · 0.85
get_checkpoint_nameFunction · 0.85
loadMethod · 0.80
set_statesMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected