MCPcopy
hub / github.com/FoundationVision/LlamaGen / _sample_with_torch

Function _sample_with_torch

autoregressive/serve/sampler.py:410–517  ·  view source on GitHub ↗
(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    include_gpu_probs_tensor: bool,
    modify_greedy_probs: bool,
)

Source from the content-addressed store, hash-verified

408
409
410def _sample_with_torch(
411 probs: torch.Tensor,
412 logprobs: torch.Tensor,
413 sampling_metadata: SamplingMetadata,
414 include_gpu_probs_tensor: bool,
415 modify_greedy_probs: bool,
416) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
417 categorized_seq_group_ids = {t: [] for t in SamplingType}
418 categorized_sample_indices = sampling_metadata.categorized_sample_indices
419 for i, seq_group in enumerate(sampling_metadata.seq_groups):
420 _, sampling_params = seq_group
421 sampling_type = sampling_params.sampling_type
422 categorized_seq_group_ids[sampling_type].append(i)
423
424 sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
425 sample_metadata = {}
426 multinomial_samples = {}
427
428 # Create output tensor for sampled token ids.
429 if include_gpu_probs_tensor:
430 sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
431 1,
432 dtype=torch.long,
433 device=logprobs.device)
434 else:
435 sampled_token_ids_tensor = None
436
437 # Counterintiutively, having two loops here is actually faster.
438 # The first loop can run without waiting on GPU<->CPU sync.
439 for sampling_type in SamplingType:
440 sample_indices = categorized_sample_indices[sampling_type][:, 0]
441 num_tokens = len(sample_indices)
442 if num_tokens == 0:
443 continue
444 seq_group_ids = categorized_seq_group_ids[sampling_type]
445 seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
446 is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
447 sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
448 is_prompts, sample_indices)
449 long_sample_indices = sample_indices.long()
450
451 if sampling_type == SamplingType.GREEDY:
452 greedy_samples = torch.argmax(logprobs[long_sample_indices],
453 dim=-1)
454
455 if include_gpu_probs_tensor:
456 # Store sampled tokens in output tensor.
457 sampled_token_ids_tensor[
458 long_sample_indices] = greedy_samples.unsqueeze(-1)
459
460 if modify_greedy_probs:
461 # If required, modify the probabilities such that sampling from
462 # the modified distribution would always sample the argmax
463 # token id.
464 _modify_greedy_probs_inplace(logprobs, probs,
465 long_sample_indices,
466 greedy_samples)
467

Callers 1

_sampleFunction · 0.85

Calls 7

_multinomialFunction · 0.85
_greedy_sampleFunction · 0.85
_random_sampleFunction · 0.85
_beam_search_sampleFunction · 0.85
updateMethod · 0.80
emptyMethod · 0.45

Tested by

no test coverage detected