(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
)
| 518 | |
| 519 | |
| 520 | def _sample_with_triton_kernel( |
| 521 | probs: torch.Tensor, |
| 522 | logprobs: torch.Tensor, |
| 523 | sampling_metadata: SamplingMetadata, |
| 524 | sampling_tensors: SamplingTensors, |
| 525 | ) -> List[Tuple[List[int], List[int]]]: |
| 526 | categorized_seq_group_ids = {t: [] for t in SamplingType} |
| 527 | categorized_sample_indices = sampling_metadata.categorized_sample_indices |
| 528 | for i, seq_group in enumerate(sampling_metadata.seq_groups): |
| 529 | _, sampling_params = seq_group |
| 530 | sampling_type = sampling_params.sampling_type |
| 531 | categorized_seq_group_ids[sampling_type].append(i) |
| 532 | |
| 533 | sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} |
| 534 | sample_metadata = {} |
| 535 | max_best_of_in_batch = 1 |
| 536 | |
| 537 | # Counterintiutively, having two loops here is actually faster. |
| 538 | # The first loop can run without waiting on GPU<->CPU sync. |
| 539 | for sampling_type in SamplingType: |
| 540 | sample_indices = categorized_sample_indices[sampling_type][:, 0] |
| 541 | sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] |
| 542 | num_tokens = len(sample_indices) |
| 543 | if num_tokens == 0: |
| 544 | continue |
| 545 | seq_group_ids = categorized_seq_group_ids[sampling_type] |
| 546 | seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] |
| 547 | is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] |
| 548 | sample_metadata[sampling_type] = (seq_group_ids, seq_groups, |
| 549 | is_prompts, sample_indices, |
| 550 | sampled_token_indices) |
| 551 | if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, |
| 552 | SamplingType.RANDOM_SEED): |
| 553 | for seq_group, is_prompt in zip(seq_groups, is_prompts): |
| 554 | if is_prompt: |
| 555 | _, sampling_params = seq_group |
| 556 | max_best_of_in_batch = max(max_best_of_in_batch, |
| 557 | sampling_params.best_of) |
| 558 | elif sampling_type == SamplingType.BEAM: |
| 559 | beam_search_logprobs = logprobs[sample_indices] |
| 560 | else: |
| 561 | raise ValueError(f"Unsupported sampling type: {sampling_type}") |
| 562 | |
| 563 | sampled_tokens, _, _ = sample_triton( |
| 564 | probs=probs, |
| 565 | seeds=sampling_tensors.sampling_seeds, |
| 566 | max_best_of=max_best_of_in_batch, |
| 567 | sample_indices=sampling_tensors.sample_indices, |
| 568 | logprobs=logprobs, |
| 569 | # don't save logprobs because we have logic for that below |
| 570 | # TODO: use this instead of the CPU-based logic below |
| 571 | save_logprobs=False, |
| 572 | ) |
| 573 | |
| 574 | # GPU<->CPU sync happens in the loop below. |
| 575 | |
| 576 | for sampling_type in SamplingType: |
| 577 | if sampling_type not in sample_metadata: |
nothing calls this directly
no test coverage detected