MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/quantization/layers.py:1356–1383  ·  view source on GitHub ↗
(self, x, lora_runtime_params=None)

Source from the content-addressed store, hash-verified

1354 prefer_managed=self.prefer_managed_weight)
1355
1356 def forward(self, x, lora_runtime_params=None):
1357 lora_hidden_state = x if lora_runtime_params is not None else None
1358 if default_net().strongly_typed:
1359 assert is_same_dtype(
1360 x.dtype,
1361 self.dtype), f"Got input type {x.dtype}, expecting {self.dtype}"
1362 x = mul(x, self.prequant_scaling_factor.value)
1363 x = cast(x, self.activation_scaling_factor.value.dtype)
1364
1365 quantized_out = quantize(x, self.activation_scaling_factor.value,
1366 'int8')
1367
1368 dequantized_out = dequantize(quantized_out,
1369 self.activation_scaling_factor.value, -1,
1370 self.activation_scaling_factor.value.dtype)
1371
1372 dequantized_out = cast(dequantized_out, self.dtype)
1373
1374 w_deq_out = dequantize(self.weight.value,
1375 self.weights_scaling_factor.value, 0,
1376 self.weights_scaling_factor.value.dtype)
1377 w_deq_out = cast(w_deq_out, self.dtype)
1378
1379 return self.multiply_collect(dequantized_out,
1380 w_deq_out,
1381 gemm_plugin=None,
1382 lora_runtime_params=lora_runtime_params,
1383 lora_hidden_state=lora_hidden_state)
1384
1385
1386class FP8Linear(Linear):

Callers

nothing calls this directly

Calls 6

default_netFunction · 0.85
is_same_dtypeFunction · 0.85
castFunction · 0.85
dequantizeFunction · 0.85
quantizeFunction · 0.70
multiply_collectMethod · 0.45

Tested by

no test coverage detected