(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
random_samples: torch.Tensor,
)
| 296 | |
| 297 | |
| 298 | def _random_sample( |
| 299 | selected_seq_groups: List[Tuple[List[int], SamplingParams]], |
| 300 | is_prompts: List[bool], |
| 301 | random_samples: torch.Tensor, |
| 302 | ) -> List[Tuple[List[int], List[int]]]: |
| 303 | # Find the maximum best_of value of the prompt phase requests. |
| 304 | random_samples = random_samples.cpu() |
| 305 | sample_idx = 0 |
| 306 | results = [] |
| 307 | for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): |
| 308 | seq_ids, sampling_params = seq_group |
| 309 | num_parent_seqs = len(seq_ids) |
| 310 | if is_prompt: |
| 311 | # Prompt phase. |
| 312 | parent_ids = [0] * sampling_params.best_of |
| 313 | next_token_ids = random_samples[ |
| 314 | sample_idx, :sampling_params.best_of].tolist() |
| 315 | else: |
| 316 | # Generation phase. |
| 317 | parent_ids = list(range(num_parent_seqs)) |
| 318 | next_token_ids = random_samples[sample_idx:sample_idx + |
| 319 | num_parent_seqs, 0].tolist() |
| 320 | results.append((next_token_ids, parent_ids)) |
| 321 | sample_idx += num_parent_seqs |
| 322 | return results |
| 323 | |
| 324 | |
| 325 | def _beam_search_sample( |
no outgoing calls
no test coverage detected