MCPcopy
hub / github.com/policy-gradient/GRPO-Zero / rollout

Function rollout

grpo.py:16–114  ·  view source on GitHub ↗
(
    model: Transformer,
    batch: MiniBatch,
    tokenizer: Tokenizer,
    max_gen_len: int,
    num_answer_per_question: int,
    reward_function: Callable,
    device: torch.device,
    dtype: torch.dtype,
)

Source from the content-addressed store, hash-verified

14
15@torch.no_grad()
16def rollout(
17 model: Transformer,
18 batch: MiniBatch,
19 tokenizer: Tokenizer,
20 max_gen_len: int,
21 num_answer_per_question: int,
22 reward_function: Callable,
23 device: torch.device,
24 dtype: torch.dtype,
25) -> List[Episode]:
26 end_token = tokenizer.eos_token
27 end_token_id = tokenizer.eos_token_id
28 pad_token_id = tokenizer.pad_token_id
29 prefix_token_ids = batch.prefix_token_ids
30 bsz = len(batch.prefix) * num_answer_per_question
31 min_prompt_len = min(len(t) for t in prefix_token_ids)
32 max_prompt_len = max(len(t) for t in prefix_token_ids)
33 total_len = max_gen_len + max_prompt_len
34 model.init_kv_cache(
35 max_batch_size=bsz,
36 max_seq_len=total_len,
37 device=device,
38 dtype=dtype,
39 )
40 tokens = torch.full((bsz, total_len), pad_token_id, dtype=torch.long, device=device)
41 for k, t in enumerate(prefix_token_ids):
42 offset = k * num_answer_per_question
43 for i in range(num_answer_per_question):
44 tokens[offset + i, : len(t)] = torch.tensor(
45 t, dtype=torch.long, device=device
46 )
47
48 prev_pos = 0
49 input_text_mask = tokens != pad_token_id
50 assert min_prompt_len < total_len
51 is_finished = torch.zeros((bsz,), dtype=torch.bool, device=device)
52
53 for cur_pos in range(min_prompt_len, total_len):
54 print(
55 f"\r* Generating trajectories: {cur_pos-min_prompt_len:>4d}/{total_len-min_prompt_len:>4d}",
56 flush=True,
57 end="",
58 )
59 with torch.autocast(device_type=device.type, dtype=dtype):
60 logits = model.inference(tokens[:, prev_pos:cur_pos], prev_pos)
61 probs = torch.softmax(logits[:, -1], dim=-1)
62 next_token = torch.multinomial(probs, num_samples=1)
63 next_token = next_token.reshape(-1)
64 next_token = torch.where(
65 input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
66 )
67 # if an rollout is finished, we fill the rest of the tokens with pad_token_id
68 next_token = torch.where(is_finished, pad_token_id, next_token)
69 tokens[:, cur_pos] = next_token
70 if end_token_id is not None:
71 is_end_token = next_token == end_token_id
72 is_generated_token = ~input_text_mask[:, cur_pos]
73 is_finished = is_finished | (is_end_token & is_generated_token)

Callers 2

evaluateFunction · 0.90
mainFunction · 0.90

Calls 6

EpisodeClass · 0.90
reward_functionFunction · 0.85
inferenceMethod · 0.80
detokenizeMethod · 0.80
init_kv_cacheMethod · 0.45
del_kv_cacheMethod · 0.45

Tested by

no test coverage detected