(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
)
| 572 | ) |
| 573 | |
| 574 | def _prepare_sample( |
| 575 | self, |
| 576 | seq_group_metadata_list: List[SequenceGroupMetadata], |
| 577 | prompt_lens: List[int], |
| 578 | subquery_lens: Optional[List[int]], |
| 579 | ) -> SamplingMetadata: |
| 580 | seq_groups: List[Tuple[List[int], SamplingParams]] = [] |
| 581 | selected_token_indices: List[int] = [] |
| 582 | generators: List[torch.Generator] = [] |
| 583 | selected_token_start_idx = 0 |
| 584 | categorized_sample_indices: Dict[SamplingType, |
| 585 | List[Tuple[int, int]]] = { |
| 586 | t: [] |
| 587 | for t in SamplingType |
| 588 | } |
| 589 | categorized_sample_indices_start_idx = 0 |
| 590 | categorized_sampled_token_indices_start_idx = 0 |
| 591 | |
| 592 | for i, seq_group_metadata in enumerate(seq_group_metadata_list): |
| 593 | seq_ids = list(seq_group_metadata.seq_data.keys()) |
| 594 | sampling_params = seq_group_metadata.sampling_params |
| 595 | seq_groups.append((seq_ids, sampling_params)) |
| 596 | |
| 597 | if seq_group_metadata.is_prompt: |
| 598 | assert len(seq_ids) == 1 |
| 599 | assert subquery_lens is not None |
| 600 | subquery_len = subquery_lens[i] |
| 601 | if sampling_params.prompt_logprobs is not None: |
| 602 | # NOTE: prompt token positions do not need sample, skip |
| 603 | categorized_sample_indices_start_idx += subquery_len - 1 |
| 604 | |
| 605 | categorized_sample_indices[ |
| 606 | sampling_params.sampling_type].append( |
| 607 | (categorized_sample_indices_start_idx, |
| 608 | categorized_sampled_token_indices_start_idx)) |
| 609 | categorized_sample_indices_start_idx += 1 |
| 610 | categorized_sampled_token_indices_start_idx += 1 |
| 611 | |
| 612 | if sampling_params.prompt_logprobs is not None: |
| 613 | selected_token_indices.extend( |
| 614 | range(selected_token_start_idx, |
| 615 | selected_token_start_idx + subquery_len - 1)) |
| 616 | selected_token_indices.append(selected_token_start_idx + |
| 617 | subquery_len - 1) |
| 618 | selected_token_start_idx += subquery_len |
| 619 | |
| 620 | if sampling_params.seed is not None: |
| 621 | seq_group_metadata.state.generator = torch.Generator( |
| 622 | device=self.device).manual_seed(sampling_params.seed) |
| 623 | else: |
| 624 | num_seqs = len(seq_ids) |
| 625 | selected_token_indices.extend( |
| 626 | range(selected_token_start_idx, |
| 627 | selected_token_start_idx + num_seqs)) |
| 628 | selected_token_start_idx += num_seqs |
| 629 | |
| 630 | categorized_sample_indices[ |
| 631 | sampling_params.sampling_type].extend( |
no test coverage detected