Causal self-attention with a single head.
| 221 | |
| 222 | |
| 223 | class AttentionBlock(nn.Module): |
| 224 | """ |
| 225 | Causal self-attention with a single head. |
| 226 | """ |
| 227 | |
| 228 | def __init__(self, dim): |
| 229 | super().__init__() |
| 230 | self.dim = dim |
| 231 | |
| 232 | # layers |
| 233 | self.norm = RMS_norm(dim) |
| 234 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1) |
| 235 | self.proj = nn.Conv2d(dim, dim, 1) |
| 236 | |
| 237 | # zero out the last layer params |
| 238 | nn.init.zeros_(self.proj.weight) |
| 239 | |
| 240 | def forward(self, x): |
| 241 | identity = x |
| 242 | b, c, t, h, w = x.size() |
| 243 | x = rearrange(x, 'b c t h w -> (b t) c h w') |
| 244 | x = self.norm(x) |
| 245 | # compute query, key, value |
| 246 | q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, |
| 247 | -1).permute(0, 1, 3, |
| 248 | 2).contiguous().chunk( |
| 249 | 3, dim=-1) |
| 250 | |
| 251 | # apply attention |
| 252 | x = F.scaled_dot_product_attention( |
| 253 | q, |
| 254 | k, |
| 255 | v, |
| 256 | ) |
| 257 | x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) |
| 258 | |
| 259 | # output |
| 260 | x = self.proj(x) |
| 261 | x = rearrange(x, '(b t) c h w-> b c t h w', t=t) |
| 262 | return x + identity |
| 263 | |
| 264 | |
| 265 | class Encoder3d(nn.Module): |