(self, x, lora_runtime_params=None, all_reduce_params=None)
| 2300 | } |
| 2301 | |
| 2302 | def forward(self, x, lora_runtime_params=None, all_reduce_params=None): |
| 2303 | assert lora_runtime_params is None, "lora is not supported on FP4Linear now" |
| 2304 | |
| 2305 | if isinstance(x, (tuple, list)): |
| 2306 | fp4_x, act_per_block_scale = x |
| 2307 | else: |
| 2308 | if default_net().plugin_config.gemm_plugin == "nvfp4": |
| 2309 | fp4_x, act_per_block_scale = quantize_to_fp4_tensor( |
| 2310 | x, div(1.0, self.activation_global_scaling_factor.value)) |
| 2311 | else: |
| 2312 | # WAR for FP8 output attention |
| 2313 | if x.dtype == trt.fp8: |
| 2314 | # Since the scale is NVFP4 scale, we need to make it back to fp8 scale |
| 2315 | new_scale_factor = self.activation_global_scaling_factor.raw_value |
| 2316 | new_scale_factor = constant(new_scale_factor * 6) |
| 2317 | x = dequantize(x, new_scale_factor, 0, |
| 2318 | new_scale_factor.dtype) |
| 2319 | fp4_x, act_per_block_scale = dynamic_quantize( |
| 2320 | x, self.activation_global_scaling_factor.value) |
| 2321 | |
| 2322 | if default_net().plugin_config.gemm_allreduce_plugin: |
| 2323 | x = gemm_allreduce( |
| 2324 | a=fp4_x, |
| 2325 | b=self.weight.value, |
| 2326 | a_sf=act_per_block_scale, |
| 2327 | b_sf=self.weights_block_scaling_factor_interleaved.value, |
| 2328 | transa=False, # row-major |
| 2329 | transb=True, # col-major |
| 2330 | alpha=self.alpha.value, |
| 2331 | group=self.tp_group, # ranks participating |
| 2332 | fp8_inputs_override=False) |
| 2333 | else: |
| 2334 | if default_net().plugin_config.gemm_plugin == "nvfp4": |
| 2335 | x = fp4_gemm( |
| 2336 | fp4_x, act_per_block_scale, self.weight.value, |
| 2337 | self.weights_block_scaling_factor_interleaved.value, |
| 2338 | self.alpha.value, self.dtype) |
| 2339 | else: |
| 2340 | quant_w = self.weight.value |
| 2341 | scale_w = self.weights_block_scaling_factor.value |
| 2342 | dequant_x = block_double_dequantize( |
| 2343 | fp4_x, act_per_block_scale, |
| 2344 | self.activation_global_scaling_factor.value, trt.float16) |
| 2345 | dequant_w = block_double_dequantize( |
| 2346 | quant_w, scale_w, self.weights_global_scaling_factor.value, |
| 2347 | trt.float16) |
| 2348 | x = matmul(dequant_x, dequant_w, transb=True).cast(self.dtype) |
| 2349 | |
| 2350 | if self.tp_size > 1 and self.tp_group is not None: |
| 2351 | need_bias = self.bias is not None |
| 2352 | fuse_bias_into_all_reduce = need_bias and ( |
| 2353 | all_reduce_params |
| 2354 | is not None) and (all_reduce_params.fusion_op |
| 2355 | == AllReduceFusionOp.RESIDUAL_RMS_NORM) |
| 2356 | if fuse_bias_into_all_reduce: |
| 2357 | all_reduce_params.bias = self.bias.value |
| 2358 | x = allreduce(x, |
| 2359 | self.tp_group, |
nothing calls this directly
no test coverage detected