MCPcopy
hub / github.com/FoundationVision/LlamaGen / _greedy_sample

Function _greedy_sample

autoregressive/serve/sampler.py:279–295  ·  view source on GitHub ↗
(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    samples: torch.Tensor,
)

Source from the content-addressed store, hash-verified

277
278
279def _greedy_sample(
280 selected_seq_groups: List[Tuple[List[int], SamplingParams]],
281 samples: torch.Tensor,
282) -> List[Tuple[List[int], List[int]]]:
283 samples = samples.tolist()
284 sample_idx = 0
285 results = []
286 for seq_group in selected_seq_groups:
287 seq_ids, _ = seq_group
288 num_parent_seqs = len(seq_ids)
289 assert num_parent_seqs == 1, (
290 "Greedy sampling should have only one seq.")
291 parent_ids = list(range(num_parent_seqs))
292 next_token_ids = [samples[sample_idx]]
293 results.append((next_token_ids, parent_ids))
294 sample_idx += num_parent_seqs
295 return results
296
297
298def _random_sample(

Callers 2

_sample_with_torchFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected