Load a model checkpoint.
(model, optimizer, lr_scheduler, args, no_deepspeed=False, no_load_optim=False, no_load_rng=False)
| 325 | |
| 326 | |
| 327 | def load_checkpoint(model, optimizer, lr_scheduler, args, no_deepspeed=False, no_load_optim=False, no_load_rng=False): |
| 328 | """Load a model checkpoint.""" |
| 329 | |
| 330 | load_dir, tag, release, success = get_checkpoint_iteration(args.load) |
| 331 | |
| 332 | if not success: |
| 333 | return 0 |
| 334 | |
| 335 | if args.deepspeed and not no_deepspeed: |
| 336 | |
| 337 | checkpoint_name, sd = model.load_checkpoint(load_dir, tag, |
| 338 | load_optimizer_states=not args.no_load_optim and not no_load_optim, |
| 339 | load_lr_scheduler_states=not args.no_load_lr_scheduler) |
| 340 | if not args.no_load_lr_scheduler and "client_lr_scheduler" in sd: |
| 341 | lr_scheduler.load_state_dict(sd["client_lr_scheduler"]) |
| 342 | print_rank_0("Load lr scheduler state") |
| 343 | if checkpoint_name is None: |
| 344 | if mpu.get_data_parallel_rank() == 0: |
| 345 | print("Unable to load checkpoint.") |
| 346 | return tag |
| 347 | |
| 348 | else: |
| 349 | |
| 350 | # Checkpoint. |
| 351 | checkpoint_name = get_checkpoint_name(load_dir, tag, release) |
| 352 | |
| 353 | if mpu.get_data_parallel_rank() == 0: |
| 354 | print('global rank {} is loading checkpoint {}'.format( |
| 355 | torch.distributed.get_rank(), checkpoint_name)) |
| 356 | |
| 357 | # Load the checkpoint. |
| 358 | sd = torch.load(checkpoint_name, map_location='cpu') |
| 359 | |
| 360 | # Model. |
| 361 | if args.deepspeed: |
| 362 | model = model.module |
| 363 | missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False) |
| 364 | if missing_keys or unexpected_keys: |
| 365 | print_rank_0(f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}") |
| 366 | |
| 367 | # Optimizer. |
| 368 | if not release and not args.finetune and not args.no_load_optim and not no_load_optim: |
| 369 | try: |
| 370 | if optimizer is not None: |
| 371 | optimizer.load_state_dict(sd['optimizer']) |
| 372 | if lr_scheduler is not None: |
| 373 | lr_scheduler.load_state_dict(sd['lr_scheduler']) |
| 374 | except KeyError: |
| 375 | print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' |
| 376 | 'Specify --no-load-optim or --finetune to prevent ' |
| 377 | 'attempting to load the optimizer ' |
| 378 | 'state.'.format(checkpoint_name)) |
| 379 | |
| 380 | # Iterations. |
| 381 | if args.finetune or release: |
| 382 | iteration = 0 |
| 383 | else: |
| 384 | try: |
no test coverage detected