MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / Fp8RowwiseFusedGatedMLP

Class Fp8RowwiseFusedGatedMLP

tensorrt_llm/quantization/layers.py:1760–1829  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

1758
1759
1760class Fp8RowwiseFusedGatedMLP(Module):
1761
1762 def __init__(
1763 self,
1764 hidden_size,
1765 ffn_hidden_size,
1766 hidden_act,
1767 bias=True,
1768 dtype=None,
1769 tp_group=None,
1770 tp_size=1,
1771 quant_mode=QuantMode(0),
1772 clamp_val=None,
1773 ):
1774 super().__init__()
1775 self.hidden_size = hidden_size
1776 self.ffn_hidden_size = ffn_hidden_size
1777 self.hidden_act = hidden_act
1778 self.bias = bias
1779 self.dtype = dtype
1780 self.tp_group = tp_group
1781 self.tp_size = tp_size
1782 self.quant_mode = quant_mode
1783
1784 self.fused_fc = Fp8RowwiseColumnLinear(hidden_size,
1785 ffn_hidden_size * 2,
1786 bias=bias,
1787 dtype=dtype,
1788 tp_group=tp_group,
1789 tp_size=tp_size,
1790 gather_output=False,
1791 quant_mode=quant_mode)
1792
1793 self.proj = Fp8RowwiseRowLinear(ffn_hidden_size,
1794 hidden_size,
1795 bias=bias,
1796 dtype=dtype,
1797 tp_group=tp_group,
1798 tp_size=tp_size,
1799 quant_mode=quant_mode)
1800
1801 if clamp_val:
1802 if not (isinstance(clamp_val, list) and len(clamp_val) == 2):
1803 raise ValueError(f'unsupported clamp_val {clamp_val}')
1804 self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32),
1805 dtype='float32',
1806 is_buffer=True)
1807 else:
1808 self.register_parameter('clamp_val', None)
1809
1810 def forward(self, hidden_states, lora_layer_params=None):
1811 assert lora_layer_params is None, f"lora is not supported on {self.__class__.__name__} now"
1812 inter = self.fused_fc(hidden_states)
1813
1814 if self.hidden_act == 'silu':
1815 inter = ACT2FN['swiglu'](inter)
1816 elif self.hidden_act == 'gelu':
1817 inter = ACT2FN['geglu'](inter)

Callers 1

fuse_gate_mlpFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected