InferenceEngine which manages the inference process.. Args: model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model. tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
| 14 | |
| 15 | |
| 16 | class InferenceEngine: |
| 17 | """ |
| 18 | InferenceEngine which manages the inference process.. |
| 19 | |
| 20 | Args: |
| 21 | model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model. |
| 22 | tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. |
| 23 | inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. |
| 24 | verbose (bool): Determine whether or not to log the generation process. |
| 25 | model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. |
| 26 | """ |
| 27 | |
| 28 | def __init__( |
| 29 | self, |
| 30 | model_or_path: Union[nn.Module, str, DiffusionPipeline], |
| 31 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, |
| 32 | inference_config: InferenceConfig = None, |
| 33 | verbose: bool = False, |
| 34 | model_policy: Union[Policy, Type[Policy]] = None, |
| 35 | ) -> None: |
| 36 | self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__ |
| 37 | self.model_type = get_model_type(model_or_path=model_or_path) |
| 38 | self.engine = None |
| 39 | if self.model_type == ModelType.LLM: |
| 40 | from .llm_engine import LLMEngine |
| 41 | |
| 42 | self.engine = LLMEngine( |
| 43 | model_or_path=model_or_path, |
| 44 | tokenizer=tokenizer, |
| 45 | inference_config=inference_config, |
| 46 | verbose=verbose, |
| 47 | model_policy=model_policy, |
| 48 | ) |
| 49 | elif self.model_type == ModelType.DIFFUSION_MODEL: |
| 50 | from .diffusion_engine import DiffusionEngine |
| 51 | |
| 52 | self.engine = DiffusionEngine( |
| 53 | model_or_path=model_or_path, |
| 54 | inference_config=inference_config, |
| 55 | verbose=verbose, |
| 56 | model_policy=model_policy, |
| 57 | ) |
| 58 | elif self.model_type == ModelType.UNKNOWN: |
| 59 | self.logger.error(f"Model Type either Difffusion or LLM!") |
| 60 | |
| 61 | self._initialized = True |
| 62 | self._verify_args() |
| 63 | |
| 64 | def _verify_args(self) -> None: |
| 65 | """Verify the input args""" |
| 66 | assert self.engine is not None, "Please init Engine first" |
| 67 | assert self._initialized, "Engine must be initialized" |
| 68 | |
| 69 | def generate( |
| 70 | self, |
| 71 | request_ids: Union[List[int], int] = None, |
| 72 | prompts: Union[List[str], str] = None, |
| 73 | *args, |
no outgoing calls
searching dependent graphs…