| 248 | self.vocab_end_index = num_embeddings |
| 249 | |
| 250 | def forward(self, input): |
| 251 | if self.tp_size > 1: |
| 252 | # Run the ops before all_reduce/all_gather. |
| 253 | # We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module. |
| 254 | embedding_ops_func = torch.compile( |
| 255 | pre_comm_embedding_ops, |
| 256 | options={"max-autotune": True}, |
| 257 | disable=not self.enable_torch_compile_for_embedding) |
| 258 | else: |
| 259 | # Skip torch.compile when TP size is 1 to avoid unnecessary host overhead |
| 260 | embedding_ops_func = pre_comm_embedding_ops |
| 261 | output = embedding_ops_func(input, self.weight, self.tp_size, |
| 262 | self.tp_rank, self.tp_mode, |
| 263 | self.vocab_start_index, |
| 264 | self.vocab_end_index, self.gather_output, |
| 265 | self.padding_size) |
| 266 | |
| 267 | # Run the all_reduce/all_gather. |
| 268 | if self.tp_size > 1: |
| 269 | if self.tp_mode == TensorParallelMode.COLUMN: |
| 270 | # Reduce across all the model parallel GPUs. |
| 271 | output = self.all_reduce(output) |
| 272 | elif self.tp_mode == TensorParallelMode.ROW: |
| 273 | if self.gather_output: |
| 274 | # Run allgather. |
| 275 | output = allgather(output, self.mapping) |
| 276 | # Remove the padding. |
| 277 | if self.padding_size > 0: |
| 278 | output = output[..., :-self.padding_size] |
| 279 | |
| 280 | return output |
| 281 | |
| 282 | def skip_forward(self, input): |
| 283 | return embedding_skip_forward(input, self.embedding_dim, self.dtype) |