MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / restore_exception_checkpoint

Function restore_exception_checkpoint

codegeex/mindspore/train.py:425–484  ·  view source on GitHub ↗

Restore exception checkpoint. Args: args_param: training job params sink_size: training job sink size dataset: dataset for training model: model network: pangu_alpha network epoch: training epoch Returns: load exception checkpoint success

(args_param, sink_size, dataset, model, network, epoch)

Source from the content-addressed store, hash-verified

423
424
425def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, epoch):
426 """
427 Restore exception checkpoint.
428 Args:
429 args_param: training job params
430 sink_size: training job sink size
431 dataset: dataset for training
432 model: model
433 network: pangu_alpha network
434 epoch: training epoch
435 Returns: load exception checkpoint success or not.
436 """
437 if os.getenv("RESTORE_RANKS") == "-1":
438 return False
439
440 ckpt_file_list = get_exception_checkpoints(args_param)
441
442 restore_flag = False
443 if ckpt_file_list:
444 restore_flag = check_exception_checkpoints(ckpt_file_list)
445
446 if not restore_flag:
447 return False
448
449 ckpt_name = args_param.ckpt_name_prefix
450 restore_ranks_map = os.getenv("RESTORE_RANKS_MAP")
451 if not restore_ranks_map:
452 return False
453
454 try:
455 print("whether run into load process")
456 restore_ranks_map_json = json.loads(restore_ranks_map)
457 map_rank_id = D.get_rank()
458 for key in restore_ranks_map_json.keys():
459 key_list = list(key.split(","))
460 if str(D.get_rank()) in key_list:
461 map_rank_id = restore_ranks_map_json.get(key)
462
463 print(f"loading map rank id {map_rank_id}")
464 ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
465 f"rank_{map_rank_id}",
466 f"{ckpt_name}*breakpoint.ckpt")
467 ckpt_files = glob.glob(ckpt_pattern)
468 ckpt_files.sort(key=os.path.getmtime, reverse=True)
469 print(f" checkpoint files {ckpt_files[0]}")
470 param_dict = load_checkpoint(ckpt_files[0])
471 print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}")
472 if param_dict.get("epoch_num") and param_dict.get("step_num"):
473 args_param.has_trained_epoches = int(
474 param_dict["epoch_num"].data.asnumpy())
475 args_param.has_trained_steps = int(
476 param_dict["step_num"].data.asnumpy())
477
478 # Load checkpoint files
479 model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
480 load_param_into_net(network, param_dict)
481 except TypeError:
482 return False

Callers 1

run_train_pipelineFunction · 0.85

Calls 4

load_checkpointFunction · 0.85
getMethod · 0.45

Tested by

no test coverage detected