r""" The main function for load model
(args_opt)
| 38 | |
| 39 | |
| 40 | def load_model(args_opt): |
| 41 | r""" |
| 42 | The main function for load model |
| 43 | """ |
| 44 | # Set execution mode |
| 45 | context.set_context(save_graphs=False, |
| 46 | mode=context.GRAPH_MODE, |
| 47 | device_target=args_opt.device_target) |
| 48 | context.set_context(variable_memory_max_size="30GB") |
| 49 | # Set parallel context |
| 50 | if args_opt.distribute == "true": |
| 51 | D.init() |
| 52 | device_num = D.get_group_size() |
| 53 | rank = D.get_rank() |
| 54 | print("rank_id is {}, device_num is {}".format(rank, device_num)) |
| 55 | context.reset_auto_parallel_context() |
| 56 | context.set_auto_parallel_context( |
| 57 | parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, |
| 58 | gradients_mean=False, |
| 59 | full_batch=True, |
| 60 | loss_repeated_mean=True, |
| 61 | enable_parallel_optimizer=False, |
| 62 | pipeline_stages=args_opt.stage_num) |
| 63 | set_algo_parameters(elementwise_op_strategy_follow=True) |
| 64 | _set_multi_subgraphs() |
| 65 | |
| 66 | else: |
| 67 | rank = 0 |
| 68 | device_num = 1 |
| 69 | context.reset_auto_parallel_context() |
| 70 | context.set_auto_parallel_context( |
| 71 | strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) |
| 72 | context.set_context( |
| 73 | save_graphs=False, |
| 74 | save_graphs_path="/cache/graphs_of_device_id_" + str(rank), |
| 75 | ) |
| 76 | use_past = (args_opt.use_past == "true") |
| 77 | print('local_rank:{}, start to run...'.format(rank), flush=True) |
| 78 | if args_opt.export: |
| 79 | use_past = True |
| 80 | # Set model property |
| 81 | model_parallel_num = args_opt.op_level_model_parallel_num |
| 82 | data_parallel_num = int(device_num / model_parallel_num) |
| 83 | |
| 84 | parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, |
| 85 | model_parallel=model_parallel_num, |
| 86 | pipeline_stage=args_opt.stage_num, |
| 87 | micro_batch_num=args_opt.micro_size, |
| 88 | optimizer_shard=False, |
| 89 | vocab_emb_dp=bool(args_opt.word_emb_dp), |
| 90 | recompute=True) |
| 91 | |
| 92 | per_batch_size = args_opt.per_batch_size |
| 93 | batch_size = per_batch_size * data_parallel_num |
| 94 | config = PanguAlphaConfig( |
| 95 | batch_size=batch_size, |
| 96 | seq_length=args_opt.seq_length, |
| 97 | vocab_size=args_opt.vocab_size, |
no test coverage detected