MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/quantization/layers.py:1560–1632  ·  view source on GitHub ↗
(self, x, lora_runtime_params=None, all_reduce_params=None)

Source from the content-addressed store, hash-verified

1558 }
1559
1560 def forward(self, x, lora_runtime_params=None, all_reduce_params=None):
1561 assert lora_runtime_params is None or default_net(
1562 ).plugin_config.lora_plugin == self.dtype
1563
1564 alpha = self.weights_scaling_factor.raw_value * self.activation_scaling_factor.raw_value
1565 activation_scaling_factor = cast(self.activation_scaling_factor.value,
1566 self.dtype)
1567 if x.dtype != trt.fp8:
1568 quantized_out = quantize(x, activation_scaling_factor, 'fp8')
1569 lora_hidden_state = x if lora_runtime_params is not None else None
1570 else:
1571 quantized_out = x
1572 # TODO: add fp8 LoRA support
1573 lora_hidden_state = dequantize(
1574 x, activation_scaling_factor, -1,
1575 self.dtype) if lora_runtime_params is not None else None
1576
1577 weights_scaling_factor = cast(self.weights_scaling_factor.value,
1578 self.dtype)
1579 if self.weight.value.dtype != trt.fp8:
1580 w_quant_out = quantize(self.weight.value, weights_scaling_factor,
1581 'fp8')
1582 else:
1583 w_quant_out = self.weight.value
1584
1585 gemm_plugin = default_net().plugin_config.gemm_plugin
1586 low_latency_gemm_plugin = default_net(
1587 ).plugin_config.low_latency_gemm_plugin
1588 gemm_allreduce_plugin = default_net(
1589 ).plugin_config.gemm_allreduce_plugin
1590 if gemm_allreduce_plugin:
1591 ret = self.multiply_collect(quantized_out,
1592 w_quant_out,
1593 gemm_plugin=None,
1594 use_fp8=True,
1595 alpha=alpha,
1596 lora_runtime_params=lora_runtime_params,
1597 lora_hidden_state=lora_hidden_state,
1598 all_reduce_params=all_reduce_params)
1599 elif (low_latency_gemm_plugin == "fp8"):
1600 ret = self.multiply_collect(
1601 quantized_out,
1602 w_quant_out,
1603 gemm_plugin=None,
1604 low_latency_gemm_plugin=low_latency_gemm_plugin,
1605 use_fp8=True,
1606 alpha=alpha,
1607 lora_runtime_params=lora_runtime_params,
1608 lora_hidden_state=lora_hidden_state,
1609 all_reduce_params=all_reduce_params)
1610 elif gemm_plugin == 'fp8':
1611 ret = self.multiply_collect(quantized_out,
1612 w_quant_out,
1613 gemm_plugin=gemm_plugin,
1614 use_fp8=True,
1615 alpha=alpha,
1616 lora_runtime_params=lora_runtime_params,
1617 lora_hidden_state=lora_hidden_state,

Callers

nothing calls this directly

Calls 5

default_netFunction · 0.85
castFunction · 0.85
dequantizeFunction · 0.85
quantizeFunction · 0.70
multiply_collectMethod · 0.45

Tested by

no test coverage detected