MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / load_model

Function load_model

codegeex/mindspore/generation_batch.py:40–172  ·  view source on GitHub ↗

r""" The main function for load model

(args_opt)

Source from the content-addressed store, hash-verified

38
39
40def load_model(args_opt):
41 r"""
42 The main function for load model
43 """
44 # Set execution mode
45 context.set_context(save_graphs=False,
46 mode=context.GRAPH_MODE,
47 device_target=args_opt.device_target)
48 context.set_context(variable_memory_max_size="30GB")
49 # Set parallel context
50 if args_opt.distribute == "true":
51 D.init()
52 device_num = D.get_group_size()
53 rank = D.get_rank()
54 print("rank_id is {}, device_num is {}".format(rank, device_num))
55 context.reset_auto_parallel_context()
56 context.set_auto_parallel_context(
57 parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
58 gradients_mean=False,
59 full_batch=True,
60 loss_repeated_mean=True,
61 enable_parallel_optimizer=False,
62 pipeline_stages=args_opt.stage_num)
63 set_algo_parameters(elementwise_op_strategy_follow=True)
64 _set_multi_subgraphs()
65
66 else:
67 rank = 0
68 device_num = 1
69 context.reset_auto_parallel_context()
70 context.set_auto_parallel_context(
71 strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path)
72 context.set_context(
73 save_graphs=False,
74 save_graphs_path="/cache/graphs_of_device_id_" + str(rank),
75 )
76 use_past = (args_opt.use_past == "true")
77 print('local_rank:{}, start to run...'.format(rank), flush=True)
78 if args_opt.export:
79 use_past = True
80 # Set model property
81 model_parallel_num = args_opt.op_level_model_parallel_num
82 data_parallel_num = int(device_num / model_parallel_num)
83
84 parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
85 model_parallel=model_parallel_num,
86 pipeline_stage=args_opt.stage_num,
87 micro_batch_num=args_opt.micro_size,
88 optimizer_shard=False,
89 vocab_emb_dp=bool(args_opt.word_emb_dp),
90 recompute=True)
91
92 per_batch_size = args_opt.per_batch_size
93 batch_size = per_batch_size * data_parallel_num
94 config = PanguAlphaConfig(
95 batch_size=batch_size,
96 seq_length=args_opt.seq_length,
97 vocab_size=args_opt.vocab_size,

Callers 1

mainFunction · 0.70

Calls 6

PanguAlphaConfigClass · 0.90
PanguAlphaModelClass · 0.90
EvalNetClass · 0.90
load_checkpointFunction · 0.85
existsMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected