MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/_torch/modules/linear.py:2303–2339  ·  view source on GitHub ↗
(
        self,
        input: Union[torch.Tensor, Fp4QuantizedTensor],
        *,
        all_reduce_params: Optional[AllReduceParams] = None,
        lora_params: Optional[dict] = None,
        layer_idx: Optional[int] = None,
    )

Source from the content-addressed store, hash-verified

2301 return False
2302
2303 def forward(
2304 self,
2305 input: Union[torch.Tensor, Fp4QuantizedTensor],
2306 *,
2307 all_reduce_params: Optional[AllReduceParams] = None,
2308 lora_params: Optional[dict] = None,
2309 layer_idx: Optional[int] = None,
2310 ) -> torch.Tensor:
2311 if self.tp_mode == TensorParallelMode.ROW:
2312 use_fused_gemm_allreduce = self.use_fused_gemm_allreduce and lora_params is None
2313 if use_fused_gemm_allreduce and all_reduce_params is not None:
2314 use_fused_gemm_allreduce = all_reduce_params.enable_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE
2315
2316 bias = None if (self.tp_rank > 0) else self.bias
2317 if self.reduce_output:
2318 if use_fused_gemm_allreduce:
2319 output = self.apply_linear_allreduce(
2320 input, self.bias, layer_idx)
2321 else:
2322 fuse_bias = self._maybe_fuse_bias_into_allreduce(
2323 bias, all_reduce_params)
2324 bias = None if fuse_bias else bias
2325 output = self.apply_linear(input, bias, lora_params,
2326 layer_idx)
2327 output = self.all_reduce(
2328 output, all_reduce_params=all_reduce_params)
2329 else:
2330 output = self.apply_linear(input, bias, lora_params, layer_idx)
2331 elif self.tp_mode == TensorParallelMode.COLUMN:
2332 output = self.apply_linear(input, self.bias, lora_params, layer_idx)
2333 if self.gather_output:
2334 from ..distributed import allgather
2335 output = allgather(output, self.mapping)
2336 else:
2337 output = self.apply_linear(input, self.bias, lora_params, layer_idx)
2338
2339 return output
2340
2341 def load_weights(self,
2342 weights: List[Dict],

Callers 15

test_linear_mxfp4Function · 0.95
column_lm_head_forwardFunction · 0.95
row_lm_head_forwardFunction · 0.95
mlp_forwardFunction · 0.95
column_linear_forwardFunction · 0.95
row_linear_forwardFunction · 0.95
fp4_row_linear_allreduceFunction · 0.95
test_fp8_linearFunction · 0.95
test_fp8_rowwise_linearFunction · 0.95
test_w4a16_linearFunction · 0.95

Calls 4

apply_linearMethod · 0.95
allgatherFunction · 0.50

Tested by 15

test_linear_mxfp4Function · 0.76
column_lm_head_forwardFunction · 0.76
row_lm_head_forwardFunction · 0.76
mlp_forwardFunction · 0.76
column_linear_forwardFunction · 0.76
row_linear_forwardFunction · 0.76
fp4_row_linear_allreduceFunction · 0.76
test_fp8_linearFunction · 0.76
test_fp8_rowwise_linearFunction · 0.76
test_w4a16_linearFunction · 0.76