A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
| 326 | |
| 327 | |
| 328 | class QKVAttentionLegacy(nn.Module): |
| 329 | """ |
| 330 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping |
| 331 | """ |
| 332 | |
| 333 | def __init__(self, n_heads): |
| 334 | super().__init__() |
| 335 | self.n_heads = n_heads |
| 336 | |
| 337 | def forward(self, qkv): |
| 338 | """ |
| 339 | Apply QKV attention. |
| 340 | |
| 341 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. |
| 342 | :return: an [N x (H * C) x T] tensor after attention. |
| 343 | """ |
| 344 | bs, width, length = qkv.shape |
| 345 | assert width % (3 * self.n_heads) == 0 |
| 346 | ch = width // (3 * self.n_heads) |
| 347 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) |
| 348 | scale = 1 / math.sqrt(math.sqrt(ch)) |
| 349 | weight = th.einsum( |
| 350 | "bct,bcs->bts", q * scale, k * scale |
| 351 | ) # More stable with f16 than dividing afterwards |
| 352 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) |
| 353 | a = th.einsum("bts,bcs->bct", weight, v) |
| 354 | return a.reshape(bs, -1, length) |
| 355 | |
| 356 | @staticmethod |
| 357 | def count_flops(model, _x, y): |
| 358 | return count_flops_attn(model, _x, y) |
| 359 | |
| 360 | |
| 361 | class QKVAttention(nn.Module): |