| 1758 | |
| 1759 | |
| 1760 | class Fp8RowwiseFusedGatedMLP(Module): |
| 1761 | |
| 1762 | def __init__( |
| 1763 | self, |
| 1764 | hidden_size, |
| 1765 | ffn_hidden_size, |
| 1766 | hidden_act, |
| 1767 | bias=True, |
| 1768 | dtype=None, |
| 1769 | tp_group=None, |
| 1770 | tp_size=1, |
| 1771 | quant_mode=QuantMode(0), |
| 1772 | clamp_val=None, |
| 1773 | ): |
| 1774 | super().__init__() |
| 1775 | self.hidden_size = hidden_size |
| 1776 | self.ffn_hidden_size = ffn_hidden_size |
| 1777 | self.hidden_act = hidden_act |
| 1778 | self.bias = bias |
| 1779 | self.dtype = dtype |
| 1780 | self.tp_group = tp_group |
| 1781 | self.tp_size = tp_size |
| 1782 | self.quant_mode = quant_mode |
| 1783 | |
| 1784 | self.fused_fc = Fp8RowwiseColumnLinear(hidden_size, |
| 1785 | ffn_hidden_size * 2, |
| 1786 | bias=bias, |
| 1787 | dtype=dtype, |
| 1788 | tp_group=tp_group, |
| 1789 | tp_size=tp_size, |
| 1790 | gather_output=False, |
| 1791 | quant_mode=quant_mode) |
| 1792 | |
| 1793 | self.proj = Fp8RowwiseRowLinear(ffn_hidden_size, |
| 1794 | hidden_size, |
| 1795 | bias=bias, |
| 1796 | dtype=dtype, |
| 1797 | tp_group=tp_group, |
| 1798 | tp_size=tp_size, |
| 1799 | quant_mode=quant_mode) |
| 1800 | |
| 1801 | if clamp_val: |
| 1802 | if not (isinstance(clamp_val, list) and len(clamp_val) == 2): |
| 1803 | raise ValueError(f'unsupported clamp_val {clamp_val}') |
| 1804 | self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32), |
| 1805 | dtype='float32', |
| 1806 | is_buffer=True) |
| 1807 | else: |
| 1808 | self.register_parameter('clamp_val', None) |
| 1809 | |
| 1810 | def forward(self, hidden_states, lora_layer_params=None): |
| 1811 | assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now" |
| 1812 | inter = self.fused_fc(hidden_states) |
| 1813 | |
| 1814 | if self.hidden_act == 'silu': |
| 1815 | inter = ACT2FN['swiglu'](inter) |
| 1816 | elif self.hidden_act == 'gelu': |
| 1817 | inter = ACT2FN['geglu'](inter) |