(args)
| 68 | |
| 69 | |
| 70 | def main(args): |
| 71 | # init setting |
| 72 | skip_batches = gpc.config.data.skip_batches |
| 73 | total_steps = gpc.config.data.total_steps |
| 74 | valid_every = gpc.config.data.valid_every |
| 75 | label_smoothing = gpc.config.loss.label_smoothing |
| 76 | lr = gpc.config.adam.lr |
| 77 | |
| 78 | get_tflops_func = partial( |
| 79 | get_megatron_flops, |
| 80 | checkpoint=gpc.config.model.checkpoint, |
| 81 | seq_len=gpc.config.SEQ_LEN, |
| 82 | hidden_size=gpc.config.model.hidden_size, |
| 83 | num_layers=gpc.config.model.num_layers, |
| 84 | vocab_size=gpc.config.model.vocab_size, |
| 85 | global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), |
| 86 | global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), |
| 87 | mlp_ratio=gpc.config.MLP_RATIO, |
| 88 | ) |
| 89 | |
| 90 | # get and broadcast current time |
| 91 | current_time = launch_time() |
| 92 | objs = [current_time] |
| 93 | dist.broadcast_object_list(objs, src=0) |
| 94 | current_time = objs[0] |
| 95 | |
| 96 | # initialize customed llm logger |
| 97 | uniscale_logger = initialize_llm_logger(start_time=current_time) |
| 98 | |
| 99 | # initialize and resume train state |
| 100 | train_state = TrainState(gpc.config) |
| 101 | |
| 102 | # initialize model |
| 103 | model = initialize_model() |
| 104 | |
| 105 | with open(args.config, "r") as f: |
| 106 | config_lines = f.readlines() |
| 107 | ckpt_manager = CheckpointManager( |
| 108 | ckpt_config=gpc.config.ckpt, |
| 109 | model=model, |
| 110 | model_config=gpc.config.model, |
| 111 | model_config_file="".join(config_lines), |
| 112 | feishu_address=gpc.config.alert_address, |
| 113 | ) |
| 114 | |
| 115 | # initialize loss function |
| 116 | criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) |
| 117 | |
| 118 | # initialize the train and validation data loader |
| 119 | train_dl, dataset_types = get_train_data_loader(num_worker=4) |
| 120 | val_dls = get_validation_data_loader() |
| 121 | train_state.init_batch_sampler(train_dl) |
| 122 | |
| 123 | # Loading model weights must be done before zero is initialized. |
| 124 | ckpt_manager.try_load_model(current_time) |
| 125 | |
| 126 | optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) |
| 127 |
no test coverage detected