(self, tensor: torch.Tensor, seq_len: int, bsz: int)
| 165 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| 166 | |
| 167 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| 168 | return ( |
| 169 | tensor.view(bsz, seq_len, self.num_heads, self.head_dim) |
| 170 | .transpose(1, 2) |
| 171 | .contiguous() |
| 172 | ) |
| 173 | |
| 174 | def forward( |
| 175 | self, |