Construct the embeddings from word, position and token_type embeddings.
| 308 | return self.weight * x + self.bias |
| 309 | |
| 310 | class BertEmbeddings(nn.Module): |
| 311 | """Construct the embeddings from word, position and token_type embeddings. |
| 312 | """ |
| 313 | def __init__(self, config): |
| 314 | super(BertEmbeddings, self).__init__() |
| 315 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
| 316 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| 317 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
| 318 | |
| 319 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load |
| 320 | # any TensorFlow checkpoint file |
| 321 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) |
| 322 | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| 323 | |
| 324 | def forward(self, input_ids, token_type_ids=None): |
| 325 | seq_length = input_ids.size(1) |
| 326 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| 327 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
| 328 | if token_type_ids is None: |
| 329 | token_type_ids = torch.zeros_like(input_ids) |
| 330 | |
| 331 | words_embeddings = self.word_embeddings(input_ids) |
| 332 | position_embeddings = self.position_embeddings(position_ids) |
| 333 | token_type_embeddings = self.token_type_embeddings(token_type_ids) |
| 334 | |
| 335 | embeddings = words_embeddings + position_embeddings + token_type_embeddings |
| 336 | embeddings = self.LayerNorm(embeddings) |
| 337 | embeddings = self.dropout(embeddings) |
| 338 | return embeddings |
| 339 | |
| 340 | |
| 341 | class BertSelfAttention(nn.Module): |