MCPcopy
hub / github.com/FoundationVision/LlamaGen / _sample_with_triton_kernel

Function _sample_with_triton_kernel

autoregressive/serve/sampler.py:520–597  ·  view source on GitHub ↗
(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    sampling_tensors: SamplingTensors,
)

Source from the content-addressed store, hash-verified

518
519
520def _sample_with_triton_kernel(
521 probs: torch.Tensor,
522 logprobs: torch.Tensor,
523 sampling_metadata: SamplingMetadata,
524 sampling_tensors: SamplingTensors,
525) -> List[Tuple[List[int], List[int]]]:
526 categorized_seq_group_ids = {t: [] for t in SamplingType}
527 categorized_sample_indices = sampling_metadata.categorized_sample_indices
528 for i, seq_group in enumerate(sampling_metadata.seq_groups):
529 _, sampling_params = seq_group
530 sampling_type = sampling_params.sampling_type
531 categorized_seq_group_ids[sampling_type].append(i)
532
533 sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
534 sample_metadata = {}
535 max_best_of_in_batch = 1
536
537 # Counterintiutively, having two loops here is actually faster.
538 # The first loop can run without waiting on GPU<->CPU sync.
539 for sampling_type in SamplingType:
540 sample_indices = categorized_sample_indices[sampling_type][:, 0]
541 sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
542 num_tokens = len(sample_indices)
543 if num_tokens == 0:
544 continue
545 seq_group_ids = categorized_seq_group_ids[sampling_type]
546 seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
547 is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
548 sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
549 is_prompts, sample_indices,
550 sampled_token_indices)
551 if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
552 SamplingType.RANDOM_SEED):
553 for seq_group, is_prompt in zip(seq_groups, is_prompts):
554 if is_prompt:
555 _, sampling_params = seq_group
556 max_best_of_in_batch = max(max_best_of_in_batch,
557 sampling_params.best_of)
558 elif sampling_type == SamplingType.BEAM:
559 beam_search_logprobs = logprobs[sample_indices]
560 else:
561 raise ValueError(f"Unsupported sampling type: {sampling_type}")
562
563 sampled_tokens, _, _ = sample_triton(
564 probs=probs,
565 seeds=sampling_tensors.sampling_seeds,
566 max_best_of=max_best_of_in_batch,
567 sample_indices=sampling_tensors.sample_indices,
568 logprobs=logprobs,
569 # don't save logprobs because we have logic for that below
570 # TODO: use this instead of the CPU-based logic below
571 save_logprobs=False,
572 )
573
574 # GPU<->CPU sync happens in the loop below.
575
576 for sampling_type in SamplingType:
577 if sampling_type not in sample_metadata:

Callers

nothing calls this directly

Calls 4

_greedy_sampleFunction · 0.85
_random_sampleFunction · 0.85
_beam_search_sampleFunction · 0.85
updateMethod · 0.80

Tested by

no test coverage detected