MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/_torch/modules/embedding.py:250–280  ·  view source on GitHub ↗
(self, input)

Source from the content-addressed store, hash-verified

248 self.vocab_end_index = num_embeddings
249
250 def forward(self, input):
251 if self.tp_size > 1:
252 # Run the ops before all_reduce/all_gather.
253 # We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module.
254 embedding_ops_func = torch.compile(
255 pre_comm_embedding_ops,
256 options={"max-autotune": True},
257 disable=not self.enable_torch_compile_for_embedding)
258 else:
259 # Skip torch.compile when TP size is 1 to avoid unnecessary host overhead
260 embedding_ops_func = pre_comm_embedding_ops
261 output = embedding_ops_func(input, self.weight, self.tp_size,
262 self.tp_rank, self.tp_mode,
263 self.vocab_start_index,
264 self.vocab_end_index, self.gather_output,
265 self.padding_size)
266
267 # Run the all_reduce/all_gather.
268 if self.tp_size > 1:
269 if self.tp_mode == TensorParallelMode.COLUMN:
270 # Reduce across all the model parallel GPUs.
271 output = self.all_reduce(output)
272 elif self.tp_mode == TensorParallelMode.ROW:
273 if self.gather_output:
274 # Run allgather.
275 output = allgather(output, self.mapping)
276 # Remove the padding.
277 if self.padding_size > 0:
278 output = output[..., :-self.padding_size]
279
280 return output
281
282 def skip_forward(self, input):
283 return embedding_skip_forward(input, self.embedding_dim, self.dtype)

Callers

nothing calls this directly

Calls 2

allgatherFunction · 0.50
compileMethod · 0.45

Tested by

no test coverage detected