MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / load_checkpoint

Function load_checkpoint

Megatron-LM/utils.py:271–358  ·  view source on GitHub ↗

Load a model checkpoint.

(model, optimizer, lr_scheduler, args)

Source from the content-addressed store, hash-verified

269 return iteration, release, True
270
271def load_checkpoint(model, optimizer, lr_scheduler, args):
272 """Load a model checkpoint."""
273
274 iteration, release, success = get_checkpoint_iteration(args)
275
276 if not success:
277 return 0
278
279 if args.deepspeed:
280
281 checkpoint_name, sd = model.load_checkpoint(args.load, iteration)
282
283 if checkpoint_name is None:
284 if mpu.get_data_parallel_rank() == 0:
285 print("Unable to load checkpoint.")
286 return iteration
287
288 else:
289
290 # Checkpoint.
291 checkpoint_name = get_checkpoint_name(args.load, iteration, release)
292
293 if mpu.get_data_parallel_rank() == 0:
294 print('global rank {} is loading checkpoint {}'.format(
295 torch.distributed.get_rank(), checkpoint_name))
296
297 # Load the checkpoint.
298 sd = torch.load(checkpoint_name, map_location='cpu')
299
300 if isinstance(model, torchDDP):
301 model = model.module
302
303 # Model.
304 try:
305 model.load_state_dict(sd['model'])
306 except KeyError:
307 print_rank_0('A metadata file exists but unable to load model '
308 'from checkpoint {}, exiting'.format(checkpoint_name))
309 exit()
310
311 # Optimizer.
312 if not release and not args.finetune and not args.no_load_optim:
313 try:
314 if optimizer is not None:
315 optimizer.load_state_dict(sd['optimizer'])
316 if lr_scheduler is not None:
317 lr_scheduler.load_state_dict(sd['lr_scheduler'])
318 except KeyError:
319 print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
320 'Specify --no-load-optim or --finetune to prevent '
321 'attempting to load the optimizer '
322 'state.'.format(checkpoint_name))
323 exit()
324
325 # Iterations.
326 if args.finetune or release:
327 iteration = 0
328 else:

Callers 4

setup_modelFunction · 0.90
setup_modelFunction · 0.90

Calls 6

get_checkpoint_iterationFunction · 0.85
get_checkpoint_nameFunction · 0.85
set_statesMethod · 0.80
print_rank_0Function · 0.70
loadMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected