| 489 | self.get_modules(self, mdl, input) |
| 490 | |
| 491 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False): |
| 492 | all_encoder_layers = [] |
| 493 | |
| 494 | def custom(start, end): |
| 495 | |
| 496 | def custom_forward(*inputs): |
| 497 | layers = self.layer[start:end] |
| 498 | x_ = inputs[0] |
| 499 | for layer in layers: |
| 500 | x_ = layer(x_, inputs[1]) |
| 501 | return x_ |
| 502 | |
| 503 | return custom_forward |
| 504 | |
| 505 | if checkpoint_activations: |
| 506 | l = 0 |
| 507 | num_layers = len(self.layer) |
| 508 | chunk_length = math.ceil(math.sqrt(num_layers)) |
| 509 | while l < num_layers: |
| 510 | hidden_states = checkpoint.checkpoint(custom(l, l + chunk_length), hidden_states, attention_mask * 1) |
| 511 | l += chunk_length |
| 512 | # decoder layers |
| 513 | else: |
| 514 | for i, layer_module in enumerate(self.layer): |
| 515 | hidden_states = layer_module(hidden_states, attention_mask, self.grads, collect_all_grads=True) |
| 516 | hidden_states.register_hook(lambda x, i=i, self=self: self.grads.append([x, "hidden_state"])) |
| 517 | #print("pytorch weight is: ", layer_module.get_w()) |
| 518 | |
| 519 | if output_all_encoded_layers: |
| 520 | all_encoder_layers.append((hidden_states)) |
| 521 | |
| 522 | if not output_all_encoded_layers or checkpoint_activations: |
| 523 | all_encoder_layers.append((hidden_states)) |
| 524 | return all_encoder_layers |