(x, heads)
| 17 | |
| 18 | |
| 19 | def reshape_tensor(x, heads): |
| 20 | bs, length, width = x.shape |
| 21 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) |
| 22 | x = x.view(bs, length, heads, -1) |
| 23 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) |
| 24 | x = x.transpose(1, 2) |
| 25 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) |
| 26 | x = x.reshape(bs, heads, length, -1) |
| 27 | return x |
| 28 | |
| 29 | |
| 30 | class PerceiverAttention(nn.Module): |