(
self,
num_embeddings: int,
embedding_dim: int,
dtype: Optional[str] = None,
tp_size: int = 1,
tp_group: Optional[list] = None,
sharding_dim: int = 0,
tp_rank: Optional[int] = None,
quant_mode=QuantMode.use_weight_only(),
)
| 897 | class WeightOnlyQuantEmbedding(Embedding): |
| 898 | |
| 899 | def __init__( |
| 900 | self, |
| 901 | num_embeddings: int, |
| 902 | embedding_dim: int, |
| 903 | dtype: Optional[str] = None, |
| 904 | tp_size: int = 1, |
| 905 | tp_group: Optional[list] = None, |
| 906 | sharding_dim: int = 0, |
| 907 | tp_rank: Optional[int] = None, |
| 908 | quant_mode=QuantMode.use_weight_only(), |
| 909 | ): |
| 910 | super().__init__( |
| 911 | num_embeddings, |
| 912 | embedding_dim, |
| 913 | dtype, # dtype, |
| 914 | tp_size, |
| 915 | tp_group, |
| 916 | sharding_dim, |
| 917 | tp_rank) |
| 918 | # only support int8 wo now |
| 919 | # TODO support int4 wo |
| 920 | self.quant_mode = quant_mode |
| 921 | self.per_token_scale = Parameter(shape=(self.num_embeddings, ), |
| 922 | dtype=dtype) |
| 923 | |
| 924 | if sharding_dim == 1: |
| 925 | self.weight = Parameter(shape=(self.num_embeddings, |
| 926 | self.embedding_dim // self.tp_size), |
| 927 | dtype="int8") |
| 928 | elif sharding_dim == 0: |
| 929 | self.weight = Parameter(shape=(math.ceil( |
| 930 | self.num_embeddings / self.tp_size), self.embedding_dim), |
| 931 | dtype="int8") |
| 932 | |
| 933 | def forward(self, x): |
| 934 | result = embedding(x, |
nothing calls this directly
no test coverage detected