MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / main

Function main

scripts/train_reward.py:59–128  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

57
58
59def main():
60 cfg, _ = parse_config_with_json(RewardConfig, "configs/reward.json")
61 ctx = ddp_setup(cfg.device)
62 set_seed(cfg.seed + ctx.rank)
63
64 backbone = load_backbone_from_ckpt(cfg, cfg.sft_ckpt, ctx.device)
65 rm = RewardModel(backbone).to(ctx.device)
66 # find_unused_parameters=True: the reward model uses the backbone's forward_hidden + a
67 # reward head and never its lm_head, so lm_head params get no gradient. Without this flag
68 # DDP errors on the first backward.
69 rm = ddp_wrap(rm, ctx, find_unused_parameters=True)
70 optimizer = configure_optimizer(unwrap(rm), cfg.lr, cfg.weight_decay)
71
72 import json
73 with open(cfg.pref_path) as f:
74 n_rows = sum(1 for line in f if line.strip())
75 total_steps = max(1, (n_rows // (cfg.batch_size * ctx.world_size)) * cfg.epochs)
76
77 logger = None
78 if ctx.is_main:
79 print(f"Reward model from {cfg.sft_ckpt} | {n_rows} pairs | total_steps={total_steps}")
80 logger = MetricsLogger("reward", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project)
81
82 train_it = get_preference_iterator(cfg.pref_path, cfg.batch_size, cfg.max_len, device=ctx.device,
83 rank=ctx.rank, world_size=ctx.world_size, shuffle=True, infinite=True)
84
85 rm.train()
86 t0 = time.perf_counter()
87 for step in range(total_steps):
88 lr = cosine_lr(step, warmup_steps=cfg.warmup_steps, max_steps=total_steps, lr=cfg.lr, min_lr=cfg.lr * 0.1)
89 for g in optimizer.param_groups:
90 g["lr"] = lr
91
92 batch = next(train_it)
93 cr, rr = _pair_rewards(rm, batch, cfg, ctx)
94 loss = bradley_terry_loss(cr, rr)
95 optimizer.zero_grad(set_to_none=True)
96 loss.backward()
97 torch.nn.utils.clip_grad_norm_(rm.parameters(), cfg.grad_clip)
98 optimizer.step()
99
100 if ctx.is_main and step % 20 == 0:
101 acc = preference_accuracy(cr, rr).item()
102 dt = time.perf_counter() - t0; t0 = time.perf_counter()
103 print(f"step {step}/{total_steps} | loss {loss.item():.4f} | train_acc {acc:.3f} | lr {lr:.2e} | {dt:.1f}s/20")
104 if logger:
105 logger.log(step, {"train_loss": loss.item(), "train_acc": acc, "lr": lr})
106
107 if step > 0 and step % cfg.eval_steps == 0:
108 acc, marg = eval_accuracy(rm, cfg, ctx)
109 acc, marg = reduce_scalar(acc, ctx), reduce_scalar(marg, ctx)
110 if ctx.is_main:
111 print(f" [eval] step {step} | test_acc {acc:.3f} | margin {marg:.3f}")
112 if logger:
113 logger.log(step, {"test_acc": acc, "test_margin": marg})
114
115 if ctx.is_main and step > 0 and step % cfg.save_every == 0:
116 save_stage_ckpt(cfg.out_ckpt, rm, optimizer, stage="reward", cfg=cfg, step=step,

Callers 1

train_reward.pyFile · 0.70

Calls 15

logMethod · 0.95
closeMethod · 0.95
parse_config_with_jsonFunction · 0.90
ddp_setupFunction · 0.90
set_seedFunction · 0.90
load_backbone_from_ckptFunction · 0.90
RewardModelClass · 0.90
ddp_wrapFunction · 0.90
configure_optimizerFunction · 0.90
unwrapFunction · 0.90
MetricsLoggerClass · 0.90
get_preference_iteratorFunction · 0.90

Tested by

no test coverage detected