(self, x, lora_runtime_params=None, all_reduce_params=None)
| 1558 | } |
| 1559 | |
| 1560 | def forward(self, x, lora_runtime_params=None, all_reduce_params=None): |
| 1561 | assert lora_runtime_params is None or default_net( |
| 1562 | ).plugin_config.lora_plugin == self.dtype |
| 1563 | |
| 1564 | alpha = self.weights_scaling_factor.raw_value * self.activation_scaling_factor.raw_value |
| 1565 | activation_scaling_factor = cast(self.activation_scaling_factor.value, |
| 1566 | self.dtype) |
| 1567 | if x.dtype != trt.fp8: |
| 1568 | quantized_out = quantize(x, activation_scaling_factor, 'fp8') |
| 1569 | lora_hidden_state = x if lora_runtime_params is not None else None |
| 1570 | else: |
| 1571 | quantized_out = x |
| 1572 | # TODO: add fp8 LoRA support |
| 1573 | lora_hidden_state = dequantize( |
| 1574 | x, activation_scaling_factor, -1, |
| 1575 | self.dtype) if lora_runtime_params is not None else None |
| 1576 | |
| 1577 | weights_scaling_factor = cast(self.weights_scaling_factor.value, |
| 1578 | self.dtype) |
| 1579 | if self.weight.value.dtype != trt.fp8: |
| 1580 | w_quant_out = quantize(self.weight.value, weights_scaling_factor, |
| 1581 | 'fp8') |
| 1582 | else: |
| 1583 | w_quant_out = self.weight.value |
| 1584 | |
| 1585 | gemm_plugin = default_net().plugin_config.gemm_plugin |
| 1586 | low_latency_gemm_plugin = default_net( |
| 1587 | ).plugin_config.low_latency_gemm_plugin |
| 1588 | gemm_allreduce_plugin = default_net( |
| 1589 | ).plugin_config.gemm_allreduce_plugin |
| 1590 | if gemm_allreduce_plugin: |
| 1591 | ret = self.multiply_collect(quantized_out, |
| 1592 | w_quant_out, |
| 1593 | gemm_plugin=None, |
| 1594 | use_fp8=True, |
| 1595 | alpha=alpha, |
| 1596 | lora_runtime_params=lora_runtime_params, |
| 1597 | lora_hidden_state=lora_hidden_state, |
| 1598 | all_reduce_params=all_reduce_params) |
| 1599 | elif (low_latency_gemm_plugin == "fp8"): |
| 1600 | ret = self.multiply_collect( |
| 1601 | quantized_out, |
| 1602 | w_quant_out, |
| 1603 | gemm_plugin=None, |
| 1604 | low_latency_gemm_plugin=low_latency_gemm_plugin, |
| 1605 | use_fp8=True, |
| 1606 | alpha=alpha, |
| 1607 | lora_runtime_params=lora_runtime_params, |
| 1608 | lora_hidden_state=lora_hidden_state, |
| 1609 | all_reduce_params=all_reduce_params) |
| 1610 | elif gemm_plugin == 'fp8': |
| 1611 | ret = self.multiply_collect(quantized_out, |
| 1612 | w_quant_out, |
| 1613 | gemm_plugin=gemm_plugin, |
| 1614 | use_fp8=True, |
| 1615 | alpha=alpha, |
| 1616 | lora_runtime_params=lora_runtime_params, |
| 1617 | lora_hidden_state=lora_hidden_state, |
nothing calls this directly
no test coverage detected