| 474 | # all_encoder_layers.append(hidden_states) |
| 475 | # return all_encoder_layers |
| 476 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False): |
| 477 | all_encoder_layers = [] |
| 478 | def custom(start, end): |
| 479 | def custom_forward(*inputs): |
| 480 | layers = self.layer[start:end] |
| 481 | x_ = inputs[0] |
| 482 | for layer in layers: |
| 483 | x_ = layer(x_, inputs[1]) |
| 484 | return x_ |
| 485 | return custom_forward |
| 486 | |
| 487 | if checkpoint_activations: |
| 488 | l = 0 |
| 489 | num_layers = len(self.layer) |
| 490 | chunk_length = math.ceil(math.sqrt(num_layers)) |
| 491 | while l < num_layers: |
| 492 | hidden_states = checkpoint.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1) |
| 493 | l += chunk_length |
| 494 | # decoder layers |
| 495 | else: |
| 496 | for i,layer_module in enumerate(self.layer): |
| 497 | hidden_states = layer_module(hidden_states, attention_mask) |
| 498 | |
| 499 | if output_all_encoded_layers: |
| 500 | all_encoder_layers.append(hidden_states) |
| 501 | |
| 502 | if not output_all_encoded_layers or checkpoint_activations: |
| 503 | all_encoder_layers.append(hidden_states) |
| 504 | return all_encoder_layers |
| 505 | |
| 506 | #class BertEncoder(nn.Module): |
| 507 | # def __init__(self, config): |