MCPcopy
hub / github.com/imoneoi/openchat / train

Function train

ochat/training_deepspeed/train.py:203–308  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

201
202
203def train():
204 deepspeed.init_distributed(dist_backend="nccl")
205 RANK = dist.get_rank()
206 WORLD_SIZE = dist.get_world_size()
207
208 # Args
209 args = parse_args()
210
211 hub_upload_check(args.push_to_hub)
212
213 # Dataset
214 train_dataset, train_loader = create_dataset_and_dataloader(args, 0)
215
216 if train_dataset is None:
217 raise RuntimeError("Training data not found.")
218
219 # Load model type
220 args.model_type = train_dataset.metadata["model_type"]
221
222 train_total_steps = args.epochs * train_dataset.estimate_num_batches()
223
224 # Hyperparams
225 args.lr = calculate_auto_lr(args.lr, args.batch_max_len, args.model_type, train_dataset)
226
227 # Model
228 model_engine, optimizer = create_model(args)
229
230 # LR Scheduler
231 lr_scheduler = create_lr_scheduler(args, train_total_steps)
232
233 # Progress bar and logger
234 progress_bar = None
235 if RANK == 0:
236 progress_bar = tqdm.tqdm(total=train_total_steps)
237
238 wandb.init(project=args.wandb_project or os.path.basename(args.model_path), entity=args.wandb_entity, config=args)
239
240 # Training Loop
241 step = 0
242 lr_this_step = None
243 for epoch in range(args.epochs):
244 print (f"[rank {RANK} of {WORLD_SIZE}]: Epoch {epoch}")
245
246 ############ Load Dataset
247 if epoch != 0:
248 del train_dataset, train_loader
249
250 train_dataset, train_loader = create_dataset_and_dataloader(args, epoch)
251
252 ############ Train Epoch
253 model_engine.train()
254 for (batch_tensor, batch_info), all_numseq, cur_numseq in train_loader:
255 step += 1
256 if step > train_total_steps: # At most train_total_steps
257 break
258
259 # To device
260 batch_tensor = {k: (v.to(args.device) if v is not None else None) for k, v in batch_tensor.items()}

Callers 1

train.pyFile · 0.70

Calls 11

hub_upload_checkFunction · 0.90
hub_upload_model_asyncFunction · 0.90
parse_argsFunction · 0.85
calculate_auto_lrFunction · 0.85
create_modelFunction · 0.85
create_lr_schedulerFunction · 0.85
state_dict_to_cpuFunction · 0.85
save_tokenizerFunction · 0.85
save_openchat_metadataFunction · 0.85
estimate_num_batchesMethod · 0.45

Tested by

no test coverage detected