()
| 57 | |
| 58 | |
| 59 | def main(): |
| 60 | cfg, _ = parse_config_with_json(RewardConfig, "configs/reward.json") |
| 61 | ctx = ddp_setup(cfg.device) |
| 62 | set_seed(cfg.seed + ctx.rank) |
| 63 | |
| 64 | backbone = load_backbone_from_ckpt(cfg, cfg.sft_ckpt, ctx.device) |
| 65 | rm = RewardModel(backbone).to(ctx.device) |
| 66 | # find_unused_parameters=True: the reward model uses the backbone's forward_hidden + a |
| 67 | # reward head and never its lm_head, so lm_head params get no gradient. Without this flag |
| 68 | # DDP errors on the first backward. |
| 69 | rm = ddp_wrap(rm, ctx, find_unused_parameters=True) |
| 70 | optimizer = configure_optimizer(unwrap(rm), cfg.lr, cfg.weight_decay) |
| 71 | |
| 72 | import json |
| 73 | with open(cfg.pref_path) as f: |
| 74 | n_rows = sum(1 for line in f if line.strip()) |
| 75 | total_steps = max(1, (n_rows // (cfg.batch_size * ctx.world_size)) * cfg.epochs) |
| 76 | |
| 77 | logger = None |
| 78 | if ctx.is_main: |
| 79 | print(f"Reward model from {cfg.sft_ckpt} | {n_rows} pairs | total_steps={total_steps}") |
| 80 | logger = MetricsLogger("reward", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project) |
| 81 | |
| 82 | train_it = get_preference_iterator(cfg.pref_path, cfg.batch_size, cfg.max_len, device=ctx.device, |
| 83 | rank=ctx.rank, world_size=ctx.world_size, shuffle=True, infinite=True) |
| 84 | |
| 85 | rm.train() |
| 86 | t0 = time.perf_counter() |
| 87 | for step in range(total_steps): |
| 88 | lr = cosine_lr(step, warmup_steps=cfg.warmup_steps, max_steps=total_steps, lr=cfg.lr, min_lr=cfg.lr * 0.1) |
| 89 | for g in optimizer.param_groups: |
| 90 | g["lr"] = lr |
| 91 | |
| 92 | batch = next(train_it) |
| 93 | cr, rr = _pair_rewards(rm, batch, cfg, ctx) |
| 94 | loss = bradley_terry_loss(cr, rr) |
| 95 | optimizer.zero_grad(set_to_none=True) |
| 96 | loss.backward() |
| 97 | torch.nn.utils.clip_grad_norm_(rm.parameters(), cfg.grad_clip) |
| 98 | optimizer.step() |
| 99 | |
| 100 | if ctx.is_main and step % 20 == 0: |
| 101 | acc = preference_accuracy(cr, rr).item() |
| 102 | dt = time.perf_counter() - t0; t0 = time.perf_counter() |
| 103 | print(f"step {step}/{total_steps} | loss {loss.item():.4f} | train_acc {acc:.3f} | lr {lr:.2e} | {dt:.1f}s/20") |
| 104 | if logger: |
| 105 | logger.log(step, {"train_loss": loss.item(), "train_acc": acc, "lr": lr}) |
| 106 | |
| 107 | if step > 0 and step % cfg.eval_steps == 0: |
| 108 | acc, marg = eval_accuracy(rm, cfg, ctx) |
| 109 | acc, marg = reduce_scalar(acc, ctx), reduce_scalar(marg, ctx) |
| 110 | if ctx.is_main: |
| 111 | print(f" [eval] step {step} | test_acc {acc:.3f} | margin {marg:.3f}") |
| 112 | if logger: |
| 113 | logger.log(step, {"test_acc": acc, "test_margin": marg}) |
| 114 | |
| 115 | if ctx.is_main and step > 0 and step % cfg.save_every == 0: |
| 116 | save_stage_ckpt(cfg.out_ckpt, rm, optimizer, stage="reward", cfg=cfg, step=step, |
no test coverage detected