(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
)
| 408 | |
| 409 | |
| 410 | def _sample_with_torch( |
| 411 | probs: torch.Tensor, |
| 412 | logprobs: torch.Tensor, |
| 413 | sampling_metadata: SamplingMetadata, |
| 414 | include_gpu_probs_tensor: bool, |
| 415 | modify_greedy_probs: bool, |
| 416 | ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: |
| 417 | categorized_seq_group_ids = {t: [] for t in SamplingType} |
| 418 | categorized_sample_indices = sampling_metadata.categorized_sample_indices |
| 419 | for i, seq_group in enumerate(sampling_metadata.seq_groups): |
| 420 | _, sampling_params = seq_group |
| 421 | sampling_type = sampling_params.sampling_type |
| 422 | categorized_seq_group_ids[sampling_type].append(i) |
| 423 | |
| 424 | sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} |
| 425 | sample_metadata = {} |
| 426 | multinomial_samples = {} |
| 427 | |
| 428 | # Create output tensor for sampled token ids. |
| 429 | if include_gpu_probs_tensor: |
| 430 | sampled_token_ids_tensor = torch.empty(logprobs.shape[0], |
| 431 | 1, |
| 432 | dtype=torch.long, |
| 433 | device=logprobs.device) |
| 434 | else: |
| 435 | sampled_token_ids_tensor = None |
| 436 | |
| 437 | # Counterintiutively, having two loops here is actually faster. |
| 438 | # The first loop can run without waiting on GPU<->CPU sync. |
| 439 | for sampling_type in SamplingType: |
| 440 | sample_indices = categorized_sample_indices[sampling_type][:, 0] |
| 441 | num_tokens = len(sample_indices) |
| 442 | if num_tokens == 0: |
| 443 | continue |
| 444 | seq_group_ids = categorized_seq_group_ids[sampling_type] |
| 445 | seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] |
| 446 | is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] |
| 447 | sample_metadata[sampling_type] = (seq_group_ids, seq_groups, |
| 448 | is_prompts, sample_indices) |
| 449 | long_sample_indices = sample_indices.long() |
| 450 | |
| 451 | if sampling_type == SamplingType.GREEDY: |
| 452 | greedy_samples = torch.argmax(logprobs[long_sample_indices], |
| 453 | dim=-1) |
| 454 | |
| 455 | if include_gpu_probs_tensor: |
| 456 | # Store sampled tokens in output tensor. |
| 457 | sampled_token_ids_tensor[ |
| 458 | long_sample_indices] = greedy_samples.unsqueeze(-1) |
| 459 | |
| 460 | if modify_greedy_probs: |
| 461 | # If required, modify the probabilities such that sampling from |
| 462 | # the modified distribution would always sample the argmax |
| 463 | # token id. |
| 464 | _modify_greedy_probs_inplace(logprobs, probs, |
| 465 | long_sample_indices, |
| 466 | greedy_samples) |
| 467 |
no test coverage detected