(self, x, lora_runtime_params=None)
| 759 | self.quant_mode = quant_mode |
| 760 | |
| 761 | def forward(self, x, lora_runtime_params=None): |
| 762 | # ootb has not supported int4 yet. |
| 763 | if self.weight_only_quant_mode == 2 and not default_net( |
| 764 | ).plugin_config.weight_only_quant_matmul_plugin: |
| 765 | raise TypeError( |
| 766 | "Int4 Weight-only Quant MatMul is only supported with plugin") |
| 767 | hidden_state = x |
| 768 | x = weight_only_quant_matmul(x, self.weight.value, |
| 769 | self.per_channel_scale.value, |
| 770 | self.weight_only_quant_mode, self.dtype, |
| 771 | self.transa, self.transb) |
| 772 | |
| 773 | if default_net( |
| 774 | ).plugin_config.lora_plugin and lora_runtime_params is not None: |
| 775 | x = x + self.lora(hidden_state, |
| 776 | lora_runtime_params=lora_runtime_params) |
| 777 | |
| 778 | if self.bias is not None: |
| 779 | x = x + self.bias.value |
| 780 | |
| 781 | if self.gather_output and self.tp_size > 1 and self.tp_group is not None: |
| 782 | # [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size] |
| 783 | x = allgather(x, self.tp_group, gather_dim=1) |
| 784 | |
| 785 | return x |
| 786 | |
| 787 | def postprocess(self, tllm_key, weights, **kwargs): |
| 788 | if "per_channel_scale" in tllm_key: |
nothing calls this directly
no test coverage detected