x: [B, L, C].
(self, x)
| 184 | nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) |
| 185 | |
| 186 | def forward(self, x): |
| 187 | """ |
| 188 | x: [B, L, C]. |
| 189 | """ |
| 190 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim |
| 191 | |
| 192 | # compute query, key, value |
| 193 | q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) |
| 194 | k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) |
| 195 | |
| 196 | # compute attention |
| 197 | x = flash_attention(q, k, v, version=2) |
| 198 | x = x.reshape(b, 1, c) |
| 199 | |
| 200 | # output |
| 201 | x = self.proj(x) |
| 202 | x = F.dropout(x, self.proj_dropout, self.training) |
| 203 | |
| 204 | # mlp |
| 205 | x = x + self.mlp(self.norm(x)) |
| 206 | return x[:, 0] |
| 207 | |
| 208 | |
| 209 | class VisionTransformer(nn.Module): |
nothing calls this directly
no test coverage detected