| 599 | self.layers.append(nn.LayerList([norm_fn(), layer, residual_fn])) |
| 600 | |
| 601 | def forward( |
| 602 | self, |
| 603 | x, |
| 604 | context=None, |
| 605 | mask=None, |
| 606 | context_mask=None, |
| 607 | mems=None, |
| 608 | seq_len=0, |
| 609 | return_hiddens=False, |
| 610 | ): |
| 611 | assert not ( |
| 612 | self.cross_attend ^ exists(context) |
| 613 | ), "context must be passed in if cross_attend is set to True" |
| 614 | |
| 615 | hiddens = [] |
| 616 | intermediates = [] |
| 617 | prev_attn = None |
| 618 | prev_cross_attn = None |
| 619 | rotary_pos_emb = None |
| 620 | |
| 621 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers |
| 622 | |
| 623 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate( |
| 624 | zip(self.layer_types, self.layers) |
| 625 | ): |
| 626 | is_last = ind == (len(self.layers) - 1) |
| 627 | |
| 628 | if layer_type == "a": |
| 629 | hiddens.append(x) |
| 630 | layer_mem = mems.pop(0) |
| 631 | |
| 632 | residual = x |
| 633 | |
| 634 | if self.pre_norm: |
| 635 | x = norm(x) |
| 636 | |
| 637 | if layer_type == "a": |
| 638 | out, inter = block( |
| 639 | x, |
| 640 | mask=mask, |
| 641 | sinusoidal_emb=self.pia_pos_emb, |
| 642 | rel_pos=self.rel_pos, |
| 643 | rotary_pos_emb=rotary_pos_emb, |
| 644 | prev_attn=prev_attn, |
| 645 | mem=layer_mem, |
| 646 | ) |
| 647 | elif layer_type == "c": |
| 648 | out, inter = block( |
| 649 | x, |
| 650 | context=context, |
| 651 | mask=mask, |
| 652 | context_mask=context_mask, |
| 653 | prev_attn=prev_cross_attn, |
| 654 | ) |
| 655 | elif layer_type == "f": |
| 656 | out = block(x) |
| 657 | |
| 658 | x = residual_fn(out, residual) |