MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/quantization/layers.py:2302–2368  ·  view source on GitHub ↗
(self, x, lora_runtime_params=None, all_reduce_params=None)

Source from the content-addressed store, hash-verified

2300 }
2301
2302 def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
2303 assert lora_runtime_params is None, "lora is not supported on FP4Linear now"
2304
2305 if isinstance(x, (tuple, list)):
2306 fp4_x, act_per_block_scale = x
2307 else:
2308 if default_net().plugin_config.gemm_plugin == "nvfp4":
2309 fp4_x, act_per_block_scale = quantize_to_fp4_tensor(
2310 x, div(1.0, self.activation_global_scaling_factor.value))
2311 else:
2312 # WAR for FP8 output attention
2313 if x.dtype == trt.fp8:
2314 # Since the scale is NVFP4 scale, we need to make it back to fp8 scale
2315 new_scale_factor = self.activation_global_scaling_factor.raw_value
2316 new_scale_factor = constant(new_scale_factor * 6)
2317 x = dequantize(x, new_scale_factor, 0,
2318 new_scale_factor.dtype)
2319 fp4_x, act_per_block_scale = dynamic_quantize(
2320 x, self.activation_global_scaling_factor.value)
2321
2322 if default_net().plugin_config.gemm_allreduce_plugin:
2323 x = gemm_allreduce(
2324 a=fp4_x,
2325 b=self.weight.value,
2326 a_sf=act_per_block_scale,
2327 b_sf=self.weights_block_scaling_factor_interleaved.value,
2328 transa=False, # row-major
2329 transb=True, # col-major
2330 alpha=self.alpha.value,
2331 group=self.tp_group, # ranks participating
2332 fp8_inputs_override=False)
2333 else:
2334 if default_net().plugin_config.gemm_plugin == "nvfp4":
2335 x = fp4_gemm(
2336 fp4_x, act_per_block_scale, self.weight.value,
2337 self.weights_block_scaling_factor_interleaved.value,
2338 self.alpha.value, self.dtype)
2339 else:
2340 quant_w = self.weight.value
2341 scale_w = self.weights_block_scaling_factor.value
2342 dequant_x = block_double_dequantize(
2343 fp4_x, act_per_block_scale,
2344 self.activation_global_scaling_factor.value, trt.float16)
2345 dequant_w = block_double_dequantize(
2346 quant_w, scale_w, self.weights_global_scaling_factor.value,
2347 trt.float16)
2348 x = matmul(dequant_x, dequant_w, transb=True).cast(self.dtype)
2349
2350 if self.tp_size > 1 and self.tp_group is not None:
2351 need_bias = self.bias is not None
2352 fuse_bias_into_all_reduce = need_bias and (
2353 all_reduce_params
2354 is not None) and (all_reduce_params.fusion_op
2355 == AllReduceFusionOp.RESIDUAL_RMS_NORM)
2356 if fuse_bias_into_all_reduce:
2357 all_reduce_params.bias = self.bias.value
2358 x = allreduce(x,
2359 self.tp_group,

Callers

nothing calls this directly

Calls 11

default_netFunction · 0.85
quantize_to_fp4_tensorFunction · 0.85
constantFunction · 0.85
dequantizeFunction · 0.85
dynamic_quantizeFunction · 0.85
gemm_allreduceFunction · 0.85
fp4_gemmFunction · 0.85
block_double_dequantizeFunction · 0.85
matmulFunction · 0.85
castMethod · 0.80
allreduceFunction · 0.50

Tested by

no test coverage detected