(
model: Transformer,
batch: MiniBatch,
tokenizer: Tokenizer,
max_gen_len: int,
num_answer_per_question: int,
reward_function: Callable,
device: torch.device,
dtype: torch.dtype,
)
| 14 | |
| 15 | @torch.no_grad() |
| 16 | def rollout( |
| 17 | model: Transformer, |
| 18 | batch: MiniBatch, |
| 19 | tokenizer: Tokenizer, |
| 20 | max_gen_len: int, |
| 21 | num_answer_per_question: int, |
| 22 | reward_function: Callable, |
| 23 | device: torch.device, |
| 24 | dtype: torch.dtype, |
| 25 | ) -> List[Episode]: |
| 26 | end_token = tokenizer.eos_token |
| 27 | end_token_id = tokenizer.eos_token_id |
| 28 | pad_token_id = tokenizer.pad_token_id |
| 29 | prefix_token_ids = batch.prefix_token_ids |
| 30 | bsz = len(batch.prefix) * num_answer_per_question |
| 31 | min_prompt_len = min(len(t) for t in prefix_token_ids) |
| 32 | max_prompt_len = max(len(t) for t in prefix_token_ids) |
| 33 | total_len = max_gen_len + max_prompt_len |
| 34 | model.init_kv_cache( |
| 35 | max_batch_size=bsz, |
| 36 | max_seq_len=total_len, |
| 37 | device=device, |
| 38 | dtype=dtype, |
| 39 | ) |
| 40 | tokens = torch.full((bsz, total_len), pad_token_id, dtype=torch.long, device=device) |
| 41 | for k, t in enumerate(prefix_token_ids): |
| 42 | offset = k * num_answer_per_question |
| 43 | for i in range(num_answer_per_question): |
| 44 | tokens[offset + i, : len(t)] = torch.tensor( |
| 45 | t, dtype=torch.long, device=device |
| 46 | ) |
| 47 | |
| 48 | prev_pos = 0 |
| 49 | input_text_mask = tokens != pad_token_id |
| 50 | assert min_prompt_len < total_len |
| 51 | is_finished = torch.zeros((bsz,), dtype=torch.bool, device=device) |
| 52 | |
| 53 | for cur_pos in range(min_prompt_len, total_len): |
| 54 | print( |
| 55 | f"\r* Generating trajectories: {cur_pos-min_prompt_len:>4d}/{total_len-min_prompt_len:>4d}", |
| 56 | flush=True, |
| 57 | end="", |
| 58 | ) |
| 59 | with torch.autocast(device_type=device.type, dtype=dtype): |
| 60 | logits = model.inference(tokens[:, prev_pos:cur_pos], prev_pos) |
| 61 | probs = torch.softmax(logits[:, -1], dim=-1) |
| 62 | next_token = torch.multinomial(probs, num_samples=1) |
| 63 | next_token = next_token.reshape(-1) |
| 64 | next_token = torch.where( |
| 65 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token |
| 66 | ) |
| 67 | # if an rollout is finished, we fill the rest of the tokens with pad_token_id |
| 68 | next_token = torch.where(is_finished, pad_token_id, next_token) |
| 69 | tokens[:, cur_pos] = next_token |
| 70 | if end_token_id is not None: |
| 71 | is_end_token = next_token == end_token_id |
| 72 | is_generated_token = ~input_text_mask[:, cur_pos] |
| 73 | is_finished = is_finished | (is_end_token & is_generated_token) |
no test coverage detected