MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / main

Function main

scripts/train_ppo.py:55–163  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

53
54
55def main():
56 cfg, _ = parse_config_with_json(PPOConfig, "configs/ppo.json")
57 ctx = ddp_setup(cfg.device)
58 set_seed(cfg.seed + ctx.rank)
59
60 backbone = load_backbone_from_ckpt(cfg, cfg.sft_ckpt, ctx.device)
61 ref = make_frozen_copy(backbone, device=ctx.device)
62 actor = TransformerWithValueHead(backbone).to(ctx.device)
63 actor_ddp = ddp_wrap(actor, ctx)
64 optimizer = configure_optimizer(unwrap(actor_ddp), cfg.lr, weight_decay=0.0)
65
66 rm = load_reward_model(cfg, cfg.reward_ckpt, ctx.device) if cfg.reward_source == "rm" else None
67
68 eval_set = None # loaded lazily on first eval to keep startup/smoke runs offline
69 logger = MetricsLogger("ppo", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project) if ctx.is_main else None
70 if ctx.is_main:
71 print(f"PPO from {cfg.sft_ckpt} | reward={cfg.reward_source} | world={ctx.world_size}")
72
73 prompt_it = get_prompt_iterator(cfg.prompt_path, cfg.prompts_per_iter, rank=ctx.rank,
74 world_size=ctx.world_size, seed=cfg.seed)
75
76 for it in range(cfg.iterations):
77 rows = next(prompt_it)
78 prompts = [encode_prompt([{"role": "user", "content": r["prompt"]}]) for r in rows]
79 golds = [r.get("gold") for r in rows]
80
81 # --- rollout ---
82 actor.eval()
83 with amp_autocast(cfg.amp_dtype, ctx.device):
84 seqs, rmask, plens = rollout_prompts(actor, prompts, cfg.rollout_len, device=ctx.device,
85 temperature=cfg.temperature, top_p=cfg.top_p if cfg.top_p < 1 else None)
86 resp = rmask[:, 1:] # action-frame response mask (N, T-1)
87
88 # --- score ---
89 seq_lens = seq_lengths_from_mask(rmask, plens)
90 responses = [decode(seqs[i, plens[i]:seq_lens[i]].tolist()) for i in range(len(rows))]
91 if cfg.reward_source == "rm":
92 with torch.no_grad(), amp_autocast(cfg.amp_dtype, ctx.device):
93 task_r = rm(seqs, seq_lengths=seq_lens).float().tolist()
94 else:
95 task_r = [reward_gsm8k(responses[i], golds[i]) for i in range(len(rows))]
96
97 # --- per-token rewards: KL penalty everywhere + task reward at last response token ---
98 with torch.no_grad(), amp_autocast(cfg.amp_dtype, ctx.device):
99 old_logp, old_values = actor_logp_values(actor, seqs, cfg.temperature)
100 ref_logp, _ = compute_logprobs(ref, seqs, rmask, temperature=cfg.temperature, requires_grad=False)
101 old_logp, old_values, ref_logp = old_logp.float(), old_values.float(), ref_logp.float()
102
103 rewards = -cfg.kl_coef * (old_logp - ref_logp) * resp.float()
104 last_idx = seq_lens - 2 # action index of the last response token
105 last_idx = last_idx.clamp(min=0)
106 task_t = torch.tensor(task_r, device=ctx.device, dtype=torch.float32)
107 rewards[torch.arange(len(rows), device=ctx.device), last_idx] += task_t
108
109 values_next = torch.cat([old_values[:, 1:], torch.zeros_like(old_values[:, :1])], dim=1)
110 adv, returns = compute_gae(rewards, old_values, values_next, resp, gamma=cfg.gamma, lam=cfg.gae_lambda)
111 adv = whiten(adv, resp)
112

Callers 1

train_ppo.pyFile · 0.70

Calls 15

parse_config_with_jsonFunction · 0.90
ddp_setupFunction · 0.90
set_seedFunction · 0.90
load_backbone_from_ckptFunction · 0.90
make_frozen_copyFunction · 0.90
ddp_wrapFunction · 0.90
configure_optimizerFunction · 0.90
unwrapFunction · 0.90
load_reward_modelFunction · 0.90
MetricsLoggerClass · 0.90
get_prompt_iteratorFunction · 0.90

Tested by

no test coverage detected