| 25 | |
| 26 | |
| 27 | class DiffusionEngine(BaseEngine): |
| 28 | def __init__( |
| 29 | self, |
| 30 | model_or_path: DiffusionPipeline | str, |
| 31 | inference_config: InferenceConfig = None, |
| 32 | verbose: bool = False, |
| 33 | model_policy: Policy | type[Policy] = None, |
| 34 | ) -> None: |
| 35 | self.inference_config = inference_config |
| 36 | self.dtype = inference_config.dtype |
| 37 | self.high_precision = inference_config.high_precision |
| 38 | |
| 39 | self.verbose = verbose |
| 40 | self.logger = get_dist_logger(__name__) |
| 41 | self.model_shard_infer_config = inference_config.to_model_shard_inference_config() |
| 42 | |
| 43 | self.model_type = get_model_type(model_or_path=model_or_path) |
| 44 | |
| 45 | self.init_model(model_or_path, model_policy, self.model_shard_infer_config) |
| 46 | |
| 47 | self.request_handler = NaiveRequestHandler() |
| 48 | |
| 49 | self.counter = count() |
| 50 | |
| 51 | self._verify_args() |
| 52 | |
| 53 | def _verify_args(self) -> None: |
| 54 | assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe" |
| 55 | |
| 56 | def init_model( |
| 57 | self, |
| 58 | model_or_path: Union[str, nn.Module, DiffusionPipeline], |
| 59 | model_policy: Union[Policy, Type[Policy]] = None, |
| 60 | model_shard_infer_config: ModelShardInferenceConfig = None, |
| 61 | ): |
| 62 | """ |
| 63 | Shard model or/and Load weight |
| 64 | |
| 65 | Args: |
| 66 | model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. |
| 67 | model_policy (Policy): the policy to replace the model. |
| 68 | model_inference_config: the configuration for modeling initialization when inference. |
| 69 | model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. |
| 70 | """ |
| 71 | if isinstance(model_or_path, str): |
| 72 | model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype) |
| 73 | policy_map_key = model.__class__.__name__ |
| 74 | model = DiffusionPipe(model) |
| 75 | elif isinstance(model_or_path, DiffusionPipeline): |
| 76 | policy_map_key = model_or_path.__class__.__name__ |
| 77 | model = DiffusionPipe(model_or_path) |
| 78 | else: |
| 79 | self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!") |
| 80 | |
| 81 | torch.cuda.empty_cache() |
| 82 | init_gpu_memory = torch.cuda.mem_get_info()[0] |
| 83 | |
| 84 | self.device = get_accelerator().get_current_device() |
no outgoing calls
no test coverage detected
searching dependent graphs…