MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / main

Function main

scripts/train_dpo.py:74–140  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

72
73
74def main():
75 cfg, _ = parse_config_with_json(DPOConfig, "configs/dpo.json")
76 ctx = ddp_setup(cfg.device)
77 set_seed(cfg.seed + ctx.rank)
78
79 policy = load_backbone_from_ckpt(cfg, cfg.sft_ckpt, ctx.device)
80 ref = make_frozen_copy(policy, device=ctx.device) if cfg.loss_type != "orpo" else None
81 policy = ddp_wrap(policy, ctx)
82 optimizer = configure_optimizer(unwrap(policy), cfg.lr, cfg.weight_decay)
83
84 with open(cfg.pref_path) as f:
85 n_rows = sum(1 for line in f if line.strip())
86 total_steps = max(1, (n_rows // (cfg.batch_size * ctx.world_size)) * cfg.epochs)
87
88 logger = None
89 if ctx.is_main:
90 print(f"DPO[{cfg.loss_type}] from {cfg.sft_ckpt} | {n_rows} pairs | total_steps={total_steps} | beta={cfg.beta}")
91 logger = MetricsLogger(f"dpo_{cfg.loss_type}", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project)
92
93 train_it = get_preference_iterator(cfg.pref_path, cfg.batch_size, cfg.max_len, device=ctx.device,
94 rank=ctx.rank, world_size=ctx.world_size, shuffle=True, infinite=True)
95
96 policy.train()
97 t0 = time.perf_counter()
98 for step in range(total_steps):
99 lr = cosine_lr(step, warmup_steps=cfg.warmup_steps, max_steps=total_steps, lr=cfg.lr, min_lr=cfg.lr * 0.1)
100 for g in optimizer.param_groups:
101 g["lr"] = lr
102
103 batch = next(train_it)
104 loss, cr, rr = _compute_losses(policy, ref, batch, cfg, ctx)
105 optimizer.zero_grad(set_to_none=True)
106 loss.backward()
107 torch.nn.utils.clip_grad_norm_(policy.parameters(), cfg.grad_clip)
108 optimizer.step()
109
110 if ctx.is_main and step % 20 == 0:
111 acc = implicit_accuracy(cr, rr).item()
112 dt = time.perf_counter() - t0; t0 = time.perf_counter()
113 print(f"step {step}/{total_steps} | loss {loss.item():.4f} | acc {acc:.3f} | "
114 f"r_chosen {cr.mean().item():.3f} r_rejected {rr.mean().item():.3f} | {dt:.1f}s/20")
115 if logger:
116 logger.log(step, {"train_loss": loss.item(), "train_acc": acc,
117 "r_chosen": cr.mean().item(), "r_rejected": rr.mean().item(), "lr": lr})
118
119 if step > 0 and step % cfg.eval_steps == 0:
120 acc, marg = eval_implicit_acc(policy, ref, cfg, ctx)
121 acc, marg = reduce_scalar(acc, ctx), reduce_scalar(marg, ctx)
122 if ctx.is_main:
123 print(f" [eval] step {step} | test_acc {acc:.3f} | margin {marg:.3f}")
124 if logger:
125 logger.log(step, {"test_acc": acc, "test_margin": marg})
126
127 if ctx.is_main and step > 0 and step % cfg.save_every == 0:
128 save_stage_ckpt(cfg.out_ckpt, policy, optimizer, stage=f"dpo_{cfg.loss_type}", cfg=cfg, step=step,
129 metrics={"train_loss": loss.item()})
130
131 if ctx.is_main:

Callers 1

train_dpo.pyFile · 0.70

Calls 15

logMethod · 0.95
closeMethod · 0.95
parse_config_with_jsonFunction · 0.90
ddp_setupFunction · 0.90
set_seedFunction · 0.90
load_backbone_from_ckptFunction · 0.90
make_frozen_copyFunction · 0.90
ddp_wrapFunction · 0.90
configure_optimizerFunction · 0.90
unwrapFunction · 0.90
MetricsLoggerClass · 0.90
get_preference_iteratorFunction · 0.90

Tested by

no test coverage detected