MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/quantization/layers.py:2154–2191  ·  view source on GitHub ↗
(self, x, lora_runtime_params=None)

Source from the content-addressed store, hash-verified

2152 }
2153
2154 def forward(self, x, lora_runtime_params=None):
2155 assert lora_runtime_params is None, "lora is not supported on FP4Linear now"
2156 if isinstance(x, (tuple, list)):
2157 fp4_x, act_per_block_scale = x
2158 else:
2159 if default_net().plugin_config.gemm_plugin == 'nvfp4':
2160 fp4_x, act_per_block_scale = quantize_to_fp4_tensor(
2161 x, div(1, self.activation_global_scaling_factor.value))
2162 else:
2163 fp4_x, act_per_block_scale = dynamic_quantize(
2164 x, self.activation_global_scaling_factor.value)
2165 if default_net().plugin_config.gemm_plugin == 'nvfp4':
2166 x = fp4_gemm(fp4_x, act_per_block_scale, self.weight.value,
2167 self.weights_block_scaling_factor_interleaved.value,
2168 self.alpha.value, self.dtype)
2169 else:
2170 quant_w = self.weight.value
2171 scale_w = self.weights_block_scaling_factor.value
2172 dequant_w = block_double_dequantize(
2173 quant_w,
2174 scale_w,
2175 self.weights_global_scaling_factor.value,
2176 dtype=trt.float16)
2177 dequant_x = block_double_dequantize(
2178 fp4_x,
2179 act_per_block_scale,
2180 self.activation_global_scaling_factor.value,
2181 dtype=trt.float16)
2182 x = matmul(dequant_x, dequant_w, transb=True).cast(self.dtype)
2183
2184 if self.bias is not None:
2185 x = x + self.bias.value
2186
2187 if self.gather_output and self.tp_size > 1 and self.tp_group is not None:
2188 # [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]
2189 x = allgather(x, self.tp_group, gather_dim=1)
2190
2191 return x
2192
2193 def postprocess(self, tllm_key, weights, **kwargs):
2194 if not any([

Callers

nothing calls this directly

Calls 8

default_netFunction · 0.85
quantize_to_fp4_tensorFunction · 0.85
dynamic_quantizeFunction · 0.85
fp4_gemmFunction · 0.85
block_double_dequantizeFunction · 0.85
matmulFunction · 0.85
castMethod · 0.80
allgatherFunction · 0.50

Tested by

no test coverage detected