MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / main

Function main

scripts/train_grpo.py:43–135  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

41
42
43def main():
44 cfg, _ = parse_config_with_json(GRPOConfig, "configs/grpo.json")
45 ctx = ddp_setup(cfg.device)
46 set_seed(cfg.seed + ctx.rank)
47
48 policy = load_backbone_from_ckpt(cfg, cfg.sft_ckpt, ctx.device)
49 ref = make_frozen_copy(policy, device=ctx.device)
50 policy_ddp = ddp_wrap(policy, ctx)
51 optimizer = configure_optimizer(unwrap(policy_ddp), cfg.lr, weight_decay=0.0)
52
53 eval_set = None
54 logger = MetricsLogger("grpo", cfg.log_dir, use_wandb=cfg.use_wandb, wandb_project=cfg.wandb_project) if ctx.is_main else None
55 if ctx.is_main:
56 print(f"GRPO from {cfg.sft_ckpt} | group_size={cfg.group_size} | world={ctx.world_size}")
57
58 warm_it = get_prompt_iterator(cfg.curriculum_path, cfg.prompts_per_iter, rank=ctx.rank,
59 world_size=ctx.world_size, seed=cfg.seed)
60 main_it = get_prompt_iterator(cfg.prompt_path, cfg.prompts_per_iter, rank=ctx.rank,
61 world_size=ctx.world_size, seed=cfg.seed)
62
63 G = cfg.group_size
64 for it in range(cfg.iterations):
65 rows = next(warm_it if it < cfg.curriculum_iters else main_it)
66 # Replicate each prompt G times, group-contiguously.
67 base_prompts = [encode_prompt([{"role": "user", "content": r["prompt"]}]) for r in rows]
68 prompts = [p for p in base_prompts for _ in range(G)]
69 golds = [r.get("gold") for r in rows for _ in range(G)]
70
71 policy.eval()
72 with amp_autocast(cfg.amp_dtype, ctx.device):
73 seqs, rmask, plens = rollout_prompts(policy, prompts, cfg.rollout_len, device=ctx.device,
74 temperature=cfg.temperature, top_p=cfg.top_p if cfg.top_p < 1 else None)
75 resp = rmask[:, 1:]
76
77 seq_lens = seq_lengths_from_mask(rmask, plens)
78 responses = [decode(seqs[i, plens[i]:seq_lens[i]].tolist()) for i in range(len(prompts))]
79 rewards = torch.tensor([reward_gsm8k(responses[i], golds[i]) for i in range(len(prompts))],
80 device=ctx.device, dtype=torch.float32)
81 adv = group_advantages(rewards, G)
82
83 with torch.no_grad(), amp_autocast(cfg.amp_dtype, ctx.device):
84 old_logp, _ = compute_logprobs(policy, seqs, rmask, temperature=cfg.temperature, requires_grad=False)
85 ref_logp, _ = compute_logprobs(ref, seqs, rmask, temperature=cfg.temperature, requires_grad=False)
86 old_logp, ref_logp = old_logp.float(), ref_logp.float()
87
88 policy.train()
89 N = seqs.size(0)
90 agg = {"loss": 0.0, "kl": 0.0, "clipfrac": 0.0, "n": 0}
91 for _ in range(cfg.grpo_epochs):
92 perm = torch.randperm(N, device=ctx.device)
93 for s in range(0, N, max(1, G)): # minibatch ~ one group's worth
94 mb = perm[s:s + max(1, G)]
95 with amp_autocast(cfg.amp_dtype, ctx.device):
96 new_logp, _ = compute_logprobs(policy_ddp, seqs[mb], rmask[mb], temperature=cfg.temperature, requires_grad=True)
97 loss, st = grpo_loss(new_logp.float(), old_logp[mb], ref_logp[mb], adv[mb], resp[mb],
98 clip=cfg.clip, kl_coef=cfg.kl_coef)
99 optimizer.zero_grad(set_to_none=True)
100 loss.backward()

Callers 1

train_grpo.pyFile · 0.70

Calls 15

parse_config_with_jsonFunction · 0.90
ddp_setupFunction · 0.90
set_seedFunction · 0.90
load_backbone_from_ckptFunction · 0.90
make_frozen_copyFunction · 0.90
ddp_wrapFunction · 0.90
configure_optimizerFunction · 0.90
unwrapFunction · 0.90
MetricsLoggerClass · 0.90
get_prompt_iteratorFunction · 0.90
encode_promptFunction · 0.90
amp_autocastFunction · 0.90

Tested by

no test coverage detected