| 357 | |
| 358 | |
| 359 | class Attention(nn.Module): |
| 360 | def __init__(self, in_channels): |
| 361 | super().__init__() |
| 362 | self.in_channels = in_channels |
| 363 | |
| 364 | self.norm = GroupNorm(in_channels) |
| 365 | self.q = torch.nn.Conv2d( |
| 366 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| 367 | ) |
| 368 | self.k = torch.nn.Conv2d( |
| 369 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| 370 | ) |
| 371 | self.v = torch.nn.Conv2d( |
| 372 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| 373 | ) |
| 374 | self.proj_out = torch.nn.Conv2d( |
| 375 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
| 376 | ) |
| 377 | |
| 378 | def attention(self, h_: torch.Tensor) -> torch.Tensor: |
| 379 | h_ = self.norm(h_) |
| 380 | q = self.q(h_) |
| 381 | k = self.k(h_) |
| 382 | v = self.v(h_) |
| 383 | |
| 384 | # compute attention |
| 385 | B, C, H, W = q.shape |
| 386 | q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) |
| 387 | |
| 388 | q, k, v = map( |
| 389 | lambda t: t.unsqueeze(3) |
| 390 | .reshape(B, t.shape[1], 1, C) |
| 391 | .permute(0, 2, 1, 3) |
| 392 | .reshape(B * 1, t.shape[1], C) |
| 393 | .contiguous(), |
| 394 | (q, k, v), |
| 395 | ) |
| 396 | |
| 397 | out = chunked_attention( |
| 398 | q, k, v, batch_chunk=1 |
| 399 | ) |
| 400 | |
| 401 | out = ( |
| 402 | out.unsqueeze(0) |
| 403 | .reshape(B, 1, out.shape[1], C) |
| 404 | .permute(0, 2, 1, 3) |
| 405 | .reshape(B, out.shape[1], C) |
| 406 | ) |
| 407 | return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) |
| 408 | |
| 409 | def forward(self, x, **kwargs): |
| 410 | h_ = x |
| 411 | h_ = self.attention(h_) |
| 412 | h_ = self.proj_out(h_) |
| 413 | return x + h_ |
| 414 | |
| 415 | |
| 416 | class VideoDecoder(nn.Module): |