(self, x, lora_runtime_params=None)
| 1423 | } |
| 1424 | |
| 1425 | def forward(self, x, lora_runtime_params=None): |
| 1426 | assert lora_runtime_params is None or default_net( |
| 1427 | ).plugin_config.lora_plugin == self.dtype |
| 1428 | |
| 1429 | if default_net().strongly_typed: |
| 1430 | assert default_net().plugin_config.user_buffer or is_same_dtype( |
| 1431 | x.dtype, |
| 1432 | self.dtype), f"Got input type {x.dtype}, expecting {self.dtype}" |
| 1433 | |
| 1434 | alpha = self.weights_scaling_factor.raw_value * self.activation_scaling_factor.raw_value |
| 1435 | activation_scaling_factor = constant( |
| 1436 | self.activation_scaling_factor.raw_value) |
| 1437 | activation_scaling_factor = cast(activation_scaling_factor, self.dtype) |
| 1438 | if x.dtype != trt.fp8: |
| 1439 | quantized_out = quantize(x, activation_scaling_factor, 'fp8') |
| 1440 | lora_hidden_state = x if lora_runtime_params is not None else None |
| 1441 | else: |
| 1442 | quantized_out = x |
| 1443 | # TODO: add fp8 LoRA support |
| 1444 | lora_hidden_state = dequantize( |
| 1445 | x, activation_scaling_factor, -1, |
| 1446 | self.dtype) if lora_runtime_params is not None else None |
| 1447 | |
| 1448 | weights_scaling_factor = cast(self.weights_scaling_factor.value, |
| 1449 | self.dtype) |
| 1450 | |
| 1451 | if self.weight.value.dtype != trt.fp8: |
| 1452 | w_quant_out = quantize(self.weight.value, weights_scaling_factor, |
| 1453 | 'fp8') |
| 1454 | else: |
| 1455 | w_quant_out = self.weight.value |
| 1456 | |
| 1457 | gemm_plugin = default_net().plugin_config.gemm_plugin |
| 1458 | |
| 1459 | low_latency_gemm_plugin = default_net( |
| 1460 | ).plugin_config.low_latency_gemm_plugin |
| 1461 | if (low_latency_gemm_plugin == "fp8"): |
| 1462 | return self.multiply_collect( |
| 1463 | quantized_out, |
| 1464 | w_quant_out, |
| 1465 | gemm_plugin=None, |
| 1466 | low_latency_gemm_plugin=low_latency_gemm_plugin, |
| 1467 | use_fp8=True, |
| 1468 | alpha=alpha, |
| 1469 | lora_runtime_params=lora_runtime_params, |
| 1470 | lora_hidden_state=lora_hidden_state) |
| 1471 | elif gemm_plugin == 'fp8': |
| 1472 | return self.multiply_collect( |
| 1473 | quantized_out, |
| 1474 | w_quant_out, |
| 1475 | gemm_plugin=gemm_plugin, |
| 1476 | use_fp8=True, |
| 1477 | alpha=alpha, |
| 1478 | lora_runtime_params=lora_runtime_params, |
| 1479 | lora_hidden_state=lora_hidden_state) |
| 1480 | else: |
| 1481 | dequantized_out = dequantize(quantized_out, |
| 1482 | activation_scaling_factor, -1, |
nothing calls this directly
no test coverage detected