()
| 72 | |
| 73 | |
| 74 | def main(): |
| 75 | cfg, _ = parse_config_with_json(DPOConfig, "configs/dpo.json") |
| 76 | ctx = ddp_setup(cfg.device) |
| 77 | set_seed(cfg.seed + ctx.rank) |
| 78 | |
| 79 | policy = load_backbone_from_ckpt(cfg, cfg.sft_ckpt, ctx.device) |
| 80 | ref = make_frozen_copy(policy, device=ctx.device) if cfg.loss_type != "orpo" else None |
| 81 | policy = ddp_wrap(policy, ctx) |
| 82 | optimizer = configure_optimizer(unwrap(policy), cfg.lr, cfg.weight_decay) |
| 83 | |
| 84 | with open(cfg.pref_path) as f: |
| 85 | n_rows = sum(1 for line in f if line.strip()) |
| 86 | total_steps = max(1, (n_rows // (cfg.batch_size * ctx.world_size)) * cfg.epochs) |
| 87 | |
| 88 | logger = None |
| 89 | if ctx.is_main: |
| 90 | print(f"DPO[{cfg.loss_type}] from {cfg.sft_ckpt} | {n_rows} pairs | total_steps={total_steps} | beta={cfg.beta}") |
| 91 | logger = MetricsLogger(f"dpo_{cfg.loss_type}", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project) |
| 92 | |
| 93 | train_it = get_preference_iterator(cfg.pref_path, cfg.batch_size, cfg.max_len, device=ctx.device, |
| 94 | rank=ctx.rank, world_size=ctx.world_size, shuffle=True, infinite=True) |
| 95 | |
| 96 | policy.train() |
| 97 | t0 = time.perf_counter() |
| 98 | for step in range(total_steps): |
| 99 | lr = cosine_lr(step, warmup_steps=cfg.warmup_steps, max_steps=total_steps, lr=cfg.lr, min_lr=cfg.lr * 0.1) |
| 100 | for g in optimizer.param_groups: |
| 101 | g["lr"] = lr |
| 102 | |
| 103 | batch = next(train_it) |
| 104 | loss, cr, rr = _compute_losses(policy, ref, batch, cfg, ctx) |
| 105 | optimizer.zero_grad(set_to_none=True) |
| 106 | loss.backward() |
| 107 | torch.nn.utils.clip_grad_norm_(policy.parameters(), cfg.grad_clip) |
| 108 | optimizer.step() |
| 109 | |
| 110 | if ctx.is_main and step % 20 == 0: |
| 111 | acc = implicit_accuracy(cr, rr).item() |
| 112 | dt = time.perf_counter() - t0; t0 = time.perf_counter() |
| 113 | print(f"step {step}/{total_steps} | loss {loss.item():.4f} | acc {acc:.3f} | " |
| 114 | f"r_chosen {cr.mean().item():.3f} r_rejected {rr.mean().item():.3f} | {dt:.1f}s/20") |
| 115 | if logger: |
| 116 | logger.log(step, {"train_loss": loss.item(), "train_acc": acc, |
| 117 | "r_chosen": cr.mean().item(), "r_rejected": rr.mean().item(), "lr": lr}) |
| 118 | |
| 119 | if step > 0 and step % cfg.eval_steps == 0: |
| 120 | acc, marg = eval_implicit_acc(policy, ref, cfg, ctx) |
| 121 | acc, marg = reduce_scalar(acc, ctx), reduce_scalar(marg, ctx) |
| 122 | if ctx.is_main: |
| 123 | print(f" [eval] step {step} | test_acc {acc:.3f} | margin {marg:.3f}") |
| 124 | if logger: |
| 125 | logger.log(step, {"test_acc": acc, "test_margin": marg}) |
| 126 | |
| 127 | if ctx.is_main and step > 0 and step % cfg.save_every == 0: |
| 128 | save_stage_ckpt(cfg.out_ckpt, policy, optimizer, stage=f"dpo_{cfg.loss_type}", cfg=cfg, step=step, |
| 129 | metrics={"train_loss": loss.item()}) |
| 130 | |
| 131 | if ctx.is_main: |
no test coverage detected