(
probs: torch.Tensor,
num_samples: int,
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
generators: Optional[List[torch.Generator]] = None,
)
| 381 | # probs will be modified in place, but this is fine, as we pass |
| 382 | # in a copy already. |
| 383 | def _multinomial( |
| 384 | probs: torch.Tensor, |
| 385 | num_samples: int, |
| 386 | seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None, |
| 387 | generators: Optional[List[torch.Generator]] = None, |
| 388 | ) -> torch.Tensor: |
| 389 | if num_samples > 1: |
| 390 | # This is equivalent to torch.repeat_interleaved (which also |
| 391 | # forces a GPU<->CPU sync). |
| 392 | # This allows us to do sampling with replacement by creating |
| 393 | # num_samples copies of each row in the tensor, and then |
| 394 | # batch sampling the resulting tensor. |
| 395 | probs = probs[:, None, :].expand(probs.shape[0], num_samples, |
| 396 | probs.shape[1]).contiguous().view( |
| 397 | -1, probs.shape[1]) |
| 398 | q = torch.empty_like(probs) |
| 399 | if seq_groups is None: |
| 400 | q.exponential_() |
| 401 | else: |
| 402 | sample_idx = 0 |
| 403 | for (seq_ids, _), generator in zip(seq_groups, generators): |
| 404 | next_sample_idx = sample_idx + len(seq_ids) * num_samples |
| 405 | q[sample_idx:next_sample_idx].exponential_(generator=generator) |
| 406 | sample_idx = next_sample_idx |
| 407 | return probs.div_(q).argmax(dim=1).view(-1, num_samples) |
| 408 | |
| 409 | |
| 410 | def _sample_with_torch( |
no outgoing calls
no test coverage detected