Load a model checkpoint.
(model, optimizer, lr_scheduler, args)
| 269 | return iteration, release, True |
| 270 | |
| 271 | def load_checkpoint(model, optimizer, lr_scheduler, args): |
| 272 | """Load a model checkpoint.""" |
| 273 | |
| 274 | iteration, release, success = get_checkpoint_iteration(args) |
| 275 | |
| 276 | if not success: |
| 277 | return 0 |
| 278 | |
| 279 | if args.deepspeed: |
| 280 | |
| 281 | checkpoint_name, sd = model.load_checkpoint(args.load, iteration) |
| 282 | |
| 283 | if checkpoint_name is None: |
| 284 | if mpu.get_data_parallel_rank() == 0: |
| 285 | print("Unable to load checkpoint.") |
| 286 | return iteration |
| 287 | |
| 288 | else: |
| 289 | |
| 290 | # Checkpoint. |
| 291 | checkpoint_name = get_checkpoint_name(args.load, iteration, release) |
| 292 | |
| 293 | if mpu.get_data_parallel_rank() == 0: |
| 294 | print('global rank {} is loading checkpoint {}'.format( |
| 295 | torch.distributed.get_rank(), checkpoint_name)) |
| 296 | |
| 297 | # Load the checkpoint. |
| 298 | sd = torch.load(checkpoint_name, map_location='cpu') |
| 299 | |
| 300 | if isinstance(model, torchDDP): |
| 301 | model = model.module |
| 302 | |
| 303 | # Model. |
| 304 | try: |
| 305 | model.load_state_dict(sd['model']) |
| 306 | except KeyError: |
| 307 | print_rank_0('A metadata file exists but unable to load model ' |
| 308 | 'from checkpoint {}, exiting'.format(checkpoint_name)) |
| 309 | exit() |
| 310 | |
| 311 | # Optimizer. |
| 312 | if not release and not args.finetune and not args.no_load_optim: |
| 313 | try: |
| 314 | if optimizer is not None: |
| 315 | optimizer.load_state_dict(sd['optimizer']) |
| 316 | if lr_scheduler is not None: |
| 317 | lr_scheduler.load_state_dict(sd['lr_scheduler']) |
| 318 | except KeyError: |
| 319 | print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' |
| 320 | 'Specify --no-load-optim or --finetune to prevent ' |
| 321 | 'attempting to load the optimizer ' |
| 322 | 'state.'.format(checkpoint_name)) |
| 323 | exit() |
| 324 | |
| 325 | # Iterations. |
| 326 | if args.finetune or release: |
| 327 | iteration = 0 |
| 328 | else: |
no test coverage detected