MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / main

Function main

scripts/train_sft.py:50–123  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

48
49
50def main():
51 cfg, _ = parse_config_with_json(SFTConfig, "configs/sft.json")
52 ctx = ddp_setup(cfg.device)
53 set_seed(cfg.seed + ctx.rank)
54
55 model = load_backbone_from_ckpt(cfg, cfg.pretrained_ckpt, ctx.device)
56 if cfg.compile:
57 model = torch.compile(model)
58 model = ddp_wrap(model, ctx)
59 optimizer = configure_optimizer(unwrap(model), cfg.lr, cfg.weight_decay)
60
61 logger = None
62 if ctx.is_main:
63 print(f"SFT from {cfg.pretrained_ckpt} | world_size={ctx.world_size}")
64 logger = MetricsLogger("sft", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project)
65
66 train_it = get_sft_batch_iterator(cfg.data_path, cfg.batch_size, device=ctx.device,
67 rank=ctx.rank, world_size=ctx.world_size, shuffle=True, infinite=True)
68
69 # Estimate total steps for the cosine schedule from dataset size.
70 import h5py
71 with h5py.File(cfg.data_path, "r") as f:
72 n_rows = f["tokens"].shape[0]
73 steps_per_epoch = max(1, n_rows // (cfg.batch_size * ctx.world_size))
74 total_steps = cfg.max_steps if cfg.max_steps > 0 else steps_per_epoch * cfg.epochs
75 if ctx.is_main:
76 print(f"{n_rows} packed rows | ~{steps_per_epoch} steps/epoch | total_steps={total_steps}")
77
78 model.train()
79 t0 = time.perf_counter()
80 for step in range(total_steps):
81 lr = cosine_lr(step, warmup_steps=cfg.warmup_steps, max_steps=total_steps, lr=cfg.lr, min_lr=cfg.min_lr)
82 for g in optimizer.param_groups:
83 g["lr"] = lr
84
85 tokens, mask, epoch = next(train_it)
86 if epoch >= cfg.epochs and cfg.max_steps <= 0:
87 break
88 optimizer.zero_grad(set_to_none=True)
89 with amp_autocast(cfg.amp_dtype, ctx.device):
90 logits, _ = model(tokens)
91 loss = sft_loss(logits, tokens, mask)
92 loss.backward()
93 torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
94 optimizer.step()
95
96 if ctx.is_main and step % 20 == 0:
97 dt = time.perf_counter() - t0; t0 = time.perf_counter()
98 print(f"step {step}/{total_steps} | loss {loss.item():.4f} | ppl {math.exp(min(20, loss.item())):.2f} | lr {lr:.2e} | {dt:.1f}s/20")
99 if logger:
100 logger.log(step, {"train_loss": loss.item(), "lr": lr})
101
102 if step > 0 and step % cfg.eval_steps == 0:
103 dev = reduce_scalar(eval_dev(model, cfg, ctx, DEV_PATH), ctx)
104 if ctx.is_main:
105 print(f" [eval] step {step} | dev_loss {dev:.4f} | dev_ppl {math.exp(min(20, dev)):.2f}")
106 if logger:
107 logger.log(step, {"dev_loss": dev})

Callers 1

train_sft.pyFile · 0.70

Calls 15

logMethod · 0.95
closeMethod · 0.95
parse_config_with_jsonFunction · 0.90
ddp_setupFunction · 0.90
set_seedFunction · 0.90
load_backbone_from_ckptFunction · 0.90
ddp_wrapFunction · 0.90
configure_optimizerFunction · 0.90
unwrapFunction · 0.90
MetricsLoggerClass · 0.90
get_sft_batch_iteratorFunction · 0.90
cosine_lrFunction · 0.90

Tested by

no test coverage detected