(self, x, lora_runtime_params=None)
| 2152 | } |
| 2153 | |
| 2154 | def forward(self, x, lora_runtime_params=None): |
| 2155 | assert lora_runtime_params is None, "lora is not supported on FP4Linear now" |
| 2156 | if isinstance(x, (tuple, list)): |
| 2157 | fp4_x, act_per_block_scale = x |
| 2158 | else: |
| 2159 | if default_net().plugin_config.gemm_plugin == 'nvfp4': |
| 2160 | fp4_x, act_per_block_scale = quantize_to_fp4_tensor( |
| 2161 | x, div(1, self.activation_global_scaling_factor.value)) |
| 2162 | else: |
| 2163 | fp4_x, act_per_block_scale = dynamic_quantize( |
| 2164 | x, self.activation_global_scaling_factor.value) |
| 2165 | if default_net().plugin_config.gemm_plugin == 'nvfp4': |
| 2166 | x = fp4_gemm(fp4_x, act_per_block_scale, self.weight.value, |
| 2167 | self.weights_block_scaling_factor_interleaved.value, |
| 2168 | self.alpha.value, self.dtype) |
| 2169 | else: |
| 2170 | quant_w = self.weight.value |
| 2171 | scale_w = self.weights_block_scaling_factor.value |
| 2172 | dequant_w = block_double_dequantize( |
| 2173 | quant_w, |
| 2174 | scale_w, |
| 2175 | self.weights_global_scaling_factor.value, |
| 2176 | dtype=trt.float16) |
| 2177 | dequant_x = block_double_dequantize( |
| 2178 | fp4_x, |
| 2179 | act_per_block_scale, |
| 2180 | self.activation_global_scaling_factor.value, |
| 2181 | dtype=trt.float16) |
| 2182 | x = matmul(dequant_x, dequant_w, transb=True).cast(self.dtype) |
| 2183 | |
| 2184 | if self.bias is not None: |
| 2185 | x = x + self.bias.value |
| 2186 | |
| 2187 | if self.gather_output and self.tp_size > 1 and self.tp_group is not None: |
| 2188 | # [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size] |
| 2189 | x = allgather(x, self.tp_group, gather_dim=1) |
| 2190 | |
| 2191 | return x |
| 2192 | |
| 2193 | def postprocess(self, tllm_key, weights, **kwargs): |
| 2194 | if not any([ |
nothing calls this directly
no test coverage detected