MCPcopy
hub / github.com/lllyasviel/Paints-UNDO / Attention

Class Attention

diffusers_vdm/vae.py:359–413  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

357
358
359class Attention(nn.Module):
360 def __init__(self, in_channels):
361 super().__init__()
362 self.in_channels = in_channels
363
364 self.norm = GroupNorm(in_channels)
365 self.q = torch.nn.Conv2d(
366 in_channels, in_channels, kernel_size=1, stride=1, padding=0
367 )
368 self.k = torch.nn.Conv2d(
369 in_channels, in_channels, kernel_size=1, stride=1, padding=0
370 )
371 self.v = torch.nn.Conv2d(
372 in_channels, in_channels, kernel_size=1, stride=1, padding=0
373 )
374 self.proj_out = torch.nn.Conv2d(
375 in_channels, in_channels, kernel_size=1, stride=1, padding=0
376 )
377
378 def attention(self, h_: torch.Tensor) -> torch.Tensor:
379 h_ = self.norm(h_)
380 q = self.q(h_)
381 k = self.k(h_)
382 v = self.v(h_)
383
384 # compute attention
385 B, C, H, W = q.shape
386 q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
387
388 q, k, v = map(
389 lambda t: t.unsqueeze(3)
390 .reshape(B, t.shape[1], 1, C)
391 .permute(0, 2, 1, 3)
392 .reshape(B * 1, t.shape[1], C)
393 .contiguous(),
394 (q, k, v),
395 )
396
397 out = chunked_attention(
398 q, k, v, batch_chunk=1
399 )
400
401 out = (
402 out.unsqueeze(0)
403 .reshape(B, 1, out.shape[1], C)
404 .permute(0, 2, 1, 3)
405 .reshape(B, out.shape[1], C)
406 )
407 return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
408
409 def forward(self, x, **kwargs):
410 h_ = x
411 h_ = self.attention(h_)
412 h_ = self.proj_out(h_)
413 return x + h_
414
415
416class VideoDecoder(nn.Module):

Callers 2

__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected