MCPcopy
hub / github.com/FoundationVision/LlamaGen / _prepare_sample

Method _prepare_sample

autoregressive/serve/model_runner.py:574–674  ·  view source on GitHub ↗
(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
        subquery_lens: Optional[List[int]],
    )

Source from the content-addressed store, hash-verified

572 )
573
574 def _prepare_sample(
575 self,
576 seq_group_metadata_list: List[SequenceGroupMetadata],
577 prompt_lens: List[int],
578 subquery_lens: Optional[List[int]],
579 ) -> SamplingMetadata:
580 seq_groups: List[Tuple[List[int], SamplingParams]] = []
581 selected_token_indices: List[int] = []
582 generators: List[torch.Generator] = []
583 selected_token_start_idx = 0
584 categorized_sample_indices: Dict[SamplingType,
585 List[Tuple[int, int]]] = {
586 t: []
587 for t in SamplingType
588 }
589 categorized_sample_indices_start_idx = 0
590 categorized_sampled_token_indices_start_idx = 0
591
592 for i, seq_group_metadata in enumerate(seq_group_metadata_list):
593 seq_ids = list(seq_group_metadata.seq_data.keys())
594 sampling_params = seq_group_metadata.sampling_params
595 seq_groups.append((seq_ids, sampling_params))
596
597 if seq_group_metadata.is_prompt:
598 assert len(seq_ids) == 1
599 assert subquery_lens is not None
600 subquery_len = subquery_lens[i]
601 if sampling_params.prompt_logprobs is not None:
602 # NOTE: prompt token positions do not need sample, skip
603 categorized_sample_indices_start_idx += subquery_len - 1
604
605 categorized_sample_indices[
606 sampling_params.sampling_type].append(
607 (categorized_sample_indices_start_idx,
608 categorized_sampled_token_indices_start_idx))
609 categorized_sample_indices_start_idx += 1
610 categorized_sampled_token_indices_start_idx += 1
611
612 if sampling_params.prompt_logprobs is not None:
613 selected_token_indices.extend(
614 range(selected_token_start_idx,
615 selected_token_start_idx + subquery_len - 1))
616 selected_token_indices.append(selected_token_start_idx +
617 subquery_len - 1)
618 selected_token_start_idx += subquery_len
619
620 if sampling_params.seed is not None:
621 seq_group_metadata.state.generator = torch.Generator(
622 device=self.device).manual_seed(sampling_params.seed)
623 else:
624 num_seqs = len(seq_ids)
625 selected_token_indices.extend(
626 range(selected_token_start_idx,
627 selected_token_start_idx + num_seqs))
628 selected_token_start_idx += num_seqs
629
630 categorized_sample_indices[
631 sampling_params.sampling_type].extend(

Callers 1

prepare_input_tensorsMethod · 0.95

Calls 1

updateMethod · 0.80

Tested by

no test coverage detected