()
| 41 | |
| 42 | |
| 43 | def main(): |
| 44 | cfg, _ = parse_config_with_json(GRPOConfig, "configs/grpo.json") |
| 45 | ctx = ddp_setup(cfg.device) |
| 46 | set_seed(cfg.seed + ctx.rank) |
| 47 | |
| 48 | policy = load_backbone_from_ckpt(cfg, cfg.sft_ckpt, ctx.device) |
| 49 | ref = make_frozen_copy(policy, device=ctx.device) |
| 50 | policy_ddp = ddp_wrap(policy, ctx) |
| 51 | optimizer = configure_optimizer(unwrap(policy_ddp), cfg.lr, weight_decay=0.0) |
| 52 | |
| 53 | eval_set = None |
| 54 | logger = MetricsLogger("grpo", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project) if ctx.is_main else None |
| 55 | if ctx.is_main: |
| 56 | print(f"GRPO from {cfg.sft_ckpt} | group_size={cfg.group_size} | world={ctx.world_size}") |
| 57 | |
| 58 | warm_it = get_prompt_iterator(cfg.curriculum_path, cfg.prompts_per_iter, rank=ctx.rank, |
| 59 | world_size=ctx.world_size, seed=cfg.seed) |
| 60 | main_it = get_prompt_iterator(cfg.prompt_path, cfg.prompts_per_iter, rank=ctx.rank, |
| 61 | world_size=ctx.world_size, seed=cfg.seed) |
| 62 | |
| 63 | G = cfg.group_size |
| 64 | for it in range(cfg.iterations): |
| 65 | rows = next(warm_it if it < cfg.curriculum_iters else main_it) |
| 66 | # Replicate each prompt G times, group-contiguously. |
| 67 | base_prompts = [encode_prompt([{"role": "user", "content": r["prompt"]}]) for r in rows] |
| 68 | prompts = [p for p in base_prompts for _ in range(G)] |
| 69 | golds = [r.get("gold") for r in rows for _ in range(G)] |
| 70 | |
| 71 | policy.eval() |
| 72 | with amp_autocast(cfg.amp_dtype, ctx.device): |
| 73 | seqs, rmask, plens = rollout_prompts(policy, prompts, cfg.rollout_len, device=ctx.device, |
| 74 | temperature=cfg.temperature, top_p=cfg.top_p if cfg.top_p < 1 else None) |
| 75 | resp = rmask[:, 1:] |
| 76 | |
| 77 | seq_lens = seq_lengths_from_mask(rmask, plens) |
| 78 | responses = [decode(seqs[i, plens[i]:seq_lens[i]].tolist()) for i in range(len(prompts))] |
| 79 | rewards = torch.tensor([reward_gsm8k(responses[i], golds[i]) for i in range(len(prompts))], |
| 80 | device=ctx.device, dtype=torch.float32) |
| 81 | adv = group_advantages(rewards, G) |
| 82 | |
| 83 | with torch.no_grad(), amp_autocast(cfg.amp_dtype, ctx.device): |
| 84 | old_logp, _ = compute_logprobs(policy, seqs, rmask, temperature=cfg.temperature, requires_grad=False) |
| 85 | ref_logp, _ = compute_logprobs(ref, seqs, rmask, temperature=cfg.temperature, requires_grad=False) |
| 86 | old_logp, ref_logp = old_logp.float(), ref_logp.float() |
| 87 | |
| 88 | policy.train() |
| 89 | N = seqs.size(0) |
| 90 | agg = {"loss": 0.0, "kl": 0.0, "clipfrac": 0.0, "n": 0} |
| 91 | for _ in range(cfg.grpo_epochs): |
| 92 | perm = torch.randperm(N, device=ctx.device) |
| 93 | for s in range(0, N, max(1, G)): # minibatch ~ one group's worth |
| 94 | mb = perm[s:s + max(1, G)] |
| 95 | with amp_autocast(cfg.amp_dtype, ctx.device): |
| 96 | new_logp, _ = compute_logprobs(policy_ddp, seqs[mb], rmask[mb], temperature=cfg.temperature, requires_grad=True) |
| 97 | loss, st = grpo_loss(new_logp.float(), old_logp[mb], ref_logp[mb], adv[mb], resp[mb], |
| 98 | clip=cfg.clip, kl_coef=cfg.kl_coef) |
| 99 | optimizer.zero_grad(set_to_none=True) |
| 100 | loss.backward() |
no test coverage detected