| 521 | |
| 522 | |
| 523 | class BertPooler(nn.Module): |
| 524 | def __init__(self, config): |
| 525 | super(BertPooler, self).__init__() |
| 526 | self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act="tanh") |
| 527 | |
| 528 | def forward(self, hidden_states): |
| 529 | # We "pool" the model by simply taking the hidden state corresponding |
| 530 | # to the first token. |
| 531 | first_token_tensor = hidden_states[:, 0] |
| 532 | pooled_output = self.dense_act(first_token_tensor) |
| 533 | return pooled_output |
| 534 | |
| 535 | |
| 536 | class BertPredictionHeadTransform(nn.Module): |