(
self,
model_or_path: Union[nn.Module, str, DiffusionPipeline],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Union[Policy, Type[Policy]] = None,
)
| 26 | """ |
| 27 | |
| 28 | def __init__( |
| 29 | self, |
| 30 | model_or_path: Union[nn.Module, str, DiffusionPipeline], |
| 31 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, |
| 32 | inference_config: InferenceConfig = None, |
| 33 | verbose: bool = False, |
| 34 | model_policy: Union[Policy, Type[Policy]] = None, |
| 35 | ) -> None: |
| 36 | self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__ |
| 37 | self.model_type = get_model_type(model_or_path=model_or_path) |
| 38 | self.engine = None |
| 39 | if self.model_type == ModelType.LLM: |
| 40 | from .llm_engine import LLMEngine |
| 41 | |
| 42 | self.engine = LLMEngine( |
| 43 | model_or_path=model_or_path, |
| 44 | tokenizer=tokenizer, |
| 45 | inference_config=inference_config, |
| 46 | verbose=verbose, |
| 47 | model_policy=model_policy, |
| 48 | ) |
| 49 | elif self.model_type == ModelType.DIFFUSION_MODEL: |
| 50 | from .diffusion_engine import DiffusionEngine |
| 51 | |
| 52 | self.engine = DiffusionEngine( |
| 53 | model_or_path=model_or_path, |
| 54 | inference_config=inference_config, |
| 55 | verbose=verbose, |
| 56 | model_policy=model_policy, |
| 57 | ) |
| 58 | elif self.model_type == ModelType.UNKNOWN: |
| 59 | self.logger.error(f"Model Type either Difffusion or LLM!") |
| 60 | |
| 61 | self._initialized = True |
| 62 | self._verify_args() |
| 63 | |
| 64 | def _verify_args(self) -> None: |
| 65 | """Verify the input args""" |
nothing calls this directly
no test coverage detected