()
| 53 | |
| 54 | |
| 55 | def main(): |
| 56 | cfg, _ = parse_config_with_json(PPOConfig, "configs/ppo.json") |
| 57 | ctx = ddp_setup(cfg.device) |
| 58 | set_seed(cfg.seed + ctx.rank) |
| 59 | |
| 60 | backbone = load_backbone_from_ckpt(cfg, cfg.sft_ckpt, ctx.device) |
| 61 | ref = make_frozen_copy(backbone, device=ctx.device) |
| 62 | actor = TransformerWithValueHead(backbone).to(ctx.device) |
| 63 | actor_ddp = ddp_wrap(actor, ctx) |
| 64 | optimizer = configure_optimizer(unwrap(actor_ddp), cfg.lr, weight_decay=0.0) |
| 65 | |
| 66 | rm = load_reward_model(cfg, cfg.reward_ckpt, ctx.device) if cfg.reward_source == "rm" else None |
| 67 | |
| 68 | eval_set = None # loaded lazily on first eval to keep startup/smoke runs offline |
| 69 | logger = MetricsLogger("ppo", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project) if ctx.is_main else None |
| 70 | if ctx.is_main: |
| 71 | print(f"PPO from {cfg.sft_ckpt} | reward={cfg.reward_source} | world={ctx.world_size}") |
| 72 | |
| 73 | prompt_it = get_prompt_iterator(cfg.prompt_path, cfg.prompts_per_iter, rank=ctx.rank, |
| 74 | world_size=ctx.world_size, seed=cfg.seed) |
| 75 | |
| 76 | for it in range(cfg.iterations): |
| 77 | rows = next(prompt_it) |
| 78 | prompts = [encode_prompt([{"role": "user", "content": r["prompt"]}]) for r in rows] |
| 79 | golds = [r.get("gold") for r in rows] |
| 80 | |
| 81 | # --- rollout --- |
| 82 | actor.eval() |
| 83 | with amp_autocast(cfg.amp_dtype, ctx.device): |
| 84 | seqs, rmask, plens = rollout_prompts(actor, prompts, cfg.rollout_len, device=ctx.device, |
| 85 | temperature=cfg.temperature, top_p=cfg.top_p if cfg.top_p < 1 else None) |
| 86 | resp = rmask[:, 1:] # action-frame response mask (N, T-1) |
| 87 | |
| 88 | # --- score --- |
| 89 | seq_lens = seq_lengths_from_mask(rmask, plens) |
| 90 | responses = [decode(seqs[i, plens[i]:seq_lens[i]].tolist()) for i in range(len(rows))] |
| 91 | if cfg.reward_source == "rm": |
| 92 | with torch.no_grad(), amp_autocast(cfg.amp_dtype, ctx.device): |
| 93 | task_r = rm(seqs, seq_lengths=seq_lens).float().tolist() |
| 94 | else: |
| 95 | task_r = [reward_gsm8k(responses[i], golds[i]) for i in range(len(rows))] |
| 96 | |
| 97 | # --- per-token rewards: KL penalty everywhere + task reward at last response token --- |
| 98 | with torch.no_grad(), amp_autocast(cfg.amp_dtype, ctx.device): |
| 99 | old_logp, old_values = actor_logp_values(actor, seqs, cfg.temperature) |
| 100 | ref_logp, _ = compute_logprobs(ref, seqs, rmask, temperature=cfg.temperature, requires_grad=False) |
| 101 | old_logp, old_values, ref_logp = old_logp.float(), old_values.float(), ref_logp.float() |
| 102 | |
| 103 | rewards = -cfg.kl_coef * (old_logp - ref_logp) * resp.float() |
| 104 | last_idx = seq_lens - 2 # action index of the last response token |
| 105 | last_idx = last_idx.clamp(min=0) |
| 106 | task_t = torch.tensor(task_r, device=ctx.device, dtype=torch.float32) |
| 107 | rewards[torch.arange(len(rows), device=ctx.device), last_idx] += task_t |
| 108 | |
| 109 | values_next = torch.cat([old_values[:, 1:], torch.zeros_like(old_values[:, :1])], dim=1) |
| 110 | adv, returns = compute_gae(rewards, old_values, values_next, resp, gamma=cfg.gamma, lam=cfg.gae_lambda) |
| 111 | adv = whiten(adv, resp) |
| 112 |
no test coverage detected