(self, seq_max_len, embedding_dim, padding_idx)
| 214 | ''' Reference: attention is all you need ''' |
| 215 | |
| 216 | def __init__(self, seq_max_len, embedding_dim, padding_idx): |
| 217 | super(PositionEmbedding, self).__init__() |
| 218 | |
| 219 | self.position_enc = nn.Embedding.from_pretrained( |
| 220 | self.get_sinusoid_encoding_table(seq_max_len + 1, |
| 221 | embedding_dim, |
| 222 | padding_idx=padding_idx), |
| 223 | freeze=True) |
| 224 | |
| 225 | def forward(self, src_pos): |
| 226 | return self.position_enc(src_pos) |
nothing calls this directly
no test coverage detected