()
| 48 | |
| 49 | |
| 50 | def main(): |
| 51 | cfg, _ = parse_config_with_json(SFTConfig, "configs/sft.json") |
| 52 | ctx = ddp_setup(cfg.device) |
| 53 | set_seed(cfg.seed + ctx.rank) |
| 54 | |
| 55 | model = load_backbone_from_ckpt(cfg, cfg.pretrained_ckpt, ctx.device) |
| 56 | if cfg.compile: |
| 57 | model = torch.compile(model) |
| 58 | model = ddp_wrap(model, ctx) |
| 59 | optimizer = configure_optimizer(unwrap(model), cfg.lr, cfg.weight_decay) |
| 60 | |
| 61 | logger = None |
| 62 | if ctx.is_main: |
| 63 | print(f"SFT from {cfg.pretrained_ckpt} | world_size={ctx.world_size}") |
| 64 | logger = MetricsLogger("sft", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project) |
| 65 | |
| 66 | train_it = get_sft_batch_iterator(cfg.data_path, cfg.batch_size, device=ctx.device, |
| 67 | rank=ctx.rank, world_size=ctx.world_size, shuffle=True, infinite=True) |
| 68 | |
| 69 | # Estimate total steps for the cosine schedule from dataset size. |
| 70 | import h5py |
| 71 | with h5py.File(cfg.data_path, "r") as f: |
| 72 | n_rows = f["tokens"].shape[0] |
| 73 | steps_per_epoch = max(1, n_rows // (cfg.batch_size * ctx.world_size)) |
| 74 | total_steps = cfg.max_steps if cfg.max_steps > 0 else steps_per_epoch * cfg.epochs |
| 75 | if ctx.is_main: |
| 76 | print(f"{n_rows} packed rows | ~{steps_per_epoch} steps/epoch | total_steps={total_steps}") |
| 77 | |
| 78 | model.train() |
| 79 | t0 = time.perf_counter() |
| 80 | for step in range(total_steps): |
| 81 | lr = cosine_lr(step, warmup_steps=cfg.warmup_steps, max_steps=total_steps, lr=cfg.lr, min_lr=cfg.min_lr) |
| 82 | for g in optimizer.param_groups: |
| 83 | g["lr"] = lr |
| 84 | |
| 85 | tokens, mask, epoch = next(train_it) |
| 86 | if epoch >= cfg.epochs and cfg.max_steps <= 0: |
| 87 | break |
| 88 | optimizer.zero_grad(set_to_none=True) |
| 89 | with amp_autocast(cfg.amp_dtype, ctx.device): |
| 90 | logits, _ = model(tokens) |
| 91 | loss = sft_loss(logits, tokens, mask) |
| 92 | loss.backward() |
| 93 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) |
| 94 | optimizer.step() |
| 95 | |
| 96 | if ctx.is_main and step % 20 == 0: |
| 97 | dt = time.perf_counter() - t0; t0 = time.perf_counter() |
| 98 | print(f"step {step}/{total_steps} | loss {loss.item():.4f} | ppl {math.exp(min(20, loss.item())):.2f} | lr {lr:.2e} | {dt:.1f}s/20") |
| 99 | if logger: |
| 100 | logger.log(step, {"train_loss": loss.item(), "lr": lr}) |
| 101 | |
| 102 | if step > 0 and step % cfg.eval_steps == 0: |
| 103 | dev = reduce_scalar(eval_dev(model, cfg, ctx, DEV_PATH), ctx) |
| 104 | if ctx.is_main: |
| 105 | print(f" [eval] step {step} | dev_loss {dev:.4f} | dev_ppl {math.exp(min(20, dev)):.2f}") |
| 106 | if logger: |
| 107 | logger.log(step, {"dev_loss": dev}) |
no test coverage detected