(
engine: Path | Engine,
worker_queues: WorkerCommIpcAddrs,
log_level: str,
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
worker_cls: type = GenerationExecutorWorker,
tracer_init_kwargs: Optional[dict] = None,
_torch_model_class_mapping: Optional[dict] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
ready_signal: Optional[str] = None,
is_llm_executor: Optional[
bool] = True, # whether it's the main executor instance
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
rpc_addr: Optional[str] = None,
hmac_key: Optional[bytes] = None,
)
| 136 | |
| 137 | @print_traceback_on_error |
| 138 | def worker_main( |
| 139 | engine: Path | Engine, |
| 140 | worker_queues: WorkerCommIpcAddrs, |
| 141 | log_level: str, |
| 142 | executor_config: Optional[tllm.ExecutorConfig] = None, |
| 143 | batched_logits_processor: Optional[BatchedLogitsProcessor] = None, |
| 144 | worker_cls: type = GenerationExecutorWorker, |
| 145 | tracer_init_kwargs: Optional[dict] = None, |
| 146 | _torch_model_class_mapping: Optional[dict] = None, |
| 147 | postproc_worker_config: Optional[PostprocWorkerConfig] = None, |
| 148 | ready_signal: Optional[str] = None, |
| 149 | is_llm_executor: Optional[ |
| 150 | bool] = True, # whether it's the main executor instance |
| 151 | hf_model_dir: Optional[Path] = None, |
| 152 | tokenizer: Optional[TokenizerBase] = None, |
| 153 | llm_args: Optional[BaseLlmArgs] = None, |
| 154 | rpc_addr: Optional[str] = None, |
| 155 | hmac_key: Optional[bytes] = None, |
| 156 | ) -> None: |
| 157 | |
| 158 | def _print_stacks(): |
| 159 | counter = 0 |
| 160 | while True: |
| 161 | time.sleep(print_stacks_period) |
| 162 | counter += 1 |
| 163 | logger.error(f"Printing stacks {counter} times") |
| 164 | print_all_stacks() |
| 165 | |
| 166 | print_stacks_period = int( |
| 167 | os.getenv("TRTLLM_WORKER_PRINT_STACKS_PERIOD", "-1")) |
| 168 | if print_stacks_period > 0: |
| 169 | print_stacks_thread = threading.Thread(target=_print_stacks, |
| 170 | daemon=True) |
| 171 | print_stacks_thread.start() |
| 172 | |
| 173 | mpi_comm().barrier() |
| 174 | |
| 175 | if llm_args is not None and llm_args.env_overrides: |
| 176 | # this is needed because MPI_Init seems to cache the env at import time. |
| 177 | # The cached env snapshot is used to spawn workers. |
| 178 | # Any env overrides to the main process after tensorrt_llm import |
| 179 | # may not get reflected in the spawned worker process, no matter how early, |
| 180 | # unless we update it explicitly here. |
| 181 | os.environ.update(llm_args.env_overrides) |
| 182 | |
| 183 | if llm_args is not None and llm_args.trust_remote_code: |
| 184 | _init_hf_modules() |
| 185 | |
| 186 | logger_debug(f"Worker {mpi_rank()} entering worker_main...\n", "green") |
| 187 | |
| 188 | result_queue: Optional[IpcQueue] = None |
| 189 | result_queues: Optional[List[IpcQueue]] = None |
| 190 | |
| 191 | postproc_worker_config = postproc_worker_config or PostprocWorkerConfig() |
| 192 | |
| 193 | is_leader: bool = mpi_rank() == 0 |
| 194 | if tracer_init_kwargs is not None and is_leader: |
| 195 | tracer = VizTracer(**tracer_init_kwargs) |
nothing calls this directly
no test coverage detected