Restore exception checkpoint. Args: args_param: training job params sink_size: training job sink size dataset: dataset for training model: model network: pangu_alpha network epoch: training epoch Returns: load exception checkpoint success
(args_param, sink_size, dataset, model, network, epoch)
| 423 | |
| 424 | |
| 425 | def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, epoch): |
| 426 | """ |
| 427 | Restore exception checkpoint. |
| 428 | Args: |
| 429 | args_param: training job params |
| 430 | sink_size: training job sink size |
| 431 | dataset: dataset for training |
| 432 | model: model |
| 433 | network: pangu_alpha network |
| 434 | epoch: training epoch |
| 435 | Returns: load exception checkpoint success or not. |
| 436 | """ |
| 437 | if os.getenv("RESTORE_RANKS") == "-1": |
| 438 | return False |
| 439 | |
| 440 | ckpt_file_list = get_exception_checkpoints(args_param) |
| 441 | |
| 442 | restore_flag = False |
| 443 | if ckpt_file_list: |
| 444 | restore_flag = check_exception_checkpoints(ckpt_file_list) |
| 445 | |
| 446 | if not restore_flag: |
| 447 | return False |
| 448 | |
| 449 | ckpt_name = args_param.ckpt_name_prefix |
| 450 | restore_ranks_map = os.getenv("RESTORE_RANKS_MAP") |
| 451 | if not restore_ranks_map: |
| 452 | return False |
| 453 | |
| 454 | try: |
| 455 | print("whether run into load process") |
| 456 | restore_ranks_map_json = json.loads(restore_ranks_map) |
| 457 | map_rank_id = D.get_rank() |
| 458 | for key in restore_ranks_map_json.keys(): |
| 459 | key_list = list(key.split(",")) |
| 460 | if str(D.get_rank()) in key_list: |
| 461 | map_rank_id = restore_ranks_map_json.get(key) |
| 462 | |
| 463 | print(f"loading map rank id {map_rank_id}") |
| 464 | ckpt_pattern = os.path.join(args_param.save_checkpoint_path, |
| 465 | f"rank_{map_rank_id}", |
| 466 | f"{ckpt_name}*breakpoint.ckpt") |
| 467 | ckpt_files = glob.glob(ckpt_pattern) |
| 468 | ckpt_files.sort(key=os.path.getmtime, reverse=True) |
| 469 | print(f" checkpoint files {ckpt_files[0]}") |
| 470 | param_dict = load_checkpoint(ckpt_files[0]) |
| 471 | print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}") |
| 472 | if param_dict.get("epoch_num") and param_dict.get("step_num"): |
| 473 | args_param.has_trained_epoches = int( |
| 474 | param_dict["epoch_num"].data.asnumpy()) |
| 475 | args_param.has_trained_steps = int( |
| 476 | param_dict["step_num"].data.asnumpy()) |
| 477 | |
| 478 | # Load checkpoint files |
| 479 | model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch) |
| 480 | load_param_into_net(network, param_dict) |
| 481 | except TypeError: |
| 482 | return False |
no test coverage detected