MCPcopy
hub / github.com/InternLM/InternLM / main

Function main

train.py:70–290  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

68
69
70def main(args):
71 # init setting
72 skip_batches = gpc.config.data.skip_batches
73 total_steps = gpc.config.data.total_steps
74 valid_every = gpc.config.data.valid_every
75 label_smoothing = gpc.config.loss.label_smoothing
76 lr = gpc.config.adam.lr
77
78 get_tflops_func = partial(
79 get_megatron_flops,
80 checkpoint=gpc.config.model.checkpoint,
81 seq_len=gpc.config.SEQ_LEN,
82 hidden_size=gpc.config.model.hidden_size,
83 num_layers=gpc.config.model.num_layers,
84 vocab_size=gpc.config.model.vocab_size,
85 global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
86 global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
87 mlp_ratio=gpc.config.MLP_RATIO,
88 )
89
90 # get and broadcast current time
91 current_time = launch_time()
92 objs = [current_time]
93 dist.broadcast_object_list(objs, src=0)
94 current_time = objs[0]
95
96 # initialize customed llm logger
97 uniscale_logger = initialize_llm_logger(start_time=current_time)
98
99 # initialize and resume train state
100 train_state = TrainState(gpc.config)
101
102 # initialize model
103 model = initialize_model()
104
105 with open(args.config, "r") as f:
106 config_lines = f.readlines()
107 ckpt_manager = CheckpointManager(
108 ckpt_config=gpc.config.ckpt,
109 model=model,
110 model_config=gpc.config.model,
111 model_config_file="".join(config_lines),
112 feishu_address=gpc.config.alert_address,
113 )
114
115 # initialize loss function
116 criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
117
118 # initialize the train and validation data loader
119 train_dl, dataset_types = get_train_data_loader(num_worker=4)
120 val_dls = get_validation_data_loader()
121 train_state.init_batch_sampler(train_dl)
122
123 # Loading model weights must be done before zero is initialized.
124 ckpt_manager.try_load_model(current_time)
125
126 optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
127

Callers 1

train.pyFile · 0.70

Calls 15

init_batch_samplerMethod · 0.95
try_load_modelMethod · 0.95
try_resume_trainingMethod · 0.95
set_current_type_idsMethod · 0.95
try_save_checkpointMethod · 0.95
stepMethod · 0.95
launch_timeFunction · 0.90
TrainStateClass · 0.90
initialize_modelFunction · 0.90
CheckpointManagerClass · 0.90
FlashGPTLMLossClass · 0.90

Tested by

no test coverage detected