(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None)
| 184 | await asyncio.gather(*init_tasks) |
| 185 | |
| 186 | async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): |
| 187 | assert len(self.workers) == self.tp_size, "init workers first" |
| 188 | |
| 189 | inference_config_param = self.inference_config.to_rpc_param() |
| 190 | model_path = model_or_path |
| 191 | model_policy_param = model_policy.to_rpc_param() if model_policy else None |
| 192 | |
| 193 | init_tasks = [ |
| 194 | self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param) |
| 195 | for rank, worker in enumerate(self.workers) |
| 196 | ] |
| 197 | |
| 198 | await asyncio.gather(*init_tasks) |
| 199 | |
| 200 | def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler: |
| 201 | return RPCRequestHandler(inference_config, model_config) |
no test coverage detected