(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
samples: torch.Tensor,
)
| 277 | |
| 278 | |
| 279 | def _greedy_sample( |
| 280 | selected_seq_groups: List[Tuple[List[int], SamplingParams]], |
| 281 | samples: torch.Tensor, |
| 282 | ) -> List[Tuple[List[int], List[int]]]: |
| 283 | samples = samples.tolist() |
| 284 | sample_idx = 0 |
| 285 | results = [] |
| 286 | for seq_group in selected_seq_groups: |
| 287 | seq_ids, _ = seq_group |
| 288 | num_parent_seqs = len(seq_ids) |
| 289 | assert num_parent_seqs == 1, ( |
| 290 | "Greedy sampling should have only one seq.") |
| 291 | parent_ids = list(range(num_parent_seqs)) |
| 292 | next_token_ids = [samples[sample_idx]] |
| 293 | results.append((next_token_ids, parent_ids)) |
| 294 | sample_idx += num_parent_seqs |
| 295 | return results |
| 296 | |
| 297 | |
| 298 | def _random_sample( |
no outgoing calls
no test coverage detected