| 2301 | return False |
| 2302 | |
| 2303 | def forward( |
| 2304 | self, |
| 2305 | input: Union[torch.Tensor, Fp4QuantizedTensor], |
| 2306 | *, |
| 2307 | all_reduce_params: Optional[AllReduceParams] = None, |
| 2308 | lora_params: Optional[dict] = None, |
| 2309 | layer_idx: Optional[int] = None, |
| 2310 | ) -> torch.Tensor: |
| 2311 | if self.tp_mode == TensorParallelMode.ROW: |
| 2312 | use_fused_gemm_allreduce = self.use_fused_gemm_allreduce and lora_params is None |
| 2313 | if use_fused_gemm_allreduce and all_reduce_params is not None: |
| 2314 | use_fused_gemm_allreduce = all_reduce_params.enable_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE |
| 2315 | |
| 2316 | bias = None if (self.tp_rank > 0) else self.bias |
| 2317 | if self.reduce_output: |
| 2318 | if use_fused_gemm_allreduce: |
| 2319 | output = self.apply_linear_allreduce( |
| 2320 | input, self.bias, layer_idx) |
| 2321 | else: |
| 2322 | fuse_bias = self._maybe_fuse_bias_into_allreduce( |
| 2323 | bias, all_reduce_params) |
| 2324 | bias = None if fuse_bias else bias |
| 2325 | output = self.apply_linear(input, bias, lora_params, |
| 2326 | layer_idx) |
| 2327 | output = self.all_reduce( |
| 2328 | output, all_reduce_params=all_reduce_params) |
| 2329 | else: |
| 2330 | output = self.apply_linear(input, bias, lora_params, layer_idx) |
| 2331 | elif self.tp_mode == TensorParallelMode.COLUMN: |
| 2332 | output = self.apply_linear(input, self.bias, lora_params, layer_idx) |
| 2333 | if self.gather_output: |
| 2334 | from ..distributed import allgather |
| 2335 | output = allgather(output, self.mapping) |
| 2336 | else: |
| 2337 | output = self.apply_linear(input, self.bias, lora_params, layer_idx) |
| 2338 | |
| 2339 | return output |
| 2340 | |
| 2341 | def load_weights(self, |
| 2342 | weights: List[Dict], |