Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
| 20 | |
| 21 | |
| 22 | class AttentionPool2d(nn.Module): |
| 23 | """ |
| 24 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py |
| 25 | """ |
| 26 | |
| 27 | def __init__( |
| 28 | self, |
| 29 | spacial_dim: int, |
| 30 | embed_dim: int, |
| 31 | num_heads_channels: int, |
| 32 | output_dim: int = None, |
| 33 | ): |
| 34 | super().__init__() |
| 35 | self.positional_embedding = nn.Parameter( |
| 36 | th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 |
| 37 | ) |
| 38 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) |
| 39 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) |
| 40 | self.num_heads = embed_dim // num_heads_channels |
| 41 | self.attention = QKVAttention(self.num_heads) |
| 42 | |
| 43 | def forward(self, x): |
| 44 | b, c, *_spatial = x.shape |
| 45 | x = x.reshape(b, c, -1) # NC(HW) |
| 46 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) |
| 47 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) |
| 48 | x = self.qkv_proj(x) |
| 49 | x = self.attention(x) |
| 50 | x = self.c_proj(x) |
| 51 | return x[:, :, 0] |
| 52 | |
| 53 | |
| 54 | class TimestepBlock(nn.Module): |