(
engine: Union[Path, Engine],
rpc_addr: str,
*,
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
llm_args: Optional[BaseLlmArgs] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
**kwargs,
)
| 111 | |
| 112 | @staticmethod |
| 113 | def main_task( |
| 114 | engine: Union[Path, Engine], |
| 115 | rpc_addr: str, |
| 116 | *, |
| 117 | executor_config: Optional[tllm.ExecutorConfig] = None, |
| 118 | batched_logits_processor: Optional[BatchedLogitsProcessor] = None, |
| 119 | postproc_worker_config: Optional[PostprocWorkerConfig] = None, |
| 120 | is_llm_executor: Optional[bool] = None, |
| 121 | llm_args: Optional[BaseLlmArgs] = None, |
| 122 | hf_model_dir: Optional[Path] = None, |
| 123 | tokenizer: Optional[TokenizerBase] = None, |
| 124 | **kwargs, |
| 125 | ) -> None: |
| 126 | nvtx.push_range(f"RpcWorker.main_task_{mpi_rank()}", color="pink") |
| 127 | |
| 128 | if enable_llm_debug(): |
| 129 | set_level("debug") |
| 130 | |
| 131 | # Step 1: Create the worker instance |
| 132 | worker = RpcWorker( |
| 133 | engine=engine, |
| 134 | executor_config=executor_config, |
| 135 | is_llm_executor=is_llm_executor, |
| 136 | llm_args=llm_args, |
| 137 | batched_logits_processor=batched_logits_processor, |
| 138 | postproc_worker_config=postproc_worker_config, |
| 139 | hf_model_dir=hf_model_dir, |
| 140 | tokenizer=tokenizer, |
| 141 | ) |
| 142 | |
| 143 | if mpi_rank() != 0: |
| 144 | # The non-leader worker will setup the engine immediately. |
| 145 | # The leader worker will wait for the RPC call to propagate the |
| 146 | # potential error. |
| 147 | logger_debug( |
| 148 | f"[worker] Worker {mpi_rank()} is setting up the engine", |
| 149 | color="yellow") |
| 150 | worker.setup_engine() |
| 151 | |
| 152 | else: |
| 153 | logger_debug( |
| 154 | f"[worker] Worker {mpi_rank()} is creating the RPC service with {worker.num_workers} workers", |
| 155 | color="yellow") |
| 156 | # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client |
| 157 | # Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async. |
| 158 | hmac_key = kwargs.get("hmac_key") |
| 159 | rpc_server = RPCServer(worker, |
| 160 | num_workers=worker.num_workers, |
| 161 | hmac_key=hmac_key) |
| 162 | rpc_server.bind(rpc_addr) |
| 163 | rpc_server.start() |
| 164 | logger_debug(f"[worker] RPC server {mpi_rank()} is started", |
| 165 | color="yellow") |
| 166 | |
| 167 | # Step 3: Wait for the worker to shutdown |
| 168 | logger_debug( |
| 169 | f"[worker] Worker {mpi_rank()} is waiting for shutdown event", |
| 170 | color="yellow") |
nothing calls this directly
no test coverage detected