Add requests. Args: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to
(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
)
| 578 | raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") |
| 579 | |
| 580 | def add_request( |
| 581 | self, |
| 582 | request_ids: Union[List[int], int] = None, |
| 583 | prompts: Union[List[str], str] = None, |
| 584 | prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, |
| 585 | **kwargs, |
| 586 | ) -> None: |
| 587 | """ |
| 588 | Add requests. |
| 589 | |
| 590 | Args: |
| 591 | request_ids (List[int], optional): The request ID. Defaults to None. |
| 592 | prompts (Union[List[str], optional): Input prompts. Defaults to None. |
| 593 | prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. |
| 594 | """ |
| 595 | |
| 596 | # apply the prompt template to the input prompts |
| 597 | |
| 598 | if self.has_prompt_template and prompts is not None: |
| 599 | prompts = self.format_prompt(prompts) |
| 600 | |
| 601 | block_size = self.inference_config.block_size |
| 602 | |
| 603 | if request_ids is not None and not isinstance(request_ids, list): |
| 604 | request_ids = [request_ids] |
| 605 | |
| 606 | if prompts is not None and not isinstance(prompts, list): |
| 607 | prompts = [prompts] |
| 608 | |
| 609 | if prompts_token_ids is None: |
| 610 | assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." |
| 611 | prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ |
| 612 | "input_ids" |
| 613 | ] |
| 614 | |
| 615 | # list of torch Tensor |
| 616 | if isinstance(prompts_token_ids, list): |
| 617 | if isinstance(prompts_token_ids[0], torch.Tensor): |
| 618 | prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] |
| 619 | elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): |
| 620 | prompts_token_ids = prompts_token_ids.tolist() |
| 621 | else: |
| 622 | raise TypeError( |
| 623 | f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." |
| 624 | ) |
| 625 | |
| 626 | assert ( |
| 627 | len(prompts_token_ids[0]) <= self.inference_config.max_input_len |
| 628 | ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." |
| 629 | |
| 630 | prompts_num = len(prompts_token_ids) |
| 631 | |
| 632 | for i in range(prompts_num): |
| 633 | if request_ids: |
| 634 | assert isinstance( |
| 635 | request_ids[0], int |
| 636 | ), f"The request_id type must be int, but got {type(request_ids[0])}" |
| 637 | assert len(request_ids) == prompts_num |
no test coverage detected