| 802 | self.apply(self.init_bert_weights) |
| 803 | |
| 804 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, checkpoint_activations=False): |
| 805 | if attention_mask is None: |
| 806 | attention_mask = torch.ones_like(input_ids) |
| 807 | if token_type_ids is None: |
| 808 | token_type_ids = torch.zeros_like(input_ids) |
| 809 | |
| 810 | # We create a 3D attention mask from a 2D tensor mask. |
| 811 | # Sizes are [batch_size, 1, 1, to_seq_length] |
| 812 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] |
| 813 | # this attention mask is more simple than the triangular masking of causal attention |
| 814 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. |
| 815 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| 816 | |
| 817 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for |
| 818 | # masked positions, this operation will create a tensor which is 0.0 for |
| 819 | # positions we want to attend and -10000.0 for masked positions. |
| 820 | # Since we are adding it to the raw scores before the softmax, this is |
| 821 | # effectively the same as removing these entirely. |
| 822 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility |
| 823 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
| 824 | |
| 825 | embedding_output = self.embeddings(input_ids, token_type_ids) |
| 826 | encoded_layers = self.encoder(embedding_output, |
| 827 | extended_attention_mask, |
| 828 | output_all_encoded_layers=output_all_encoded_layers, checkpoint_activations=checkpoint_activations) |
| 829 | sequence_output = encoded_layers[-1] |
| 830 | pooled_output = self.pooler(sequence_output) |
| 831 | if not output_all_encoded_layers: |
| 832 | encoded_layers = encoded_layers[-1] |
| 833 | return encoded_layers, pooled_output |
| 834 | |
| 835 | |
| 836 | class BertForPreTraining(BertPreTrainedModel): |