MCPcopy
hub / github.com/FoundationVision/LlamaGen / _random_sample

Function _random_sample

autoregressive/serve/sampler.py:298–322  ·  view source on GitHub ↗
(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    random_samples: torch.Tensor,
)

Source from the content-addressed store, hash-verified

296
297
298def _random_sample(
299 selected_seq_groups: List[Tuple[List[int], SamplingParams]],
300 is_prompts: List[bool],
301 random_samples: torch.Tensor,
302) -> List[Tuple[List[int], List[int]]]:
303 # Find the maximum best_of value of the prompt phase requests.
304 random_samples = random_samples.cpu()
305 sample_idx = 0
306 results = []
307 for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
308 seq_ids, sampling_params = seq_group
309 num_parent_seqs = len(seq_ids)
310 if is_prompt:
311 # Prompt phase.
312 parent_ids = [0] * sampling_params.best_of
313 next_token_ids = random_samples[
314 sample_idx, :sampling_params.best_of].tolist()
315 else:
316 # Generation phase.
317 parent_ids = list(range(num_parent_seqs))
318 next_token_ids = random_samples[sample_idx:sample_idx +
319 num_parent_seqs, 0].tolist()
320 results.append((next_token_ids, parent_ids))
321 sample_idx += num_parent_seqs
322 return results
323
324
325def _beam_search_sample(

Callers 2

_sample_with_torchFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected