An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
| 198 | |
| 199 | |
| 200 | class AttentionBlock(nn.Module): |
| 201 | """ |
| 202 | An attention block that allows spatial positions to attend to each other. |
| 203 | |
| 204 | Originally ported from here, but adapted to the N-d case. |
| 205 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. |
| 206 | """ |
| 207 | |
| 208 | def __init__(self, channels, num_heads=1, use_checkpoint=False): |
| 209 | super().__init__() |
| 210 | self.channels = channels |
| 211 | self.num_heads = num_heads |
| 212 | self.use_checkpoint = use_checkpoint |
| 213 | |
| 214 | self.norm = normalization(channels) |
| 215 | self.qkv = conv_nd(1, channels, channels * 3, 1) |
| 216 | self.attention = QKVAttention() |
| 217 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) |
| 218 | |
| 219 | def forward(self, x): |
| 220 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) |
| 221 | |
| 222 | def _forward(self, x): |
| 223 | b, c, *spatial = x.shape |
| 224 | x = x.reshape(b, c, -1) |
| 225 | qkv = self.qkv(self.norm(x)) |
| 226 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) |
| 227 | h = self.attention(qkv) |
| 228 | h = h.reshape(b, -1, h.shape[-1]) |
| 229 | h = self.proj_out(h) |
| 230 | return (x + h).reshape(b, c, *spatial) |
| 231 | |
| 232 | |
| 233 | class QKVAttention(nn.Module): |