| 587 | |
| 588 | |
| 589 | class BertPreTrainingHeads(nn.Module): |
| 590 | def __init__(self, config, bert_model_embedding_weights): |
| 591 | super(BertPreTrainingHeads, self).__init__() |
| 592 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) |
| 593 | self.seq_relationship = nn.Linear(config.hidden_size, 2) |
| 594 | |
| 595 | def forward(self, sequence_output, pooled_output): |
| 596 | prediction_scores = self.predictions(sequence_output) |
| 597 | seq_relationship_score = self.seq_relationship(pooled_output) |
| 598 | return prediction_scores, seq_relationship_score |
| 599 | |
| 600 | |
| 601 | class BertPreTrainedModel(nn.Module): |