MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / main_task

Method main_task

tensorrt_llm/executor/rpc_worker.py:113–172  ·  view source on GitHub ↗
(
        engine: Union[Path, Engine],
        rpc_addr: str,
        *,
        executor_config: Optional[tllm.ExecutorConfig] = None,
        batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
        postproc_worker_config: Optional[PostprocWorkerConfig] = None,
        is_llm_executor: Optional[bool] = None,
        llm_args: Optional[BaseLlmArgs] = None,
        hf_model_dir: Optional[Path] = None,
        tokenizer: Optional[TokenizerBase] = None,
        **kwargs,
    )

Source from the content-addressed store, hash-verified

111
112 @staticmethod
113 def main_task(
114 engine: Union[Path, Engine],
115 rpc_addr: str,
116 *,
117 executor_config: Optional[tllm.ExecutorConfig] = None,
118 batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
119 postproc_worker_config: Optional[PostprocWorkerConfig] = None,
120 is_llm_executor: Optional[bool] = None,
121 llm_args: Optional[BaseLlmArgs] = None,
122 hf_model_dir: Optional[Path] = None,
123 tokenizer: Optional[TokenizerBase] = None,
124 **kwargs,
125 ) -> None:
126 nvtx.push_range(f"RpcWorker.main_task_{mpi_rank()}", color="pink")
127
128 if enable_llm_debug():
129 set_level("debug")
130
131 # Step 1: Create the worker instance
132 worker = RpcWorker(
133 engine=engine,
134 executor_config=executor_config,
135 is_llm_executor=is_llm_executor,
136 llm_args=llm_args,
137 batched_logits_processor=batched_logits_processor,
138 postproc_worker_config=postproc_worker_config,
139 hf_model_dir=hf_model_dir,
140 tokenizer=tokenizer,
141 )
142
143 if mpi_rank() != 0:
144 # The non-leader worker will setup the engine immediately.
145 # The leader worker will wait for the RPC call to propagate the
146 # potential error.
147 logger_debug(
148 f"[worker] Worker {mpi_rank()} is setting up the engine",
149 color="yellow")
150 worker.setup_engine()
151
152 else:
153 logger_debug(
154 f"[worker] Worker {mpi_rank()} is creating the RPC service with {worker.num_workers} workers",
155 color="yellow")
156 # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
157 # Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
158 hmac_key = kwargs.get("hmac_key")
159 rpc_server = RPCServer(worker,
160 num_workers=worker.num_workers,
161 hmac_key=hmac_key)
162 rpc_server.bind(rpc_addr)
163 rpc_server.start()
164 logger_debug(f"[worker] RPC server {mpi_rank()} is started",
165 color="yellow")
166
167 # Step 3: Wait for the worker to shutdown
168 logger_debug(
169 f"[worker] Worker {mpi_rank()} is waiting for shutdown event",
170 color="yellow")

Callers

nothing calls this directly

Calls 12

setup_engineMethod · 0.95
bindMethod · 0.95
startMethod · 0.95
shutdownMethod · 0.95
enable_llm_debugFunction · 0.90
logger_debugFunction · 0.90
set_levelFunction · 0.85
RpcWorkerClass · 0.85
RPCServerClass · 0.85
waitMethod · 0.80
mpi_rankFunction · 0.50
getMethod · 0.45

Tested by

no test coverage detected