()
| 201 | |
| 202 | |
| 203 | def train(): |
| 204 | deepspeed.init_distributed(dist_backend="nccl") |
| 205 | RANK = dist.get_rank() |
| 206 | WORLD_SIZE = dist.get_world_size() |
| 207 | |
| 208 | # Args |
| 209 | args = parse_args() |
| 210 | |
| 211 | hub_upload_check(args.push_to_hub) |
| 212 | |
| 213 | # Dataset |
| 214 | train_dataset, train_loader = create_dataset_and_dataloader(args, 0) |
| 215 | |
| 216 | if train_dataset is None: |
| 217 | raise RuntimeError("Training data not found.") |
| 218 | |
| 219 | # Load model type |
| 220 | args.model_type = train_dataset.metadata["model_type"] |
| 221 | |
| 222 | train_total_steps = args.epochs * train_dataset.estimate_num_batches() |
| 223 | |
| 224 | # Hyperparams |
| 225 | args.lr = calculate_auto_lr(args.lr, args.batch_max_len, args.model_type, train_dataset) |
| 226 | |
| 227 | # Model |
| 228 | model_engine, optimizer = create_model(args) |
| 229 | |
| 230 | # LR Scheduler |
| 231 | lr_scheduler = create_lr_scheduler(args, train_total_steps) |
| 232 | |
| 233 | # Progress bar and logger |
| 234 | progress_bar = None |
| 235 | if RANK == 0: |
| 236 | progress_bar = tqdm.tqdm(total=train_total_steps) |
| 237 | |
| 238 | wandb.init(project=args.wandb_project or os.path.basename(args.model_path), entity=args.wandb_entity, config=args) |
| 239 | |
| 240 | # Training Loop |
| 241 | step = 0 |
| 242 | lr_this_step = None |
| 243 | for epoch in range(args.epochs): |
| 244 | print (f"[rank {RANK} of {WORLD_SIZE}]: Epoch {epoch}") |
| 245 | |
| 246 | ############ Load Dataset |
| 247 | if epoch != 0: |
| 248 | del train_dataset, train_loader |
| 249 | |
| 250 | train_dataset, train_loader = create_dataset_and_dataloader(args, epoch) |
| 251 | |
| 252 | ############ Train Epoch |
| 253 | model_engine.train() |
| 254 | for (batch_tensor, batch_info), all_numseq, cur_numseq in train_loader: |
| 255 | step += 1 |
| 256 | if step > train_total_steps: # At most train_total_steps |
| 257 | break |
| 258 | |
| 259 | # To device |
| 260 | batch_tensor = {k: (v.to(args.device) if v is not None else None) for k, v in batch_tensor.items()} |
no test coverage detected