| 454 | self.out_proj = Linear(hidden_size, num_labels) |
| 455 | |
| 456 | def forward(self, hidden_states, input_lengths, remove_input_padding): |
| 457 | |
| 458 | if not remove_input_padding: |
| 459 | # We "pool" the model by simply taking the hidden state corresponding |
| 460 | # to the first token. |
| 461 | first_token_tensor = select(hidden_states, 1, 0) |
| 462 | else: |
| 463 | # when remove_input_padding is enabled, the shape of hidden_states is [num_tokens, hidden_size] |
| 464 | # We can take the first token of each sequence according to input_lengths, |
| 465 | # and then do pooling similar to padding mode. |
| 466 | # For example, if input_lengths is [8, 5, 6], then the indices of first tokens |
| 467 | # should be [0, 8, 13] |
| 468 | first_token_indices = cumsum( |
| 469 | concat([ |
| 470 | 0, |
| 471 | slice(input_lengths, |
| 472 | starts=[0], |
| 473 | sizes=(shape(input_lengths) - |
| 474 | constant(np.array([1], dtype=np.int32)))) |
| 475 | ]), 0) |
| 476 | first_token_tensor = index_select(hidden_states, 0, |
| 477 | first_token_indices) |
| 478 | |
| 479 | x = self.dense(first_token_tensor) |
| 480 | x = ACT2FN['tanh'](x) |
| 481 | x = self.out_proj(x) |
| 482 | return x |
| 483 | |
| 484 | |
| 485 | class BertForSequenceClassification(BertBase): |