| 410 | |
| 411 | |
| 412 | class BertAttention(nn.Module): |
| 413 | def __init__(self, config): |
| 414 | super(BertAttention, self).__init__() |
| 415 | self.self = BertSelfAttention(config) |
| 416 | self.output = BertSelfOutput(config) |
| 417 | |
| 418 | def forward(self, input_tensor, attention_mask): |
| 419 | self_output = self.self(input_tensor, attention_mask) |
| 420 | attention_output = self.output(self_output, input_tensor) |
| 421 | return attention_output |
| 422 | |
| 423 | |
| 424 | class BertIntermediate(nn.Module): |