MCPcopy Index your code
hub / github.com/openai/guided-diffusion / AttentionPool2d

Class AttentionPool2d

guided_diffusion/unet.py:22–51  ·  view source on GitHub ↗

Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py

Source from the content-addressed store, hash-verified

20
21
22class AttentionPool2d(nn.Module):
23 """
24 Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
25 """
26
27 def __init__(
28 self,
29 spacial_dim: int,
30 embed_dim: int,
31 num_heads_channels: int,
32 output_dim: int = None,
33 ):
34 super().__init__()
35 self.positional_embedding = nn.Parameter(
36 th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
37 )
38 self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
39 self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
40 self.num_heads = embed_dim // num_heads_channels
41 self.attention = QKVAttention(self.num_heads)
42
43 def forward(self, x):
44 b, c, *_spatial = x.shape
45 x = x.reshape(b, c, -1) # NC(HW)
46 x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
47 x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
48 x = self.qkv_proj(x)
49 x = self.attention(x)
50 x = self.c_proj(x)
51 return x[:, :, 0]
52
53
54class TimestepBlock(nn.Module):

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected