x: [B, L, C].
(self, x)
| 72 | self.proj = nn.Linear(dim, dim) |
| 73 | |
| 74 | def forward(self, x): |
| 75 | """ |
| 76 | x: [B, L, C]. |
| 77 | """ |
| 78 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim |
| 79 | |
| 80 | # compute query, key, value |
| 81 | q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) |
| 82 | |
| 83 | # compute attention |
| 84 | p = self.attn_dropout if self.training else 0.0 |
| 85 | x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) |
| 86 | x = x.reshape(b, s, c) |
| 87 | |
| 88 | # output |
| 89 | x = self.proj(x) |
| 90 | x = F.dropout(x, self.proj_dropout, self.training) |
| 91 | return x |
| 92 | |
| 93 | |
| 94 | class SwiGLU(nn.Module): |
nothing calls this directly
no test coverage detected