Load a model checkpoint.
(model, optimizer, lr_scheduler, args, load_optimizer_states=True)
| 287 | return position_embeddings |
| 288 | |
| 289 | def load_checkpoint(model, optimizer, lr_scheduler, args, load_optimizer_states=True): |
| 290 | """Load a model checkpoint.""" |
| 291 | |
| 292 | iteration, release, success = get_checkpoint_iteration(args) |
| 293 | |
| 294 | if not success: |
| 295 | return 0 |
| 296 | |
| 297 | if args.deepspeed: |
| 298 | |
| 299 | checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=not args.no_load_optim) |
| 300 | if args.fp16 and args.no_load_optim: |
| 301 | model.optimizer.refresh_fp32_params() |
| 302 | |
| 303 | if "client_lr_scheduler" in sd: |
| 304 | lr_scheduler.load_state_dict(sd["client_lr_scheduler"]) |
| 305 | print_rank_0("Load lr scheduler state") |
| 306 | if checkpoint_name is None: |
| 307 | if mpu.get_data_parallel_rank() == 0: |
| 308 | print("Unable to load checkpoint.") |
| 309 | return iteration |
| 310 | |
| 311 | else: |
| 312 | |
| 313 | # Checkpoint. |
| 314 | checkpoint_name = get_checkpoint_name(args.load, iteration, release) |
| 315 | |
| 316 | if mpu.get_data_parallel_rank() == 0: |
| 317 | print('global rank {} is loading checkpoint {}'.format( |
| 318 | torch.distributed.get_rank(), checkpoint_name)) |
| 319 | |
| 320 | # Load the checkpoint. |
| 321 | sd = torch.load(checkpoint_name, map_location='cpu') |
| 322 | |
| 323 | if isinstance(model, torchDDP): |
| 324 | model = model.module |
| 325 | |
| 326 | # Model. |
| 327 | try: |
| 328 | model.load_state_dict(sd['module']) |
| 329 | except KeyError: |
| 330 | print_rank_0('A metadata file exists but unable to load model ' |
| 331 | 'from checkpoint {}, exiting'.format(checkpoint_name)) |
| 332 | exit() |
| 333 | |
| 334 | # Optimizer. |
| 335 | if not release and not args.finetune and not args.no_load_optim: |
| 336 | try: |
| 337 | if optimizer is not None and load_optimizer_states: |
| 338 | optimizer.load_state_dict(sd['optimizer']) |
| 339 | if lr_scheduler is not None: |
| 340 | lr_scheduler.load_state_dict(sd['lr_scheduler']) |
| 341 | except KeyError: |
| 342 | print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' |
| 343 | 'Specify --no-load-optim or --finetune to prevent ' |
| 344 | 'attempting to load the optimizer ' |
| 345 | 'state.'.format(checkpoint_name)) |
| 346 | exit() |
no test coverage detected