MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/quantization/layers.py:1425–1492  ·  view source on GitHub ↗
(self, x, lora_runtime_params=None)

Source from the content-addressed store, hash-verified

1423 }
1424
1425 def forward(self, x, lora_runtime_params=None):
1426 assert lora_runtime_params is None or default_net(
1427 ).plugin_config.lora_plugin == self.dtype
1428
1429 if default_net().strongly_typed:
1430 assert default_net().plugin_config.user_buffer or is_same_dtype(
1431 x.dtype,
1432 self.dtype), f"Got input type {x.dtype}, expecting {self.dtype}"
1433
1434 alpha = self.weights_scaling_factor.raw_value * self.activation_scaling_factor.raw_value
1435 activation_scaling_factor = constant(
1436 self.activation_scaling_factor.raw_value)
1437 activation_scaling_factor = cast(activation_scaling_factor, self.dtype)
1438 if x.dtype != trt.fp8:
1439 quantized_out = quantize(x, activation_scaling_factor, 'fp8')
1440 lora_hidden_state = x if lora_runtime_params is not None else None
1441 else:
1442 quantized_out = x
1443 # TODO: add fp8 LoRA support
1444 lora_hidden_state = dequantize(
1445 x, activation_scaling_factor, -1,
1446 self.dtype) if lora_runtime_params is not None else None
1447
1448 weights_scaling_factor = cast(self.weights_scaling_factor.value,
1449 self.dtype)
1450
1451 if self.weight.value.dtype != trt.fp8:
1452 w_quant_out = quantize(self.weight.value, weights_scaling_factor,
1453 'fp8')
1454 else:
1455 w_quant_out = self.weight.value
1456
1457 gemm_plugin = default_net().plugin_config.gemm_plugin
1458
1459 low_latency_gemm_plugin = default_net(
1460 ).plugin_config.low_latency_gemm_plugin
1461 if (low_latency_gemm_plugin == "fp8"):
1462 return self.multiply_collect(
1463 quantized_out,
1464 w_quant_out,
1465 gemm_plugin=None,
1466 low_latency_gemm_plugin=low_latency_gemm_plugin,
1467 use_fp8=True,
1468 alpha=alpha,
1469 lora_runtime_params=lora_runtime_params,
1470 lora_hidden_state=lora_hidden_state)
1471 elif gemm_plugin == 'fp8':
1472 return self.multiply_collect(
1473 quantized_out,
1474 w_quant_out,
1475 gemm_plugin=gemm_plugin,
1476 use_fp8=True,
1477 alpha=alpha,
1478 lora_runtime_params=lora_runtime_params,
1479 lora_hidden_state=lora_hidden_state)
1480 else:
1481 dequantized_out = dequantize(quantized_out,
1482 activation_scaling_factor, -1,

Callers

nothing calls this directly

Calls 7

default_netFunction · 0.85
is_same_dtypeFunction · 0.85
constantFunction · 0.85
castFunction · 0.85
dequantizeFunction · 0.85
quantizeFunction · 0.70
multiply_collectMethod · 0.45

Tested by

no test coverage detected